Files
OpenVPNAdapter/Sources/OpenVPN3/openvpn/ws/websocket.hpp

565 lines
12 KiB
C++

// OpenVPN -- An application to securely tunnel IP networks
// over a single port, with support for SSL/TLS-based
// session authentication and key exchange,
// packet encryption, packet authentication, and
// packet compression.
//
// Copyright (C) 2012-2020 OpenVPN Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License Version 3
// as published by the Free Software Foundation.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program in the COPYING file.
// If not, see <http://www.gnu.org/licenses/>.
#pragma once
#include <string>
#include <cstdint>
#include <ostream>
#include <tuple>
#include <utility>
#include <openvpn/common/exception.hpp>
#include <openvpn/common/rc.hpp>
#include <openvpn/common/base64.hpp>
#include <openvpn/common/socktypes.hpp>
#include <openvpn/common/endian64.hpp>
#include <openvpn/crypto/hashstr.hpp>
#include <openvpn/buffer/buffer.hpp>
#include <openvpn/random/randapi.hpp>
namespace openvpn {
namespace WebSocket {
OPENVPN_EXCEPTION(websocket_error);
class Receiver;
inline std::string accept_confirmation(DigestFactory& digest_factory,
const std::string& websocket_key)
{
static const char guid[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
HashString h(digest_factory, CryptoAlgs::SHA1);
h.update(websocket_key + guid);
return h.final_base64();
}
class Protocol
{
public:
static constexpr size_t MAX_HEAD = 16;
enum Opcode
{
Text = 0x1,
Binary = 0x2,
Close = 0x8,
Ping = 0x9,
Pong = 0xA,
};
static std::string opcode_to_string(const unsigned int opcode)
{
switch (opcode)
{
case Text:
return "Text";
case Binary:
return "Binary";
case Close:
return "Close";
case Ping:
return "Ping";
case Pong:
return "Pong";
default:
return "WS-OPCODE-" + std::to_string(opcode);
}
}
union MaskingKey
{
public:
MaskingKey(std::uint32_t mask)
: mask32(std::move(mask))
{
}
void xor_buf(Buffer& buf) const
{
const size_t size = buf.size();
std::uint8_t* data = buf.data();
for (size_t i = 0; i < size; ++i)
data[i] ^= mask8[i & 0x3];
}
void prepend_mask(Buffer& buf) const
{
buf.prepend(&mask32, sizeof(mask32));
}
private:
std::uint32_t mask32;
std::uint8_t mask8[4];
};
};
class Status
{
public:
Status()
: opcode_(0),
fin_(false),
close_status_code_(0)
{
}
Status(unsigned int opcode,
bool fin=true,
unsigned int close_status_code=0)
: opcode_(std::move(opcode)),
fin_(std::move(fin)),
close_status_code_(std::move(close_status_code))
{
}
Status(const Status& ref,
const unsigned int opcode)
: opcode_(opcode),
fin_(ref.fin_),
close_status_code_(ref.close_status_code_)
{
}
bool defined() const
{
return opcode_ != 0;
}
unsigned int opcode() const
{
return opcode_;
}
bool fin() const
{
return fin_;
}
unsigned int close_status_code() const
{
return close_status_code_;
}
bool operator==(const Status& rhs) const
{
return std::tie(opcode_, fin_, close_status_code_) == std::tie(rhs.opcode_, fin_, rhs.close_status_code_);
}
bool operator!=(const Status& rhs) const
{
return std::tie(opcode_, fin_, close_status_code_) != std::tie(rhs.opcode_, fin_, rhs.close_status_code_);
}
std::string to_string() const
{
std::string ret;
ret.reserve(64);
ret += "[op=";
ret += Protocol::opcode_to_string(opcode_);
ret += " fin=";
ret += std::to_string(fin_);
if (opcode_ == Protocol::Close)
{
ret += " status=";
ret += std::to_string(close_status_code_);
}
ret += ']';
return ret;
}
private:
friend class Receiver;
unsigned int opcode_;
bool fin_;
unsigned int close_status_code_;
};
class Sender
{
public:
Sender(RandomAPI::Ptr cli_rng_arg) // only provide rng on client side
: cli_rng(std::move(cli_rng_arg))
{
if (cli_rng)
cli_rng->assert_crypto();
}
void frame(Buffer& buf, const Status& s) const
{
if (s.opcode() == Protocol::Close)
{
const std::uint16_t cs = htons(s.close_status_code());
buf.prepend(&cs, sizeof(cs));
}
const size_t payload_len = buf.size();
if (cli_rng)
{
const Protocol::MaskingKey mk(cli_rng->rand_get<std::uint32_t>());
mk.xor_buf(buf);
mk.prepend_mask(buf);
}
prepend_payload_length(buf, payload_len);
std::uint8_t head = s.opcode() & 0xF;
if (s.fin())
head |= 0x80;
buf.prepend(&head, sizeof(head));
//OPENVPN_LOG("WS SEND HEAD\n" << dump_hex(buf));
}
private:
void prepend_payload_length(Buffer& buf, const size_t len) const
{
std::uint8_t len8;
if (len <= 125)
len8 = len;
else if (len <= 65535)
{
len8 = 126;
const std::uint16_t len16 = htons(len);
buf.prepend(&len16, sizeof(len16));
}
else
{
len8 = 127;
const std::uint64_t len64 = Endian::rev64(len);
buf.prepend(&len64, sizeof(len64));
}
if (cli_rng)
len8 |= 0x80;
buf.prepend(&len8, sizeof(len8));
}
RandomAPI::Ptr cli_rng;
};
class Receiver
{
public:
Receiver(const bool is_client_arg)
: is_client(is_client_arg)
{
reset_pod();
}
Buffer buf_unframed()
{
verify_message_complete();
if (size > buf.size())
throw websocket_error("Receiver::buf_unframed: internal error");
return Buffer(buf.data(), size, true);
}
// return true if message is complete
bool complete()
{
// already complete?
if (header_complete)
return complete_();
// we need at least 2 bytes before we can do anything
if (buf.size() < 2)
return false;
// get first 2 bytes of header
Buffer b(buf.data(), buf.size(), true);
const std::uint8_t* head = b.read_alloc(2);
s.opcode_ = head[0] & 0xF;
s.fin_ = bool(head[0] & 0x80);
if (head[0] & 0x70)
throw websocket_error("Receiver: reserved bits are set");
if (bool(head[1] & 0x80) == is_client)
throw websocket_error("Receiver: bad masking direction");
// process payload length
const std::uint8_t pl = head[1] & 0x7f;
if (pl <= 125)
{
size = pl;
}
else if (pl == 126)
{
std::uint16_t len16;
if (b.size() < sizeof(len16))
return false;
b.read(&len16, sizeof(len16));
size = ntohs(len16);
}
else // pl == 127
{
std::uint64_t len64;
if (b.size() < sizeof(len64))
return false;
b.read(&len64, sizeof(len64));
size = Endian::rev64(len64);
}
// read mask (server side only)
if (!is_client)
{
if (b.size() < sizeof(mask))
return false;
b.read(&mask, sizeof(mask));
}
buf.advance(b.offset());
header_complete = true;
return complete_();
}
void add_buf(BufferAllocated&& inbuf)
{
if (!buf.allocated())
{
buf = std::move(inbuf);
buf.or_flags(BufferAllocated::GROW);
}
else
buf.append(inbuf);
}
void reset()
{
verify_message_complete();
s = Status();
reset_buf();
reset_pod();
}
Status status() const
{
verify_message_complete();
return s;
}
private:
void reset_buf()
{
if (buf.allocated())
{
if (size < buf.size())
{
buf.advance(size);
buf.realign(0);
}
else if (size == buf.size())
buf.clear();
else
throw websocket_error("Receiver::reset_buf: bad size");
}
}
void reset_pod()
{
header_complete = false;
message_complete = false;
mask = 0;
size = 0;
}
void verify_message_complete() const
{
if (!message_complete)
throw websocket_error("Receiver: message incomplete");
}
bool complete_()
{
if (message_complete)
return true;
if (header_complete && size <= buf.size())
{
// un-xor the data on the server side only
if (!is_client)
{
Buffer b(buf.data(), size, true);
const Protocol::MaskingKey mk(mask);
mk.xor_buf(b);
}
// get close status code
if (s.opcode_ == Protocol::Close && size >= 2)
{
std::uint16_t cs;
buf.read(&cs, sizeof(cs));
size -= sizeof(cs);
s.close_status_code_ = ntohs(cs);
}
message_complete = true;
return true;
}
return false;
}
const bool is_client;
bool header_complete;
bool message_complete;
std::uint32_t mask;
std::uint64_t size;
Status s;
BufferAllocated buf;
};
namespace Client {
struct Config : public RC<thread_unsafe_refcount>
{
typedef RCPtr<Config> Ptr;
std::string origin;
std::string protocol;
RandomAPI::Ptr rng;
DigestFactory::Ptr digest_factory;
// compression
bool compress = false;
size_t compress_threshold = 256;
};
class PerRequest : public RC<thread_unsafe_refcount>
{
private:
Config::Ptr conf;
public:
typedef RCPtr<PerRequest> Ptr;
PerRequest(Config::Ptr conf_arg)
: conf(validate_conf(std::move(conf_arg))),
sender(conf->rng),
receiver(true)
{
}
void client_headers(std::ostream& os)
{
generate_websocket_key();
os << "Sec-WebSocket-Key: " << websocket_key << "\r\n";
os << "Sec-WebSocket-Version: 13\r\n";
if (!conf->protocol.empty())
os << "Sec-WebSocket-Protocol: " << conf->protocol << "\r\n";
os << "Connection: Upgrade\r\n";
os << "Upgrade: websocket\r\n";
if (!conf->origin.empty())
os << "Origin: " << conf->origin << "\r\n";
}
bool confirm_websocket_key(const std::string& ws_accept) const
{
return ws_accept == accept_confirmation(*conf->digest_factory, websocket_key);
}
Sender sender;
Receiver receiver;
private:
static Config::Ptr validate_conf(Config::Ptr conf)
{
if (!conf)
throw websocket_error("no config");
conf->rng->assert_crypto();
if (!conf->digest_factory)
throw websocket_error("no digest factory in config");
return conf;
}
void generate_websocket_key()
{
std::uint8_t data[16];
conf->rng->rand_bytes(data, sizeof(data));
websocket_key = base64->encode(data, sizeof(data));
}
std::string websocket_key;
};
}
namespace Server {
struct Config : public RC<thread_unsafe_refcount>
{
typedef RCPtr<Config> Ptr;
std::string protocol;
DigestFactory::Ptr digest_factory;
};
class PerRequest : public RC<thread_unsafe_refcount>
{
private:
Config::Ptr conf;
public:
typedef RCPtr<PerRequest> Ptr;
PerRequest(Config::Ptr conf_arg)
: conf(validate_conf(std::move(conf_arg))),
sender(RandomAPI::Ptr()),
receiver(false)
{
}
void set_websocket_key(const std::string& websocket_key)
{
websocket_accept = accept_confirmation(*conf->digest_factory, websocket_key);
}
void server_headers(std::ostream& os)
{
os << "Upgrade: websocket\r\n";
os << "Connection: Upgrade\r\n";
if (!websocket_accept.empty())
os << "Sec-WebSocket-Accept: " << websocket_accept << "\r\n";
if (!conf->protocol.empty())
os << "Sec-WebSocket-Protocol: " << conf->protocol << "\r\n";
}
Sender sender;
Receiver receiver;
private:
static Config::Ptr validate_conf(Config::Ptr conf)
{
if (!conf)
throw websocket_error("no config");
if (!conf->digest_factory)
throw websocket_error("no digest factory in config");
return conf;
}
std::string websocket_accept;
};
}
}
}