diff --git a/Makefile.am b/Makefile.am index 348209688..c058b4917 100644 --- a/Makefile.am +++ b/Makefile.am @@ -103,6 +103,7 @@ shared_sources = common/FileUtil.cpp \ common/Authorization.cpp \ net/DelaySocket.cpp \ net/HttpHelper.cpp \ + net/NetUtil.cpp \ net/Socket.cpp if ENABLE_SSL shared_sources += net/Ssl.cpp @@ -266,6 +267,7 @@ shared_headers = common/Common.hpp \ net/DelaySocket.hpp \ net/FakeSocket.hpp \ net/HttpHelper.hpp \ + net/NetUtil.hpp \ net/ServerSocket.hpp \ net/Socket.hpp \ net/WebSocketHandler.hpp \ diff --git a/net/NetUtil.cpp b/net/NetUtil.cpp new file mode 100644 index 000000000..a03f1142d --- /dev/null +++ b/net/NetUtil.cpp @@ -0,0 +1,88 @@ +/* -*- Mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; fill-column: 100 -*- */ +/* + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#include + +#include "NetUtil.hpp" + +#include "Socket.hpp" +#if ENABLE_SSL && !MOBILEAPP +#include "SslSocket.hpp" +#endif + +#include + +namespace net +{ +std::shared_ptr +connect(const std::string& host, const std::string& port, const bool isSSL, + const std::shared_ptr& protocolHandler) +{ + LOG_DBG("Connecting to " << host << ':' << port << " (" << (isSSL ? "SSL" : "Unencrypted") + << ')'); + + std::shared_ptr socket; + +#if !ENABLE_SSL + if (isSSL) + { + LOG_ERR("Error: isSSL socket requested but SSL is not compiled in."); + return socket; + } +#endif + + // FIXME: store the address? + struct addrinfo* ainfo = nullptr; + struct addrinfo hints; + std::memset(&hints, 0, sizeof(hints)); + const int rc = getaddrinfo(host.c_str(), port.c_str(), &hints, &ainfo); + + if (!rc && ainfo) + { + for (struct addrinfo* ai = ainfo; ai; ai = ai->ai_next) + { + std::string canonicalName; + if (ai->ai_canonname) + canonicalName = ai->ai_canonname; + + if (ai->ai_addrlen && ai->ai_addr) + { + int fd = ::socket(ai->ai_addr->sa_family, SOCK_STREAM | SOCK_NONBLOCK, 0); + int res = ::connect(fd, ai->ai_addr, ai->ai_addrlen); + if (fd < 0 || (res < 0 && errno != EINPROGRESS)) + { + LOG_SYS("Failed to connect to " << host); + ::close(fd); + } + else + { +#if ENABLE_SSL + if (isSSL) + socket = StreamSocket::create(fd, true, protocolHandler); +#endif + if (!socket && !isSSL) + socket = StreamSocket::create(fd, true, protocolHandler); + + if (socket) + break; + + LOG_ERR("Failed to allocate socket for client websocket " << host); + ::close(fd); + break; + } + } + } + + freeaddrinfo(ainfo); + } + else + LOG_ERR("Failed to lookup host [" << host << "]. Skipping."); + + return socket; +} + +} // namespace net \ No newline at end of file diff --git a/net/NetUtil.hpp b/net/NetUtil.hpp new file mode 100644 index 000000000..547f94a62 --- /dev/null +++ b/net/NetUtil.hpp @@ -0,0 +1,26 @@ +/* -*- Mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; fill-column: 100 -*- */ +/* + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#pragma once + +#include +#include + +// This file hosts network related common functionality +// and helper/utility functions and classes. +// HTTP-specific helpers are in HttpHeler.hpp. + +class StreamSocket; +class ProtocolHandlerInterface; + +namespace net +{ +/// Connect to an end-point at the given host and port and return StreamSocket. +std::shared_ptr +connect(const std::string& host, const std::string& port, const bool isSSL, + const std::shared_ptr& protocolHandler); +} // namespace net \ No newline at end of file diff --git a/net/Socket.cpp b/net/Socket.cpp index 0ccdeb2bb..8cdcbc38c 100644 --- a/net/Socket.cpp +++ b/net/Socket.cpp @@ -7,12 +7,14 @@ #include +#include "NetUtil.hpp" #include "Socket.hpp" #include #include #include #include +#include #include #include #include @@ -382,18 +384,7 @@ void SocketPoll::insertNewWebSocketSync( const Poco::URI &uri, const std::shared_ptr& websocketHandler) { - LOG_INF("Connecting to " << uri.getHost() << " : " << uri.getPort() << " : " << uri.getPath()); - - // FIXME: put this in a ClientSocket class ? - // FIXME: store the address there - and ... (so on) ... - struct addrinfo* ainfo = nullptr; - struct addrinfo hints; - std::memset(&hints, 0, sizeof(hints)); - int rc = getaddrinfo(uri.getHost().c_str(), - std::to_string(uri.getPort()).c_str(), - &hints, &ainfo); - std::string canonicalName; - bool isSSL = uri.getScheme() != "ws"; + const bool isSSL = uri.getScheme() != "ws"; #if !ENABLE_SSL if (isSSL) { @@ -402,53 +393,14 @@ void SocketPoll::insertNewWebSocketSync( } #endif - if (!rc && ainfo) + std::shared_ptr socket + = net::connect(uri.getHost(), std::to_string(uri.getPort()), isSSL, websocketHandler); + if (socket) { - for (struct addrinfo* ai = ainfo; ai; ai = ai->ai_next) - { - if (ai->ai_canonname) - canonicalName = ai->ai_canonname; - - if (ai->ai_addrlen && ai->ai_addr) - { - int fd = socket(ai->ai_addr->sa_family, SOCK_STREAM | SOCK_NONBLOCK, 0); - int res = connect(fd, ai->ai_addr, ai->ai_addrlen); - if (fd < 0 || (res < 0 && errno != EINPROGRESS)) - { - LOG_ERR("Failed to connect to " << uri.getHost()); - ::close(fd); - } - else - { - std::shared_ptr socket; -#if ENABLE_SSL - if (isSSL) - socket = StreamSocket::create(fd, true, websocketHandler); -#endif - if (!socket && !isSSL) - socket = StreamSocket::create(fd, true, websocketHandler); - - if (socket) - { - LOG_DBG("Connected to client websocket " << uri.getHost() << " #" << socket->getFD()); - clientRequestWebsocketUpgrade(socket, websocketHandler, uri.getPathAndQuery()); - insertNewSocket(socket); - } - else - { - LOG_ERR("Failed to allocate socket for client websocket " << uri.getHost()); - ::close(fd); - } - - break; - } - } - } - - freeaddrinfo(ainfo); + LOG_DBG("Connected to client websocket " << uri.getHost() << " #" << socket->getFD()); + clientRequestWebsocketUpgrade(socket, websocketHandler, uri.getPathAndQuery()); + insertNewSocket(socket); } - else - LOG_ERR("Failed to lookup client websocket host '" << uri.getHost() << "' skipping"); } // should this be a static method in the WebsocketHandler(?) diff --git a/test/Makefile.am b/test/Makefile.am index 99ac47549..2ca6ef95d 100644 --- a/test/Makefile.am +++ b/test/Makefile.am @@ -89,6 +89,7 @@ unittest_SOURCES = \ ../common/Unit.cpp \ ../common/StringVector.cpp \ ../net/Socket.cpp \ + ../net/NetUtil.cpp \ ../wsd/Auth.cpp \ ../wsd/TestStubs.cpp \ test.cpp