Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions bricks/net/tcp/impl/posix.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ typedef int SOCKET;
namespace current {
namespace net {

const timeval DefaultSocketTimeout = {0, 0};
enum class NagleAlgorithm : bool { Disable, Keep };
const NagleAlgorithm kDefaultNagleAlgorithmPolicy = NagleAlgorithm::Keep;

Expand Down Expand Up @@ -394,9 +395,8 @@ class ReserveLocalPortImpl final {
current::net::BarePort(port),
nagle_algorithm_policy,
max_connections);
return current::net::ReservedLocalPort(current::net::ReservedLocalPort::Construct(),
port,
std::move(hold_port_or_throw));
return current::net::ReservedLocalPort(
current::net::ReservedLocalPort::Construct(), port, std::move(hold_port_or_throw));
}

class Connection : public SocketHandle {
Expand Down Expand Up @@ -673,12 +673,17 @@ inline std::string ResolveIPFromHostname(const std::string& hostname) {

// POSIX allows numeric ports, as well as strings like "http".
template <typename T>
inline Connection ClientSocket(const std::string& host, T port_or_serv) {
inline Connection ClientSocket(const std::string& host,
T port_or_serv,
const timeval& read_timeout = DefaultSocketTimeout,
const timeval& write_timeout = DefaultSocketTimeout) {
class ClientSocket final : public SocketHandle {
public:
explicit ClientSocket(const std::string& host,
const std::string& serv,
NagleAlgorithm nagle_algorithm_policy = kDefaultNagleAlgorithmPolicy)
NagleAlgorithm nagle_algorithm_policy = kDefaultNagleAlgorithmPolicy,
const timeval& read_timeout = DefaultSocketTimeout,
const timeval& write_timeout = DefaultSocketTimeout)
: SocketHandle(SocketHandle::DoNotBind(), nagle_algorithm_policy) {
CURRENT_BRICKS_NET_LOG("S%05d ", static_cast<SOCKET>(socket));
// Deliberately left non-const because of possible Windows issues. -- M.Z.
Expand All @@ -693,6 +698,15 @@ inline Connection ClientSocket(const std::string& host, T port_or_serv) {
remote_ip_and_port.port = htons(p_addr_in->sin_port);

CURRENT_BRICKS_NET_LOG("S%05d connect() ...\n", static_cast<SOCKET>(socket));

// Set socket timeout if it's not default 0,0
if (read_timeout.tv_sec > 0 || read_timeout.tv_usec > 0) {
setsockopt(socket, SOL_SOCKET, SO_RCVTIMEO, &read_timeout, sizeof(read_timeout));
}
if (write_timeout.tv_sec > 0 || write_timeout.tv_usec > 0) {
setsockopt(socket, SOL_SOCKET, SO_SNDTIMEO, &write_timeout, sizeof(write_timeout));
}

const int retval = ::connect(socket, p_addr, sizeof(*p_addr));
if (retval) {
CURRENT_THROW(SocketConnectException()); // LCOV_EXCL_LINE -- Not covered by the unit tests.
Expand All @@ -715,7 +729,8 @@ inline Connection ClientSocket(const std::string& host, T port_or_serv) {
IPAndPort local_ip_and_port;
IPAndPort remote_ip_and_port;
};
auto client_socket = ClientSocket(host, std::to_string(port_or_serv));
auto client_socket =
ClientSocket(host, std::to_string(port_or_serv), kDefaultNagleAlgorithmPolicy, read_timeout, write_timeout);
IPAndPort local_ip_and_port(std::move(client_socket.local_ip_and_port));
IPAndPort remote_ip_and_port(std::move(client_socket.remote_ip_and_port));
return Connection(std::move(client_socket), std::move(local_ip_and_port), std::move(remote_ip_and_port));
Expand Down
40 changes: 40 additions & 0 deletions bricks/net/tcp/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ using current::net::Socket;
using current::strings::Printf;

using current::net::AttemptedToUseMovedAwayConnection;
using current::net::EmptySocketReadException;
using current::net::SocketBindException;
using current::net::SocketException;
using current::net::SocketResolveAddressException;
Expand Down Expand Up @@ -155,6 +156,45 @@ TEST(TCPTest, ReceiveDelayedMessage) {
server.join();
}

TEST(TCPTest, SocketReadTimeoutFailed) {
current::net::ReservedLocalPort port_reservation = ReserveLocalPort();
const uint16_t port_number = port_reservation;
const timeval timeout = {1, 0};
std::thread server(
[](Socket socket) {
Connection connection = socket.Accept();
// Socket timeout is 1 sec
// client should not receive anything
sleep_for(milliseconds(1200));
connection.BlockingWrite("TEST", false);
},
std::move(port_reservation));
Connection client(ClientSocket("localhost", port_number, timeout));
char response[5] = "????";
ASSERT_THROW(client.BlockingRead(response, 4, Connection::FillFullBuffer), EmptySocketReadException);
server.join();
}

TEST(TCPTest, SocketReadTimeoutOK) {
current::net::ReservedLocalPort port_reservation = ReserveLocalPort();
const uint16_t port_number = port_reservation;
const timeval timeout = {1, 0};
std::thread server(
[](Socket socket) {
Connection connection = socket.Accept();
// Default socket timeout is 100 ms
// client should receive the string
sleep_for(milliseconds(10));
connection.BlockingWrite("TEST", false);
},
std::move(port_reservation));
Connection client(ClientSocket("localhost", port_number, timeout));
char response[5] = "????";
ASSERT_EQ(4u, client.BlockingRead(response, 4, Connection::FillFullBuffer));
EXPECT_EQ("TEST", std::string(response));
server.join();
}

TEST(TCPTest, ReceiveMessageAsResponse) {
current::net::ReservedLocalPort port_reservation = ReserveLocalPort();
const uint16_t port_number = port_reservation;
Expand Down