diff --git a/kit/Kit.cpp b/kit/Kit.cpp index 8ef9840bd..47cedaa8d 100644 --- a/kit/Kit.cpp +++ b/kit/Kit.cpp @@ -2009,7 +2009,7 @@ class KitWebSocketHandler final : public WebSocketHandler, public std::enable_sh public: KitWebSocketHandler(const std::string& socketName, const std::shared_ptr& loKit, const std::string& jailId, SocketPoll& socketPoll) : - WebSocketHandler(/* isClient = */ true), + WebSocketHandler(/* isClient = */ true, /* isMasking */ false), _queue(std::make_shared()), _socketName(socketName), _loKit(loKit), diff --git a/net/WebSocketHandler.hpp b/net/WebSocketHandler.hpp index 59814b6d5..5b9aa5460 100644 --- a/net/WebSocketHandler.hpp +++ b/net/WebSocketHandler.hpp @@ -36,6 +36,7 @@ protected: std::vector _wsPayload; std::atomic _shuttingDown; bool _isClient; + bool _isMasking; struct WSFrameMask { @@ -48,11 +49,12 @@ protected: public: /// Perform upgrade ourselves, or select a client web socket. - WebSocketHandler(bool isClient = false) : + WebSocketHandler(bool isClient = false, bool isMasking = true) : _lastPingSentTime(std::chrono::steady_clock::now()), _pingTimeUs(0), _shuttingDown(false), - _isClient(isClient) + _isClient(isClient), + _isMasking(isClient && isMasking) { } @@ -65,7 +67,8 @@ public: std::chrono::milliseconds(InitialPingDelayMs)), _pingTimeUs(0), _shuttingDown(false), - _isClient(false) + _isClient(false), + _isMasking(false) { upgradeToWebSocket(request); } @@ -381,14 +384,14 @@ public: return sendFrame(socket, data, len, WSFrameMask::Fin | static_cast(code), flush); } -protected: +private: /// Sends a WebSocket frame given the data, length, and flags. /// Returns the number of bytes written (including frame overhead) on success, /// 0 for closed/invalid socket, and -1 for other errors. - static int sendFrame(const std::shared_ptr& socket, - const char* data, const size_t len, - const unsigned char flags, const bool flush = true) + int sendFrame(const std::shared_ptr& socket, + const char* data, const size_t len, + unsigned char flags, const bool flush = true) const { if (!socket || data == nullptr || len == 0) return -1; @@ -402,19 +405,20 @@ protected: out.push_back(flags); + int maskFlag = _isMasking ? 0x80 : 0; if (len < 126) { - out.push_back((char)len); + out.push_back((char)(len | maskFlag)); } else if (len <= 0xffff) { - out.push_back((char)126); + out.push_back((char)(126 | maskFlag)); out.push_back(static_cast((len >> 8) & 0xff)); out.push_back(static_cast((len >> 0) & 0xff)); } else { - out.push_back((char)127); + out.push_back((char)(127 | maskFlag)); out.push_back(static_cast((len >> 56) & 0xff)); out.push_back(static_cast((len >> 48) & 0xff)); out.push_back(static_cast((len >> 40) & 0xff)); @@ -425,8 +429,27 @@ protected: out.push_back(static_cast((len >> 0) & 0xff)); } - // Copy the data. - out.insert(out.end(), data, data + len); + if (_isMasking) + { // flip some top bits - perhaps it helps. + size_t mask = out.size(); + + out.push_back(static_cast(0x81)); + out.push_back(static_cast(0x76)); + out.push_back(static_cast(0x81)); + out.push_back(static_cast(0x76)); + + // Copy the data. + out.insert(out.end(), data, data + len); + + // Mask it. + for (size_t i = 4; i < out.size() - mask; ++i) + out[mask + i] = out[mask + i] ^ out[mask + (i%4)]; + } + else + { + // Copy the data. + out.insert(out.end(), data, data + len); + } const size_t size = out.size() - oldSize; if (flush) @@ -435,6 +458,8 @@ protected: return size; } +protected: + /// To be overriden to handle the websocket messages the way you need. virtual void handleMessage(bool /*fin*/, WSOpCode /*code*/, std::vector &/*data*/) { diff --git a/test/data/hello-world.ods b/test/data/hello-world.ods index 3a44a7ba1..37ddf33f1 100644 Binary files a/test/data/hello-world.ods and b/test/data/hello-world.ods differ