diff --git a/bricks/net/tcp/impl/posix.h b/bricks/net/tcp/impl/posix.h index 025d2829..7d93ae86 100644 --- a/bricks/net/tcp/impl/posix.h +++ b/bricks/net/tcp/impl/posix.h @@ -81,6 +81,7 @@ typedef int SOCKET; namespace current { namespace net { +const timeval DefaultSocketTimeout = {0, 0}; enum class NagleAlgorithm : bool { Disable, Keep }; const NagleAlgorithm kDefaultNagleAlgorithmPolicy = NagleAlgorithm::Keep; @@ -394,9 +395,8 @@ class ReserveLocalPortImpl final { current::net::BarePort(port), nagle_algorithm_policy, max_connections); - return current::net::ReservedLocalPort(current::net::ReservedLocalPort::Construct(), - port, - std::move(hold_port_or_throw)); + return current::net::ReservedLocalPort( + current::net::ReservedLocalPort::Construct(), port, std::move(hold_port_or_throw)); } class Connection : public SocketHandle { @@ -673,12 +673,17 @@ inline std::string ResolveIPFromHostname(const std::string& hostname) { // POSIX allows numeric ports, as well as strings like "http". template -inline Connection ClientSocket(const std::string& host, T port_or_serv) { +inline Connection ClientSocket(const std::string& host, + T port_or_serv, + const timeval& read_timeout = DefaultSocketTimeout, + const timeval& write_timeout = DefaultSocketTimeout) { class ClientSocket final : public SocketHandle { public: explicit ClientSocket(const std::string& host, const std::string& serv, - NagleAlgorithm nagle_algorithm_policy = kDefaultNagleAlgorithmPolicy) + NagleAlgorithm nagle_algorithm_policy = kDefaultNagleAlgorithmPolicy, + const timeval& read_timeout = DefaultSocketTimeout, + const timeval& write_timeout = DefaultSocketTimeout) : SocketHandle(SocketHandle::DoNotBind(), nagle_algorithm_policy) { CURRENT_BRICKS_NET_LOG("S%05d ", static_cast(socket)); // Deliberately left non-const because of possible Windows issues. -- M.Z. @@ -693,6 +698,15 @@ inline Connection ClientSocket(const std::string& host, T port_or_serv) { remote_ip_and_port.port = htons(p_addr_in->sin_port); CURRENT_BRICKS_NET_LOG("S%05d connect() ...\n", static_cast(socket)); + + // Set socket timeout if it's not default 0,0 + if (read_timeout.tv_sec > 0 || read_timeout.tv_usec > 0) { + setsockopt(socket, SOL_SOCKET, SO_RCVTIMEO, &read_timeout, sizeof(read_timeout)); + } + if (write_timeout.tv_sec > 0 || write_timeout.tv_usec > 0) { + setsockopt(socket, SOL_SOCKET, SO_SNDTIMEO, &write_timeout, sizeof(write_timeout)); + } + const int retval = ::connect(socket, p_addr, sizeof(*p_addr)); if (retval) { CURRENT_THROW(SocketConnectException()); // LCOV_EXCL_LINE -- Not covered by the unit tests. @@ -715,7 +729,8 @@ inline Connection ClientSocket(const std::string& host, T port_or_serv) { IPAndPort local_ip_and_port; IPAndPort remote_ip_and_port; }; - auto client_socket = ClientSocket(host, std::to_string(port_or_serv)); + auto client_socket = + ClientSocket(host, std::to_string(port_or_serv), kDefaultNagleAlgorithmPolicy, read_timeout, write_timeout); IPAndPort local_ip_and_port(std::move(client_socket.local_ip_and_port)); IPAndPort remote_ip_and_port(std::move(client_socket.remote_ip_and_port)); return Connection(std::move(client_socket), std::move(local_ip_and_port), std::move(remote_ip_and_port)); diff --git a/bricks/net/tcp/test.cc b/bricks/net/tcp/test.cc index 9b9ee93d..df5171c1 100644 --- a/bricks/net/tcp/test.cc +++ b/bricks/net/tcp/test.cc @@ -52,6 +52,7 @@ using current::net::Socket; using current::strings::Printf; using current::net::AttemptedToUseMovedAwayConnection; +using current::net::EmptySocketReadException; using current::net::SocketBindException; using current::net::SocketException; using current::net::SocketResolveAddressException; @@ -155,6 +156,45 @@ TEST(TCPTest, ReceiveDelayedMessage) { server.join(); } +TEST(TCPTest, SocketReadTimeoutFailed) { + current::net::ReservedLocalPort port_reservation = ReserveLocalPort(); + const uint16_t port_number = port_reservation; + const timeval timeout = {1, 0}; + std::thread server( + [](Socket socket) { + Connection connection = socket.Accept(); + // Socket timeout is 1 sec + // client should not receive anything + sleep_for(milliseconds(1200)); + connection.BlockingWrite("TEST", false); + }, + std::move(port_reservation)); + Connection client(ClientSocket("localhost", port_number, timeout)); + char response[5] = "????"; + ASSERT_THROW(client.BlockingRead(response, 4, Connection::FillFullBuffer), EmptySocketReadException); + server.join(); +} + +TEST(TCPTest, SocketReadTimeoutOK) { + current::net::ReservedLocalPort port_reservation = ReserveLocalPort(); + const uint16_t port_number = port_reservation; + const timeval timeout = {1, 0}; + std::thread server( + [](Socket socket) { + Connection connection = socket.Accept(); + // Default socket timeout is 100 ms + // client should receive the string + sleep_for(milliseconds(10)); + connection.BlockingWrite("TEST", false); + }, + std::move(port_reservation)); + Connection client(ClientSocket("localhost", port_number, timeout)); + char response[5] = "????"; + ASSERT_EQ(4u, client.BlockingRead(response, 4, Connection::FillFullBuffer)); + EXPECT_EQ("TEST", std::string(response)); + server.join(); +} + TEST(TCPTest, ReceiveMessageAsResponse) { current::net::ReservedLocalPort port_reservation = ReserveLocalPort(); const uint16_t port_number = port_reservation;