wsd: new NetUtil file for network utilities

Move the connect function into the NetUtil
translation unit to aid using it for the
upcoming async socket logic.

The NetUtil should also come in handy for
the miscellaneous network helpers we have.

Change-Id: I2ee0c6e3e1769fd87572d7407d3b4979b59ffe6a
Signed-off-by: Ashod Nakashian <ashod.nakashian@collabora.co.uk>
This commit is contained in:
Ashod Nakashian 2021-01-12 22:48:26 -05:00 committed by Ashod Nakashian
parent 18446d63cf
commit ce3dd02ef3
5 changed files with 126 additions and 57 deletions

View file

@ -103,6 +103,7 @@ shared_sources = common/FileUtil.cpp \
common/Authorization.cpp \
net/DelaySocket.cpp \
net/HttpHelper.cpp \
net/NetUtil.cpp \
net/Socket.cpp
if ENABLE_SSL
shared_sources += net/Ssl.cpp
@ -266,6 +267,7 @@ shared_headers = common/Common.hpp \
net/DelaySocket.hpp \
net/FakeSocket.hpp \
net/HttpHelper.hpp \
net/NetUtil.hpp \
net/ServerSocket.hpp \
net/Socket.hpp \
net/WebSocketHandler.hpp \

88
net/NetUtil.cpp Normal file
View file

@ -0,0 +1,88 @@
/* -*- Mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; fill-column: 100 -*- */
/*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
#include <config.h>
#include "NetUtil.hpp"
#include "Socket.hpp"
#if ENABLE_SSL && !MOBILEAPP
#include "SslSocket.hpp"
#endif
#include <netdb.h>
namespace net
{
std::shared_ptr<StreamSocket>
connect(const std::string& host, const std::string& port, const bool isSSL,
const std::shared_ptr<ProtocolHandlerInterface>& protocolHandler)
{
LOG_DBG("Connecting to " << host << ':' << port << " (" << (isSSL ? "SSL" : "Unencrypted")
<< ')');
std::shared_ptr<StreamSocket> socket;
#if !ENABLE_SSL
if (isSSL)
{
LOG_ERR("Error: isSSL socket requested but SSL is not compiled in.");
return socket;
}
#endif
// FIXME: store the address?
struct addrinfo* ainfo = nullptr;
struct addrinfo hints;
std::memset(&hints, 0, sizeof(hints));
const int rc = getaddrinfo(host.c_str(), port.c_str(), &hints, &ainfo);
if (!rc && ainfo)
{
for (struct addrinfo* ai = ainfo; ai; ai = ai->ai_next)
{
std::string canonicalName;
if (ai->ai_canonname)
canonicalName = ai->ai_canonname;
if (ai->ai_addrlen && ai->ai_addr)
{
int fd = ::socket(ai->ai_addr->sa_family, SOCK_STREAM | SOCK_NONBLOCK, 0);
int res = ::connect(fd, ai->ai_addr, ai->ai_addrlen);
if (fd < 0 || (res < 0 && errno != EINPROGRESS))
{
LOG_SYS("Failed to connect to " << host);
::close(fd);
}
else
{
#if ENABLE_SSL
if (isSSL)
socket = StreamSocket::create<SslStreamSocket>(fd, true, protocolHandler);
#endif
if (!socket && !isSSL)
socket = StreamSocket::create<StreamSocket>(fd, true, protocolHandler);
if (socket)
break;
LOG_ERR("Failed to allocate socket for client websocket " << host);
::close(fd);
break;
}
}
}
freeaddrinfo(ainfo);
}
else
LOG_ERR("Failed to lookup host [" << host << "]. Skipping.");
return socket;
}
} // namespace net

26
net/NetUtil.hpp Normal file
View file

@ -0,0 +1,26 @@
/* -*- Mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; fill-column: 100 -*- */
/*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
#pragma once
#include <string>
#include <memory>
// This file hosts network related common functionality
// and helper/utility functions and classes.
// HTTP-specific helpers are in HttpHeler.hpp.
class StreamSocket;
class ProtocolHandlerInterface;
namespace net
{
/// Connect to an end-point at the given host and port and return StreamSocket.
std::shared_ptr<StreamSocket>
connect(const std::string& host, const std::string& port, const bool isSSL,
const std::shared_ptr<ProtocolHandlerInterface>& protocolHandler);
} // namespace net

View file

@ -7,12 +7,14 @@
#include <config.h>
#include "NetUtil.hpp"
#include "Socket.hpp"
#include <cstring>
#include <ctype.h>
#include <iomanip>
#include <stdio.h>
#include <string>
#include <unistd.h>
#include <sys/stat.h>
#include <sys/types.h>
@ -382,18 +384,7 @@ void SocketPoll::insertNewWebSocketSync(
const Poco::URI &uri,
const std::shared_ptr<ProtocolHandlerInterface>& websocketHandler)
{
LOG_INF("Connecting to " << uri.getHost() << " : " << uri.getPort() << " : " << uri.getPath());
// FIXME: put this in a ClientSocket class ?
// FIXME: store the address there - and ... (so on) ...
struct addrinfo* ainfo = nullptr;
struct addrinfo hints;
std::memset(&hints, 0, sizeof(hints));
int rc = getaddrinfo(uri.getHost().c_str(),
std::to_string(uri.getPort()).c_str(),
&hints, &ainfo);
std::string canonicalName;
bool isSSL = uri.getScheme() != "ws";
const bool isSSL = uri.getScheme() != "ws";
#if !ENABLE_SSL
if (isSSL)
{
@ -402,53 +393,14 @@ void SocketPoll::insertNewWebSocketSync(
}
#endif
if (!rc && ainfo)
std::shared_ptr<StreamSocket> socket
= net::connect(uri.getHost(), std::to_string(uri.getPort()), isSSL, websocketHandler);
if (socket)
{
for (struct addrinfo* ai = ainfo; ai; ai = ai->ai_next)
{
if (ai->ai_canonname)
canonicalName = ai->ai_canonname;
if (ai->ai_addrlen && ai->ai_addr)
{
int fd = socket(ai->ai_addr->sa_family, SOCK_STREAM | SOCK_NONBLOCK, 0);
int res = connect(fd, ai->ai_addr, ai->ai_addrlen);
if (fd < 0 || (res < 0 && errno != EINPROGRESS))
{
LOG_ERR("Failed to connect to " << uri.getHost());
::close(fd);
}
else
{
std::shared_ptr<StreamSocket> socket;
#if ENABLE_SSL
if (isSSL)
socket = StreamSocket::create<SslStreamSocket>(fd, true, websocketHandler);
#endif
if (!socket && !isSSL)
socket = StreamSocket::create<StreamSocket>(fd, true, websocketHandler);
if (socket)
{
LOG_DBG("Connected to client websocket " << uri.getHost() << " #" << socket->getFD());
clientRequestWebsocketUpgrade(socket, websocketHandler, uri.getPathAndQuery());
insertNewSocket(socket);
}
else
{
LOG_ERR("Failed to allocate socket for client websocket " << uri.getHost());
::close(fd);
}
break;
}
}
}
freeaddrinfo(ainfo);
LOG_DBG("Connected to client websocket " << uri.getHost() << " #" << socket->getFD());
clientRequestWebsocketUpgrade(socket, websocketHandler, uri.getPathAndQuery());
insertNewSocket(socket);
}
else
LOG_ERR("Failed to lookup client websocket host '" << uri.getHost() << "' skipping");
}
// should this be a static method in the WebsocketHandler(?)

View file

@ -89,6 +89,7 @@ unittest_SOURCES = \
../common/Unit.cpp \
../common/StringVector.cpp \
../net/Socket.cpp \
../net/NetUtil.cpp \
../wsd/Auth.cpp \
../wsd/TestStubs.cpp \
test.cpp