#include "sockutil.h" #include "util.h" #include #include #include #include #include #include std::string FormatWsaError(int errorCode) { std::stringstream ss; ss << "0x" << std::setw(8) << std::setfill('0') << std::hex << errorCode; return ss.str(); } SOCKET CreateBindSocket(const IN_ADDR &ip, uint16_t port, bool tcp) { auto s = CreateSocket(tcp); sockaddr_in endpoint = { 0 }; endpoint.sin_family = AF_INET; endpoint.sin_port = htons(port); endpoint.sin_addr = ip; auto status = bind(s, (sockaddr*)&endpoint, sizeof(endpoint)); if (SOCKET_ERROR == status) { const auto errorCode = WSAGetLastError(); closesocket(s); const auto err = std::string("Failed to bind socket: ") .append(FormatWsaError(errorCode)); THROW_ERROR(err.c_str()); } return s; } SOCKET CreateBindSocket(const std::wstring &ip, uint16_t port, bool tcp) { return CreateBindSocket(ParseIpv4(ip), port, tcp); } SOCKET CreateSocket(bool tcp) { const auto [type, protocol] = tcp ? std::make_pair<>(SOCK_STREAM, IPPROTO_TCP) : std::make_pair<>(SOCK_DGRAM, IPPROTO_UDP); auto s = socket(AF_INET, type, protocol); if (INVALID_SOCKET == s) { THROW_ERROR("Failed to create socket"); } return s; } void ShutdownSocket(SOCKET &s) { if (s != INVALID_SOCKET) { shutdown(s, SD_BOTH); closesocket(s); s = INVALID_SOCKET; } } void ConnectSocket(SOCKET s, const IN_ADDR &ip, uint16_t port) { sockaddr_in peer = { 0 }; peer.sin_family = AF_INET; peer.sin_port = htons(port); peer.sin_addr = ip; if (SOCKET_ERROR == connect(s, (sockaddr*)&peer, sizeof(peer))) { const auto lastError = WSAGetLastError(); const auto err = std::string("Failed to connect socket: ") .append(FormatWsaError(lastError)); THROW_WINDOWS_ERROR(lastError, err.c_str()); } } void ConnectSocket(SOCKET s, const std::wstring &ip, uint16_t port) { ConnectSocket(s, ParseIpv4(ip), port); } std::vector SendRecvSocket(SOCKET s, const std::vector &sendBuffer) { auto status = send(s, (const char *)&sendBuffer[0], (int)sendBuffer.size(), 0); if (SOCKET_ERROR == status) { const auto err = std::string("Failed to send on socket: ") .append(FormatWsaError(WSAGetLastError())); THROW_ERROR(err.c_str()); } if (status != sendBuffer.size()) { std::stringstream ss; ss << "Failed to send() on socket. Sent " << status << " of " << sendBuffer.size() << " bytes"; THROW_ERROR(ss.str().c_str()); } std::vector receiveBuffer(1024, 0); status = recv(s, (char *)&receiveBuffer[0], (int)receiveBuffer.size(), 0); if (SOCKET_ERROR == status) { const auto err = std::string("Failed to receive on socket: ") .append(FormatWsaError(WSAGetLastError())); THROW_ERROR(err.c_str()); } receiveBuffer.resize(status); return receiveBuffer; } void SendRecvValidateEcho(SOCKET s, const std::vector &sendBuffer) { const auto receiveBuffer = SendRecvSocket(s, sendBuffer); if (receiveBuffer.size() != sendBuffer.size() || 0 != memcmp(&receiveBuffer[0], &sendBuffer[0], receiveBuffer.size())) { THROW_ERROR("Invalid echo response"); } } sockaddr_in QueryBind(SOCKET s) { sockaddr_in bind = { 0 }; int bindSize = sizeof(bind); if (SOCKET_ERROR == getsockname(s, (sockaddr*)&bind, &bindSize)) { const auto err = std::string("Failed to query bind: ") .append(FormatWsaError(WSAGetLastError())); THROW_ERROR(err.c_str()); } if (bindSize != sizeof(bind)) { THROW_ERROR("Invalid data returned for bind query"); } return bind; } void ValidateBind(SOCKET s, const IN_ADDR &ip) { auto actualBind = QueryBind(s); if (actualBind.sin_addr.s_addr != ip.s_addr) { std::wstringstream ss; ss << L"Unexpected socket bind. Expected address " << IpToString(ip) << L", Actual address " << IpToString(actualBind.sin_addr); THROW_ERROR(common::string::ToAnsi(ss.str()).c_str()); } } void SetSocketRecvTimeout(SOCKET s, std::chrono::milliseconds timeout) { DWORD rawTimeout = static_cast(timeout.count()); const auto status = setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&rawTimeout), sizeof(rawTimeout)); if (SOCKET_ERROR == status) { const auto errorCode = WSAGetLastError(); const auto err = std::string("Failed to set socket recv timeout: ") .append(FormatWsaError(errorCode)); THROW_ERROR(err.c_str()); } } SOCKET CreateBindOverlappedSocket(const IN_ADDR &ip, uint16_t port, bool tcp) { // // Turns out all sockets on Windows support overlapped operations. // return CreateBindSocket(ip, port, tcp); } SOCKET CreateBindOverlappedSocket(const std::wstring &ip, uint16_t port, bool tcp) { return CreateBindOverlappedSocket(ParseIpv4(ip), port, tcp); } WinsockOverlapped *AllocateWinsockOverlapped() { auto ctx = new WinsockOverlapped; ZeroMemory(&ctx->overlapped, sizeof(ctx->overlapped)); ctx->overlapped.hEvent = WSACreateEvent(); ZeroMemory(&ctx->winsockBuffer, sizeof(ctx->winsockBuffer)); ctx->pendingOperation = false; return ctx; } void DeleteWinsockOverlapped(WinsockOverlapped **ctx) { if ((*ctx)->pendingOperation) { WaitForSingleObject((*ctx)->overlapped.hEvent, INFINITE); } WSACloseEvent((*ctx)->overlapped.hEvent); delete *ctx; *ctx = nullptr; } void AssignOverlappedBuffer(WinsockOverlapped &ctx, std::vector &&buffer) { assert(!ctx.pendingOperation); ctx.buffer.swap(buffer); ctx.winsockBuffer.buf = reinterpret_cast(&ctx.buffer[0]); ctx.winsockBuffer.len = static_cast(ctx.buffer.size()); } void SendOverlappedSocket(SOCKET s, WinsockOverlapped &ctx) { assert(!ctx.pendingOperation); WSAResetEvent(ctx.overlapped.hEvent); const auto status = WSASend(s, &ctx.winsockBuffer, 1, nullptr, 0, &ctx.overlapped, nullptr); if (0 == status || (SOCKET_ERROR == status && WSA_IO_PENDING == WSAGetLastError())) { ctx.pendingOperation = true; return; } THROW_WINDOWS_ERROR(WSAGetLastError(), "WSASend"); } void RecvOverlappedSocket(SOCKET s, WinsockOverlapped &ctx, size_t bytes) { assert(!ctx.pendingOperation); WSAResetEvent(ctx.overlapped.hEvent); if (ctx.winsockBuffer.len != bytes) { ctx.winsockBuffer.len = static_cast(bytes); if (ctx.buffer.size() < bytes) { ctx.buffer.resize(bytes); ctx.winsockBuffer.buf = reinterpret_cast(&ctx.buffer[0]); } } DWORD flags = 0; const auto status = WSARecv(s, &ctx.winsockBuffer, 1, nullptr, &flags, &ctx.overlapped, nullptr); if (0 == status || (SOCKET_ERROR == status && WSA_IO_PENDING == WSAGetLastError())) { ctx.pendingOperation = true; return; } THROW_WINDOWS_ERROR(WSAGetLastError(), "WSARecv"); } bool PollOverlappedSend(SOCKET s, WinsockOverlapped &ctx) { assert(ctx.pendingOperation); DWORD bytesTransferred; DWORD flags; const auto status = WSAGetOverlappedResult(s, &ctx.overlapped, &bytesTransferred, FALSE, &flags); if (FALSE == status) { if (WSA_IO_INCOMPLETE == WSAGetLastError()) { return false; } THROW_WINDOWS_ERROR(WSAGetLastError(), "Overlapped send"); } ctx.pendingOperation = false; if (bytesTransferred != ctx.winsockBuffer.len) { THROW_ERROR("Overlapped send completed but did not transfer all bytes"); } return true; } bool PollOverlappedRecv(SOCKET s, WinsockOverlapped &ctx) { assert(ctx.pendingOperation); DWORD bytesTransferred; DWORD flags; const auto status = WSAGetOverlappedResult(s, &ctx.overlapped, &bytesTransferred, FALSE, &flags); if (FALSE == status) { if (WSA_IO_INCOMPLETE == WSAGetLastError()) { return false; } THROW_WINDOWS_ERROR(WSAGetLastError(), "Overlapped receive"); } ctx.pendingOperation = false; ctx.winsockBuffer.len = bytesTransferred; return true; }