/* -*- Mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; fill-column: 100 -*- */ /* * This file is part of the LibreOffice project. * * This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "config.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include using Poco::MemoryInputStream; using Poco::StringTokenizer; #include "Socket.hpp" #include "ServerSocket.hpp" #include "SslSocket.hpp" constexpr int HttpPortNumber = 9191; constexpr int SslPortNumber = 9193; static std::string computeAccept(const std::string &key); class WebSocketHandler : public SocketHandlerInterface { std::unique_ptr _socket; int _wsVersion; std::string _wsKey; std::string _wsProtocol; std::vector _wsPayload; enum { HTTP, WEBSOCKET } _wsState; public: WebSocketHandler() : _wsVersion(0), _wsState(HTTP) { } /// Implementation of the SocketHandlerInterface. virtual void setSocket(StreamSocket* socket) override { _socket.reset(socket); } void handleWebsocketUpgrade() { int number = 0; MemoryInputStream message(&_socket->_inBuffer[0], _socket->_inBuffer.size()); Poco::Net::HTTPRequest req; req.read(message); // if we succeeded - remove that from our input buffer // FIXME: We should check if this is GET or POST. For GET, we only // can have a single request (headers only). For POST, we can/should // use Poco HTMLForm to parse the post message properly. // Otherwise, we should catch exceptions from the previous read/parse // and assume we don't have sufficient data, so we wait some more. _socket->_inBuffer.clear(); StringTokenizer tokens(req.getURI(), "/?"); if (tokens.count() == 4) { std::string subpool = tokens[2]; number = std::stoi(tokens[3]); // complex algorithmic core: number = number + 1; std::string numberString = std::to_string(number); std::ostringstream oss; oss << "HTTP/1.1 200 OK\r\n" << "Date: Once, Upon a time GMT\r\n" // Mon, 27 Jul 2009 12:28:53 GMT << "Server: madeup string (Linux)\r\n" << "Content-Length: " << numberString.size() << "\r\n" << "Content-Type: text/plain\r\n" << "Connection: Closed\r\n" << "\r\n" << numberString; ; std::string str = oss.str(); _socket->_outBuffer.insert(_socket->_outBuffer.end(), str.begin(), str.end()); } else if (tokens.count() == 2 && tokens[1] == "ws") { // create our websocket goodness ... _wsVersion = std::stoi(req.get("Sec-WebSocket-Version", "13")); _wsKey = req.get("Sec-WebSocket-Key", ""); _wsProtocol = req.get("Sec-WebSocket-Protocol", "chat"); std::cerr << "version " << _wsVersion << " key '" << _wsKey << "\n"; // FIXME: other sanity checks ... std::ostringstream oss; oss << "HTTP/1.1 101 Switching Protocols\r\n" << "Upgrade: websocket\r\n" << "Connection: Upgrade\r\n" << "Sec-Websocket-Accept: " << computeAccept(_wsKey) << "\r\n" << "\r\n"; std::string str = oss.str(); _socket->_outBuffer.insert(_socket->_outBuffer.end(), str.begin(), str.end()); _wsState = WEBSOCKET; } else std::cerr << " unknown tokens " << tokens.count() << std::endl; } enum WSOpCode { Continuation, // 0x0 Text, // 0x1 Binary, // 0x2 Reserved1, // 0x3 Reserved2, // 0x4 Reserved3, // 0x5 Reserved4, // 0x6 Reserved5, // 0x7 Close, // 0x8 Ping, // 0x9 Pong // 0xa // ... reserved }; /// Implementation of the SocketHandlerInterface. virtual void handleIncomingMessage() override { std::cerr << "incoming message with buffer size " << _socket->_inBuffer.size() << "\n"; if (_wsState == HTTP) { handleWebsocketUpgrade(); return; } // websocket fun ! size_t len = _socket->_inBuffer.size(); if (len < 2) // partial read return; unsigned char *p = reinterpret_cast(&_socket->_inBuffer[0]); bool fin = p[0] & 0x80; WSOpCode code = static_cast(p[0] & 0x0f); bool hasMask = p[1] & 0x80; size_t payloadLen = p[1] & 0x7f; size_t headerLen = 2; // normally - 7 bit length. if (payloadLen == 126) // 2 byte length { if (len < 2 + 2) return; payloadLen = (((unsigned)p[2]) << 8) | ((unsigned)p[3]); headerLen += 2; } else if (payloadLen == 127) // 8 byte length { if (len < 2 + 8) return; payloadLen = ((((uint64_t)(p[9])) << 0) + (((uint64_t)(p[8])) << 8) + (((uint64_t)(p[7])) << 16) + (((uint64_t)(p[6])) << 24) + (((uint64_t)(p[5])) << 32) + (((uint64_t)(p[4])) << 40) + (((uint64_t)(p[3])) << 48) + (((uint64_t)(p[2])) << 56)); // FIXME: crop read length to remove top / sign bits. headerLen += 8; } unsigned char *data, *mask; if (hasMask) { mask = p + headerLen; headerLen += 4; } if (payloadLen + headerLen > len) { // partial read wait for more data. return; } data = p + headerLen; if (hasMask) { for (size_t i = 0; i < payloadLen; ++i) data[i] = data[i] ^ mask[i % 4]; // FIXME: copy and un-mask at the same time ... _wsPayload.insert(_wsPayload.end(), data, data + payloadLen); } else _wsPayload.insert(_wsPayload.end(), data, data + payloadLen); _socket->_inBuffer.erase(_socket->_inBuffer.begin(), _socket->_inBuffer.begin() + headerLen + payloadLen); // FIXME: fin, aggregating payloads into _wsPayload etc. handleMessage(fin, code, _wsPayload); _wsPayload.clear(); } void sendMessage(const std::vector &data, WSOpCode code = WSOpCode::Binary) { size_t len = data.size(); bool fin = false; bool mask = false; unsigned char header[2]; header[0] = (fin ? 0x80 : 0) | static_cast(code); header[1] = mask ? 0x80 : 0; _socket->_outBuffer.push_back((char)header[0]); // no out-bound masking ... if (len < 126) { header[1] |= len; _socket->_outBuffer.push_back((char)header[1]); } else if (len <= 0xffff) { header[1] |= 126; _socket->_outBuffer.push_back((char)header[1]); _socket->_outBuffer.push_back(static_cast((len >> 8) & 0xff)); _socket->_outBuffer.push_back(static_cast((len >> 0) & 0xff)); } else { header[1] |= 127; _socket->_outBuffer.push_back((char)header[1]); _socket->_outBuffer.push_back(static_cast((len >> 56) & 0xff)); _socket->_outBuffer.push_back(static_cast((len >> 48) & 0xff)); _socket->_outBuffer.push_back(static_cast((len >> 40) & 0xff)); _socket->_outBuffer.push_back(static_cast((len >> 32) & 0xff)); _socket->_outBuffer.push_back(static_cast((len >> 24) & 0xff)); _socket->_outBuffer.push_back(static_cast((len >> 16) & 0xff)); _socket->_outBuffer.push_back(static_cast((len >> 8) & 0xff)); _socket->_outBuffer.push_back(static_cast((len >> 0) & 0xff)); } // FIXME: pick random number and mask in the outbuffer etc. assert (!mask); _socket->_outBuffer.insert(_socket->_outBuffer.end(), data.begin(), data.end()); } virtual void handleMessage(bool fin, WSOpCode code, std::vector &data) = 0; }; class SimpleResponseClient : public WebSocketHandler { public: SimpleResponseClient() : WebSocketHandler() { } virtual void handleMessage(bool fin, WSOpCode code, std::vector &data) override { std::cerr << "Message: fin? " << fin << " code " << code << " data size " << data.size(); if (code == WSOpCode::Text) { std::string text(data.begin(), data.end()); std::cerr << " text is '" << text << "'\n"; return; } else std::cerr << " binary\n"; std::vector reply; if (data.size() == sizeof(size_t)) { // ping pong test assert (data.size() >= sizeof(size_t)); size_t *countPtr = reinterpret_cast(&data[0]); size_t count = *countPtr; count++; std::cerr << "count is " << count << "\n"; reply.insert(reply.end(), reinterpret_cast(&count), reinterpret_cast(&count) + sizeof(count)); } else { // echo tests reply.insert(reply.end(), data.begin(), data.end()); } sendMessage(reply); } }; // FIXME: use Poco Thread instead (?) /// Generic thread class. class Thread { public: Thread(const std::function&)>& cb) : _cb(cb), _stop(false) { _thread = std::thread([this]() { _cb(_stop); }); } Thread(Thread&& other) = delete; const Thread& operator=(Thread&& other) = delete; ~Thread() { stop(); if (_thread.joinable()) { _thread.join(); } } void stop() { _stop = true; } private: const std::function&)> _cb; std::atomic _stop; std::thread _thread; }; Poco::Net::SocketAddress addrHttp("127.0.0.1", HttpPortNumber); Poco::Net::SocketAddress addrSsl("127.0.0.1", SslPortNumber); void server(const Poco::Net::SocketAddress& addr, SocketPoll& clientPoller, std::unique_ptr sockFactory) { // Start server. auto server = std::make_shared(clientPoller, std::move(sockFactory)); if (!server->bind(addr)) { const std::string msg = "Failed to bind. (errno: "; throw std::runtime_error(msg + std::strerror(errno) + ")"); } if (!server->listen()) { const std::string msg = "Failed to listen. (errno: "; throw std::runtime_error(msg + std::strerror(errno) + ")"); } SocketPoll serverPoll; serverPoll.insertNewSocket(server); std::cout << "Listening." << std::endl; for (;;) { serverPoll.poll(30000); } } int main(int argc, const char**argv) { // TODO: These would normally come from config. SslContext::initialize("/etc/loolwsd/cert.pem", "/etc/loolwsd/key.pem", "/etc/loolwsd/ca-chain.cert.pem"); // Used to poll client sockets. SocketPoll poller; // Start the client polling thread. Thread threadPoll([&poller](std::atomic& stop) { while (!stop) { poller.poll(5000); } }); class PlainSocketFactory : public SocketFactory { std::shared_ptr create(const int fd) override { return std::make_shared(fd, new SimpleResponseClient()); } }; class SslSocketFactory : public SocketFactory { std::shared_ptr create(const int fd) override { return std::make_shared(fd, new SimpleResponseClient()); } }; // Start the server. if (!strcmp(argv[argc-1], "ssl")) server(addrSsl, poller, std::unique_ptr{new SslSocketFactory}); else server(addrHttp, poller, std::unique_ptr{new PlainSocketFactory}); std::cout << "Shutting down server." << std::endl; threadPoll.stop(); SslContext::uninitialize(); return 0; } // Saves writing this ourselves: #include namespace { #include struct Puncture : private Poco::Net::WebSocket { static std::string doComputeAccept(const std::string &key) { return computeAccept(key); } }; } static std::string computeAccept(const std::string &key) { return Puncture::doComputeAccept(key); } /* vim:set shiftwidth=4 softtabstop=4 expandtab: */