diff --git a/wish/cpp/benchmark/benchmark.py b/wish/cpp/benchmark/benchmark.py index 5eb7927..0cb8cea 100644 --- a/wish/cpp/benchmark/benchmark.py +++ b/wish/cpp/benchmark/benchmark.py @@ -6,10 +6,13 @@ import signal BUILD_DIR = "./build" -SERVER_BINARY_NAME = "examples/echo_server" -CLIENT_BINARY_NAME = "benchmark/benchmark_client" -SERVER_BINARY_PATH = os.path.join(BUILD_DIR, SERVER_BINARY_NAME) -CLIENT_BINARY_PATH = os.path.join(BUILD_DIR, CLIENT_BINARY_NAME) +CERTS_DIR = "./certs" + +TLS_SERVER_BINARY_NAME = "examples/tls_echo_server" +TLS_CLIENT_BINARY_NAME = "benchmark/tls_benchmark_client" + +PLAIN_SERVER_BINARY_NAME = "examples/echo_server" +PLAIN_CLIENT_BINARY_NAME = "benchmark/benchmark_client" def _client_host_from_remote_target(remote_host): @@ -18,32 +21,48 @@ def _client_host_from_remote_target(remote_host): return remote_host -def _start_server(remote_host): +def _start_server(remote_host, server_binary_path, server_binary_name, certs_dir=None): if not remote_host: - return subprocess.Popen([SERVER_BINARY_PATH]) + return subprocess.Popen([server_binary_path]) - remote_binary_path = f"/tmp/{SERVER_BINARY_NAME}" + remote_binary_path = f"/tmp/{os.path.basename(server_binary_name)}" subprocess.run( ["ssh", remote_host, f"rm -f {shlex.quote(remote_binary_path)}"], check=True, ) subprocess.run( - ["scp", SERVER_BINARY_PATH, f"{remote_host}:{remote_binary_path}"], + ["scp", server_binary_path, f"{remote_host}:{remote_binary_path}"], check=True, ) + if certs_dir: + remote_certs_dir = "/tmp/certs" + subprocess.run( + ["ssh", remote_host, f"rm -rf {shlex.quote(remote_certs_dir)}"], + check=True, + ) + subprocess.run( + ["scp", "-r", certs_dir, f"{remote_host}:{remote_certs_dir}"], + check=True, + ) + remote_command = ( f"chmod +x {shlex.quote(remote_binary_path)} && " - f"{shlex.quote(remote_binary_path)}" + f"cd /tmp && {shlex.quote(remote_binary_path)}" ) # Pass the `-t` option multiple times to ensure a pseudo-terminal is allocated, so that we can send a SIGTERM signal to the remote process, not the `ssh` process itself. This allows us to properly terminate the remote server when the benchmark is done. return subprocess.Popen(["ssh", "-tt", remote_host, remote_command], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0) -def run_benchmark(remote_host=None): - print("Running benchmark ...") +def run_benchmark(remote_host=None, tls=True): + server_binary_name = TLS_SERVER_BINARY_NAME if tls else PLAIN_SERVER_BINARY_NAME + client_binary_name = TLS_CLIENT_BINARY_NAME if tls else PLAIN_CLIENT_BINARY_NAME + server_binary_path = os.path.join(BUILD_DIR, server_binary_name) + client_binary_path = os.path.join(BUILD_DIR, client_binary_name) + + print(f"Running {'TLS' if tls else 'plain'} benchmark ...") client_host = "127.0.0.1" if remote_host: @@ -52,15 +71,15 @@ def run_benchmark(remote_host=None): else: print("Starting local server ...") - server_process = _start_server(remote_host) + server_process = _start_server(remote_host, server_binary_path, server_binary_name, certs_dir=CERTS_DIR if tls else None) time.sleep(2) # wait for server to start if server_process.poll() is not None: - raise RuntimeError(f"{SERVER_BINARY_NAME} failed to start") + raise RuntimeError(f"{server_binary_name} failed to start") try: subprocess.run([ - CLIENT_BINARY_PATH, + client_binary_path, "--stderrthreshold=0", "--benchmark_counters_tabular=true", "--benchmark_min_time=5.0s", @@ -82,19 +101,30 @@ def run_benchmark(remote_host=None): "--remote-host", help="SSH target for remote server, e.g. user@10.0.0.5", ) + parser.add_argument( + "--tls", + action=argparse.BooleanOptionalAction, + default=True, + help="Use TLS (default: enabled). Pass --no-tls for plain HTTP.", + ) args = parser.parse_args() + server_binary_name = TLS_SERVER_BINARY_NAME if args.tls else PLAIN_SERVER_BINARY_NAME + client_binary_name = TLS_CLIENT_BINARY_NAME if args.tls else PLAIN_CLIENT_BINARY_NAME + server_binary_path = os.path.join(BUILD_DIR, server_binary_name) + client_binary_path = os.path.join(BUILD_DIR, client_binary_name) + if not os.path.isdir(BUILD_DIR): print("Error: 'build' directory not found. Please compile the project first.") exit(1) - if not os.path.isfile(SERVER_BINARY_PATH): - print(f"Error: '{SERVER_BINARY_PATH}' not found. Please compile the project first.") + if not os.path.isfile(server_binary_path): + print(f"Error: '{server_binary_path}' not found. Please compile the project first.") exit(1) - if not os.path.isfile(CLIENT_BINARY_PATH): - print(f"Error: '{CLIENT_BINARY_PATH}' not found. Please compile the project first.") + if not os.path.isfile(client_binary_path): + print(f"Error: '{client_binary_path}' not found. Please compile the project first.") exit(1) - + print("Starting benchmarks...") - run_benchmark(remote_host=args.remote_host) + run_benchmark(remote_host=args.remote_host, tls=args.tls) diff --git a/wish/cpp/benchmark/tls_client.cc b/wish/cpp/benchmark/tls_client.cc index 0d30a18..ae58709 100644 --- a/wish/cpp/benchmark/tls_client.cc +++ b/wish/cpp/benchmark/tls_client.cc @@ -26,10 +26,10 @@ ABSL_FLAG(std::string, host, "127.0.0.1", "Server host to connect to"); ABSL_FLAG(int, port, 8080, "Server port to connect to"); -ABSL_FLAG(std::string, ca_cert, "../certs/ca.crt", "Path to CA certificate"); -ABSL_FLAG(std::string, client_cert, "../certs/client.crt", +ABSL_FLAG(std::string, ca_cert, "certs/ca.crt", "Path to CA certificate"); +ABSL_FLAG(std::string, client_cert, "certs/client.crt", "Path to client certificate"); -ABSL_FLAG(std::string, client_key, "../certs/client.key", +ABSL_FLAG(std::string, client_key, "certs/client.key", "Path to client private key"); namespace { @@ -101,7 +101,10 @@ bool InitConnection(ClientState* client) { const int fd = bufferevent_getfd(client->bev); if (fd >= 0) { int nodelay = 1; - setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &nodelay, sizeof(nodelay)); + int result = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &nodelay, sizeof(nodelay)); + if (result != 0) { + LOG(FATAL) << "setsockopt(TCP_NODELAY) failed"; + } } client->connected = true; diff --git a/wish/cpp/examples/CMakeLists.txt b/wish/cpp/examples/CMakeLists.txt index 8f3760f..a540bcb 100644 --- a/wish/cpp/examples/CMakeLists.txt +++ b/wish/cpp/examples/CMakeLists.txt @@ -15,6 +15,11 @@ add_executable(tls_echo_server ) target_link_libraries(tls_echo_server wish_handler + absl::flags + absl::flags_parse + absl::log + absl::log_initialize + "$" ) add_executable(tls_hello_client diff --git a/wish/cpp/examples/tls_echo_server.cc b/wish/cpp/examples/tls_echo_server.cc index 2d645eb..dfa6cf8 100644 --- a/wish/cpp/examples/tls_echo_server.cc +++ b/wish/cpp/examples/tls_echo_server.cc @@ -1,22 +1,35 @@ -#include #include #include "../src/tls_server.h" #include "../src/wish_handler.h" +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/log/initialize.h" +#include "absl/log/log.h" + +ABSL_FLAG(int, port, 8080, "Port to listen on"); +ABSL_FLAG(std::string, ca_cert, "certs/ca.crt", "Path to CA certificate file"); +ABSL_FLAG(std::string, server_cert, "certs/server.crt", "Path to server certificate file"); +ABSL_FLAG(std::string, server_key, "certs/server.key", "Path to server private key file"); int main(int argc, char** argv) { - int port = 8080; + absl::ParseCommandLine(argc, argv); + absl::InitializeLog(); + + const int port = absl::GetFlag(FLAGS_port); + const std::string ca_cert = absl::GetFlag(FLAGS_ca_cert); + const std::string server_cert = absl::GetFlag(FLAGS_server_cert); + const std::string server_key = absl::GetFlag(FLAGS_server_key); - TlsServer server("../certs/ca.crt", "../certs/server.crt", - "../certs/server.key", port); + TlsServer server(ca_cert, server_cert, server_key, port); if (!server.Init()) { - std::cerr << "Failed to initialize server" << std::endl; + LOG(ERROR) << "Failed to initialize server"; return 1; } server.SetOnConnection([](struct bufferevent* bev) { - std::cout << "Client connected." << std::endl; + LOG(INFO) << "Client connected."; WishHandler* handler = new WishHandler(bev, true); @@ -39,7 +52,7 @@ int main(int argc, char** argv) { type = "UNKNOWN(" + std::to_string(opcode) + ")"; break; } - std::cout << "Received [" << type << "]: " << msg << std::endl; + LOG(INFO) << "Received [" << type << "]: " << msg; // Echo back if (opcode == WISH_OPCODE_TEXT) @@ -51,7 +64,7 @@ int main(int argc, char** argv) { else if (opcode == WISH_OPCODE_BINARY_METADATA) handler->SendBinaryMetadata(msg); else { - std::cerr << "Unknown opcode, cannot echo." << std::endl; + LOG(WARNING) << "Unknown opcode, cannot echo."; } }); diff --git a/wish/cpp/examples/tls_hello_client.cc b/wish/cpp/examples/tls_hello_client.cc index 2b9ea94..706fc05 100644 --- a/wish/cpp/examples/tls_hello_client.cc +++ b/wish/cpp/examples/tls_hello_client.cc @@ -5,8 +5,8 @@ #include "../src/wish_handler.h" int main() { - TlsClient client("../certs/ca.crt", "../certs/client.crt", - "../certs/client.key", "127.0.0.1", 8080); + TlsClient client("certs/ca.crt", "certs/client.crt", + "certs/client.key", "127.0.0.1", 8080); if (!client.Init()) { std::cerr << "Failed to initialize client" << std::endl; diff --git a/wish/cpp/src/tls_server.cc b/wish/cpp/src/tls_server.cc index 314fafe..e13fe60 100644 --- a/wish/cpp/src/tls_server.cc +++ b/wish/cpp/src/tls_server.cc @@ -1,8 +1,10 @@ #include "tls_server.h" -#include -#include #include +#include + +#include +#include // To use BoringSSL #define EVENT__HAVE_OPENSSL 1 @@ -10,7 +12,7 @@ #include TlsServer::TlsServer(const std::string& ca_file, const std::string& cert_file, - const std::string& key_file, int port) + const std::string& key_file, int port) : ca_file_(ca_file), cert_file_(cert_file), key_file_(key_file), @@ -78,11 +80,16 @@ void TlsServer::SetOnConnection(ConnectCallback cb) { } void TlsServer::AcceptConnCb(struct evconnlistener* listener, - evutil_socket_t fd, struct sockaddr* address, - int socklen, void* ctx) { + evutil_socket_t fd, struct sockaddr* address, + int socklen, void* ctx) { struct event_base* base = evconnlistener_get_base(listener); TlsServer* server = static_cast(ctx); + int one = 1; + if (setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)) < 0) { + std::cerr << "Failed to set TCP_NODELAY: " << strerror(errno) << std::endl; + } + SSL* ssl = SSL_new(server->tls_ctx_.ssl_ctx()); struct bufferevent* bev = bufferevent_openssl_socket_new( base, fd, ssl, BUFFEREVENT_SSL_ACCEPTING, BEV_OPT_CLOSE_ON_FREE); diff --git a/wish/cpp/src/wish_handler.cc b/wish/cpp/src/wish_handler.cc index ff61a1a..ce430de 100644 --- a/wish/cpp/src/wish_handler.cc +++ b/wish/cpp/src/wish_handler.cc @@ -1,21 +1,21 @@ #include "wish_handler.h" +#include +#include +#include +#include + #include #include #include #include #include -#include -#include -#include -#include - WishHandler::WishHandler(struct bufferevent* bev, bool is_server) : bev_(bev), is_server_(is_server), ctx_(nullptr), state_(HANDSHAKE) { struct wslay_event_callbacks callbacks = { - RecvCallback, SendCallback, GenMaskCallback, NULL, - NULL, NULL, OnMsgRecvCallback}; + RecvCallback, SendCallback, GenMaskCallback, NULL, + NULL, NULL, OnMsgRecvCallback}; if (is_server_) { wslay_event_context_server_init(&ctx_, &callbacks, this);