libreoffice-online/net/WebSocketHandler.hpp
Tor Lillqvist b59d160a08 Intermediate commit of work in progress on an iOS app
The app is unimaginatively called "Mobile" for now.

Runs but crashes pretty quickly after loading the document by the LO
core. Will need some heavy changes to get a ClientSession object
created in there, too, to handle the (emulated) WebSocket messages
from the JavaScript. It would then handle some of these messages
itself, and forwards some to the ChildSession, which in this case is
in the same process. Now the messsages from the JavaScript go to a
ChildSession, which is wrong. As the assertion says, "Tile traffic
should go through the DocumentBroker-LoKit WS"
2018-09-12 18:32:05 +03:00

624 lines
21 KiB
C++

/* -*- Mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; fill-column: 100 -*- */
/*
* This file is part of the LibreOffice project.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
#ifndef INCLUDED_WEBSOCKETHANDLER_HPP
#define INCLUDED_WEBSOCKETHANDLER_HPP
#include <chrono>
#include <memory>
#include <vector>
#include "common/Common.hpp"
#include "common/Log.hpp"
#include "common/Unit.hpp"
#include "Socket.hpp"
#include <Poco/MemoryStream.h>
#include <Poco/Net/HTTPRequest.h>
#include <Poco/Net/HTTPResponse.h>
#include <Poco/Net/WebSocket.h>
class WebSocketHandler : public SocketHandlerInterface
{
protected:
/// The socket that owns us (we can't own it).
std::weak_ptr<StreamSocket> _socket;
std::chrono::steady_clock::time_point _lastPingSentTime;
int _pingTimeUs;
std::vector<char> _wsPayload;
std::atomic<bool> _shuttingDown;
bool _isClient;
bool _isMasking;
struct WSFrameMask
{
static const unsigned char Fin = 0x80;
static const unsigned char Mask = 0x80;
};
static const int InitialPingDelayMs;
static const int PingFrequencyMs;
public:
/// Perform upgrade ourselves, or select a client web socket.
WebSocketHandler(bool isClient = false, bool isMasking = true) :
_lastPingSentTime(std::chrono::steady_clock::now()),
_pingTimeUs(0),
_shuttingDown(false),
_isClient(isClient),
_isMasking(isClient && isMasking)
{
}
/// Upgrades itself to a websocket directly.
WebSocketHandler(const std::weak_ptr<StreamSocket>& socket,
const Poco::Net::HTTPRequest& request) :
_socket(socket),
_lastPingSentTime(std::chrono::steady_clock::now() -
std::chrono::milliseconds(PingFrequencyMs) -
std::chrono::milliseconds(InitialPingDelayMs)),
_pingTimeUs(0),
_shuttingDown(false),
_isClient(false),
_isMasking(false)
{
upgradeToWebSocket(request);
}
/// Implementation of the SocketHandlerInterface.
void onConnect(const std::shared_ptr<StreamSocket>& socket) override
{
_socket = socket;
LOG_TRC("#" << socket->getFD() << " Connected to WS Handler 0x" << std::hex << this << std::dec);
}
/// Status codes sent to peer on shutdown.
enum class StatusCodes : unsigned short
{
NORMAL_CLOSE = 1000,
ENDPOINT_GOING_AWAY = 1001,
PROTOCOL_ERROR = 1002,
PAYLOAD_NOT_ACCEPTABLE = 1003,
RESERVED = 1004,
RESERVED_NO_STATUS_CODE = 1005,
RESERVED_ABNORMAL_CLOSE = 1006,
MALFORMED_PAYLOAD = 1007,
POLICY_VIOLATION = 1008,
PAYLOAD_TOO_BIG = 1009,
EXTENSION_REQUIRED = 1010,
UNEXPECTED_CONDITION = 1011,
RESERVED_TLS_FAILURE = 1015
};
/// Sends WS shutdown message to the peer.
void shutdown(const StatusCodes statusCode = StatusCodes::NORMAL_CLOSE, const std::string& statusMessage = "")
{
std::shared_ptr<StreamSocket> socket = _socket.lock();
if (socket == nullptr)
{
LOG_ERR("No socket associated with WebSocketHandler 0x" << std::hex << this << std::dec);
return;
}
LOG_TRC("#" << socket->getFD() << ": Shutdown websocket, code: " <<
static_cast<unsigned>(statusCode) << ", message: " << statusMessage);
_shuttingDown = true;
const size_t len = statusMessage.size();
std::vector<char> buf(2 + len);
buf[0] = ((((int)statusCode) >> 8) & 0xff);
buf[1] = ((((int)statusCode) >> 0) & 0xff);
std::copy(statusMessage.begin(), statusMessage.end(), buf.begin() + 2);
const unsigned char flags = WSFrameMask::Fin
| static_cast<char>(WSOpCode::Close);
sendFrame(socket, buf.data(), buf.size(), flags);
}
bool handleOneIncomingMessage(const std::shared_ptr<StreamSocket>& socket)
{
assert(socket && "Expected a valid socket instance.");
// websocket fun !
const size_t len = socket->_inBuffer.size();
if (len == 0)
return false; // avoid logging.
#ifndef MOBILEAPP
if (len < 2) // partial read
{
LOG_TRC("#" << socket->getFD() << ": Still incomplete WebSocket message, have " << len << " bytes");
return false;
}
unsigned char *p = reinterpret_cast<unsigned char*>(&socket->_inBuffer[0]);
const bool fin = p[0] & 0x80;
const WSOpCode code = static_cast<WSOpCode>(p[0] & 0x0f);
const bool hasMask = p[1] & 0x80;
size_t payloadLen = p[1] & 0x7f;
size_t headerLen = 2;
// normally - 7 bit length.
if (payloadLen == 126) // 2 byte length
{
if (len < 2 + 2)
{
LOG_TRC("#" << socket->getFD() << ": Still incomplete WebSocket message, have " << len << " bytes");
return false;
}
payloadLen = (((unsigned)p[2]) << 8) | ((unsigned)p[3]);
headerLen += 2;
}
else if (payloadLen == 127) // 8 byte length
{
if (len < 2 + 8)
{
LOG_TRC("#" << socket->getFD() << ": Still incomplete WebSocket message, have " << len << " bytes");
return false;
}
payloadLen = ((((uint64_t)p[9]) << 0) + (((uint64_t)p[8]) << 8) +
(((uint64_t)p[7]) << 16) + (((uint64_t)p[6]) << 24) +
(((uint64_t)p[5]) << 32) + (((uint64_t)p[4]) << 40) +
(((uint64_t)p[3]) << 48) + (((uint64_t)p[2]) << 56));
// FIXME: crop read length to remove top / sign bits.
headerLen += 8;
}
unsigned char *data, *mask;
if (hasMask)
{
mask = p + headerLen;
headerLen += 4;
}
if (payloadLen + headerLen > len)
{ // partial read wait for more data.
LOG_TRC("#" << socket->getFD() << ": Still incomplete WebSocket message, have " << len << " bytes, message is " << payloadLen + headerLen << " bytes");
return false;
}
LOG_TRC("#" << socket->getFD() << ": Incoming WebSocket data of " << len << " bytes: " << Util::stringifyHexLine(socket->_inBuffer, 0, std::min((size_t)32, len)));
data = p + headerLen;
if (hasMask)
{
const size_t end = _wsPayload.size();
_wsPayload.resize(end + payloadLen);
char* wsData = &_wsPayload[end];
for (size_t i = 0; i < payloadLen; ++i)
*wsData++ = data[i] ^ mask[i % 4];
} else
_wsPayload.insert(_wsPayload.end(), data, data + payloadLen);
#else
unsigned char * const p = reinterpret_cast<unsigned char*>(&socket->_inBuffer[0]);
_wsPayload.insert(_wsPayload.end(), p, p + len);
const size_t headerLen = 0;
const size_t payloadLen = len;
const bool hasMask = false;
#endif
assert(_wsPayload.size() >= payloadLen);
socket->_inBuffer.erase(socket->_inBuffer.begin(), socket->_inBuffer.begin() + headerLen + payloadLen);
#ifndef MOBILEAPP
// FIXME: fin, aggregating payloads into _wsPayload etc.
LOG_TRC("#" << socket->getFD() << ": Incoming WebSocket message code " << static_cast<unsigned>(code) <<
", fin? " << fin << ", mask? " << hasMask << ", payload length: " << _wsPayload.size() <<
", residual socket data: " << socket->_inBuffer.size() << " bytes.");
bool doClose = false;
switch (code)
{
case WSOpCode::Pong:
{
if (_isClient)
{
LOG_ERR("#" << socket->getFD() << ": Servers should not send pongs, only clients");
doClose = true;
break;
}
else
{
_pingTimeUs = std::chrono::duration_cast<std::chrono::microseconds>
(std::chrono::steady_clock::now() - _lastPingSentTime).count();
LOG_TRC("#" << socket->getFD() << ": Pong received: " << _pingTimeUs << " microseconds");
break;
}
}
case WSOpCode::Ping:
if (_isClient)
{
auto now = std::chrono::steady_clock::now();
_pingTimeUs = std::chrono::duration_cast<std::chrono::microseconds>
(now - _lastPingSentTime).count();
sendPong(now, &_wsPayload[0], payloadLen, socket);
break;
}
else
{
LOG_ERR("#" << socket->getFD() << ": Clients should not send pings, only servers");
doClose = true;
}
break;
case WSOpCode::Close:
doClose = true;
break;
default:
handleMessage(fin, code, _wsPayload);
break;
}
#else
handleMessage(true, WSOpCode::Binary, _wsPayload);
#endif
#ifndef MOBILEAPP
if (doClose)
{
if (!_shuttingDown)
{
// Peer-initiated shutdown must be echoed.
// Otherwise, this is the echo to _our_ shutdown message, which we should ignore.
const StatusCodes statusCode = static_cast<StatusCodes>((((uint64_t)(unsigned char)_wsPayload[0]) << 8) +
(((uint64_t)(unsigned char)_wsPayload[1]) << 0));
LOG_TRC("#" << socket->getFD() << ": Client initiated socket shutdown. Code: " << static_cast<int>(statusCode));
if (_wsPayload.size() > 2)
{
const std::string message(&_wsPayload[2], &_wsPayload[2] + _wsPayload.size() - 2);
shutdown(statusCode, message);
}
else
{
shutdown(statusCode);
}
}
else
{
LOG_TRC("#" << socket->getFD() << ": Client responded to our shutdown.");
}
// TCP Close.
socket->closeConnection();
}
#endif
_wsPayload.clear();
return true;
}
/// Implementation of the SocketHandlerInterface.
virtual void handleIncomingMessage(SocketDisposition&) override
{
std::shared_ptr<StreamSocket> socket = _socket.lock();
if (socket == nullptr)
{
LOG_ERR("No socket associated with WebSocketHandler 0x" << std::hex << this << std::dec);
}
else if (_isClient && !socket->isWebSocket())
handleClientUpgrade();
else
{
while (handleOneIncomingMessage(socket))
; // might have multiple messages in the accumulated buffer.
}
}
int getPollEvents(std::chrono::steady_clock::time_point now,
int & timeoutMaxMs) override
{
if (!_isClient)
{
const int timeSincePingMs =
std::chrono::duration_cast<std::chrono::milliseconds>(now - _lastPingSentTime).count();
timeoutMaxMs = std::min(timeoutMaxMs, PingFrequencyMs - timeSincePingMs);
}
return POLLIN;
}
/// Send a ping message
void sendPingOrPong(std::chrono::steady_clock::time_point now,
const char* data, const size_t len,
const WSOpCode code,
const std::shared_ptr<StreamSocket>& socket)
{
assert(socket && "Expected a valid socket instance.");
// Must not send this before we're upgraded.
if (!socket->isWebSocket())
{
LOG_WRN("Attempted ping on non-upgraded websocket! #" << socket->getFD());
_lastPingSentTime = now; // Pretend we sent it to avoid timing out immediately.
return;
}
LOG_TRC("#" << socket->getFD() << ": Sending " <<
(const char *)(code == WSOpCode::Ping ? " ping." : "pong."));
// FIXME: allow an empty payload.
sendMessage(data, len, code, false);
_lastPingSentTime = now;
}
void sendPing(std::chrono::steady_clock::time_point now,
const std::shared_ptr<StreamSocket>& socket)
{
assert(!_isClient);
sendPingOrPong(now, "", 1, WSOpCode::Ping, socket);
}
void sendPong(std::chrono::steady_clock::time_point now,
const char* data, const size_t len,
const std::shared_ptr<StreamSocket>& socket)
{
assert(_isClient);
sendPingOrPong(now, data, len, WSOpCode::Pong, socket);
}
/// Do we need to handle a timeout ?
void checkTimeout(std::chrono::steady_clock::time_point now) override
{
if (_isClient)
return;
const int timeSincePingMs =
std::chrono::duration_cast<std::chrono::milliseconds>(now - _lastPingSentTime).count();
if (timeSincePingMs >= PingFrequencyMs)
{
const std::shared_ptr<StreamSocket> socket = _socket.lock();
if (socket)
sendPing(now, socket);
}
}
/// By default rely on the socket buffer.
void performWrites() override {}
/// Sends a WebSocket Text message.
int sendMessage(const std::string& msg) const
{
return sendMessage(msg.data(), msg.size(), WSOpCode::Text);
}
/// Sends a WebSocket message of WPOpCode type.
/// Returns the number of bytes written (including frame overhead) on success,
/// 0 for closed/invalid socket, and -1 for other errors.
int sendMessage(const char* data, const size_t len, const WSOpCode code, const bool flush = true) const
{
#ifndef MOBILEAPP
int unitReturn = -1;
if (UnitBase::get().filterSendMessage(data, len, code, flush, unitReturn))
return unitReturn;
#endif
//TODO: Support fragmented messages.
std::shared_ptr<StreamSocket> socket = _socket.lock();
return sendFrame(socket, data, len, WSFrameMask::Fin | static_cast<unsigned char>(code), flush);
}
private:
/// Sends a WebSocket frame given the data, length, and flags.
/// Returns the number of bytes written (including frame overhead) on success,
/// 0 for closed/invalid socket, and -1 for other errors.
int sendFrame(const std::shared_ptr<StreamSocket>& socket,
const char* data, const size_t len,
unsigned char flags, const bool flush = true) const
{
if (!socket || data == nullptr || len == 0)
return -1;
if (socket->isClosed())
return 0;
socket->assertCorrectThread();
std::vector<char>& out = socket->_outBuffer;
const size_t oldSize = out.size();
out.push_back(flags);
int maskFlag = _isMasking ? 0x80 : 0;
if (len < 126)
{
out.push_back((char)(len | maskFlag));
}
else if (len <= 0xffff)
{
out.push_back((char)(126 | maskFlag));
out.push_back(static_cast<char>((len >> 8) & 0xff));
out.push_back(static_cast<char>((len >> 0) & 0xff));
}
else
{
out.push_back((char)(127 | maskFlag));
out.push_back(static_cast<char>((len >> 56) & 0xff));
out.push_back(static_cast<char>((len >> 48) & 0xff));
out.push_back(static_cast<char>((len >> 40) & 0xff));
out.push_back(static_cast<char>((len >> 32) & 0xff));
out.push_back(static_cast<char>((len >> 24) & 0xff));
out.push_back(static_cast<char>((len >> 16) & 0xff));
out.push_back(static_cast<char>((len >> 8) & 0xff));
out.push_back(static_cast<char>((len >> 0) & 0xff));
}
if (_isMasking)
{ // flip some top bits - perhaps it helps.
size_t mask = out.size();
out.push_back(static_cast<char>(0x81));
out.push_back(static_cast<char>(0x76));
out.push_back(static_cast<char>(0x81));
out.push_back(static_cast<char>(0x76));
// Copy the data.
out.insert(out.end(), data, data + len);
// Mask it.
for (size_t i = 4; i < out.size() - mask; ++i)
out[mask + i] = out[mask + i] ^ out[mask + (i%4)];
}
else
{
// Copy the data.
out.insert(out.end(), data, data + len);
}
const size_t size = out.size() - oldSize;
if (flush)
socket->writeOutgoingData();
return size;
}
protected:
/// To be overriden to handle the websocket messages the way you need.
virtual void handleMessage(bool /*fin*/, WSOpCode /*code*/, std::vector<char> &/*data*/)
{
}
void dumpState(std::ostream& os) override;
private:
/// To make the protected 'computeAccept' accessible.
class PublicComputeAccept : public Poco::Net::WebSocket
{
public:
static std::string doComputeAccept(const std::string &key)
{
return computeAccept(key);
}
};
protected:
/// Upgrade the http(s) connection to a websocket.
void upgradeToWebSocket(const Poco::Net::HTTPRequest& req)
{
std::shared_ptr<StreamSocket> socket = _socket.lock();
if (socket == nullptr)
throw std::runtime_error("Invalid socket while upgrading to WebSocket. Request: " + req.getURI());
LOG_TRC("#" << socket->getFD() << ": Upgrading to WebSocket.");
assert(!socket->isWebSocket());
// create our websocket goodness ...
const int wsVersion = std::stoi(req.get("Sec-WebSocket-Version", "13"));
const std::string wsKey = req.get("Sec-WebSocket-Key", "");
const std::string wsProtocol = req.get("Sec-WebSocket-Protocol", "chat");
// FIXME: other sanity checks ...
LOG_INF("#" << socket->getFD() << ": WebSocket version: " << wsVersion <<
", key: [" << wsKey << "], protocol: [" << wsProtocol << "].");
#if ENABLE_DEBUG
if (std::getenv("LOOL_ZERO_BUFFER_SIZE"))
socket->setSocketBufferSize(0);
#endif
std::ostringstream oss;
oss << "HTTP/1.1 101 Switching Protocols\r\n"
<< "Upgrade: websocket\r\n"
<< "Connection: Upgrade\r\n"
<< "Sec-WebSocket-Accept: " << PublicComputeAccept::doComputeAccept(wsKey) << "\r\n"
<< "\r\n";
const std::string res = oss.str();
LOG_TRC("#" << socket->getFD() << ": Sending WS Upgrade response: " << res);
socket->send(res);
setWebSocket();
}
// Handle incoming upgrade to full socket as client WS.
void handleClientUpgrade()
{
std::shared_ptr<StreamSocket> socket = _socket.lock();
LOG_TRC("Incoming client websocket upgrade response: " << std::string(&socket->_inBuffer[0], socket->_inBuffer.size()));
bool bOk = false;
size_t responseSize = 0;
try
{
Poco::MemoryInputStream message(&socket->_inBuffer[0], socket->_inBuffer.size());;
Poco::Net::HTTPResponse response;
response.read(message);
{
static const std::string marker("\r\n\r\n");
auto itBody = std::search(socket->_inBuffer.begin(),
socket->_inBuffer.end(),
marker.begin(), marker.end());
if (itBody != socket->_inBuffer.end())
responseSize = itBody - socket->_inBuffer.begin() + marker.size();
}
if (response.getStatus() == Poco::Net::HTTPResponse::HTTP_SWITCHING_PROTOCOLS &&
response.has("Upgrade") && Poco::icompare(response.get("Upgrade"), "websocket") == 0)
{
#if 0 // SAL_DEBUG ...
const std::string wsKey = response.get("Sec-WebSocket-Accept", "");
const std::string wsProtocol = response.get("Sec-WebSocket-Protocol", "");
if (Poco::icompare(wsProtocol, "chat") != 0)
LOG_ERR("Unknown websocket protocol " << wsProtocol);
else
#endif
{
LOG_TRC("Accepted incoming websocket response");
// FIXME: validate Sec-WebSocket-Accept vs. Sec-WebSocket-Key etc.
bOk = true;
}
}
}
catch (const Poco::Exception& exc)
{
LOG_DBG("handleClientUpgrade exception caught: " << exc.displayText());
}
catch (const std::exception& exc)
{
LOG_DBG("handleClientUpgrade exception caught.");
}
if (!bOk || responseSize == 0)
{
LOG_ERR("Bad websocker server response.");
socket->shutdown();
return;
}
setWebSocket();
socket->eraseFirstInputBytes(responseSize);
}
void setWebSocket()
{
std::shared_ptr<StreamSocket> socket = _socket.lock();
socket->setWebSocket();
// No need to ping right upon connection/upgrade,
// but do reset the time to avoid pinging immediately after.
_lastPingSentTime = std::chrono::steady_clock::now();
}
};
#endif
/* vim:set shiftwidth=4 softtabstop=4 expandtab: */