diff --git a/wish/cpp/benchmark/CMakeLists.txt b/wish/cpp/benchmark/CMakeLists.txt index 9e352f4..e0a5a36 100644 --- a/wish/cpp/benchmark/CMakeLists.txt +++ b/wish/cpp/benchmark/CMakeLists.txt @@ -16,4 +16,10 @@ add_executable(tls_benchmark_client ) target_link_libraries(tls_benchmark_client wish_handler + absl::flags + absl::flags_parse + absl::log + absl::log_initialize + "$" + benchmark::benchmark ) \ No newline at end of file diff --git a/wish/cpp/benchmark/tls_client.cc b/wish/cpp/benchmark/tls_client.cc index ccf17a4..0d30a18 100644 --- a/wish/cpp/benchmark/tls_client.cc +++ b/wish/cpp/benchmark/tls_client.cc @@ -1,48 +1,228 @@ +#include +#include +#include +#include +#include +#include + +// To use BoringSSL +#define EVENT__HAVE_OPENSSL 1 +#include +#include + +#include #include -#include +#include #include #include -#include "../src/tls_client.h" +#include "../src/tls_context.h" #include "../src/wish_handler.h" +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/log/initialize.h" +#include "absl/log/log.h" +#include "benchmark/benchmark.h" -int main() { - TlsClient client("../certs/ca.crt", "../certs/client.crt", - "../certs/client.key", "127.0.0.1", 8080); +ABSL_FLAG(std::string, host, "127.0.0.1", "Server host to connect to"); +ABSL_FLAG(int, port, 8080, "Server port to connect to"); +ABSL_FLAG(std::string, ca_cert, "../certs/ca.crt", "Path to CA certificate"); +ABSL_FLAG(std::string, client_cert, "../certs/client.crt", + "Path to client certificate"); +ABSL_FLAG(std::string, client_key, "../certs/client.key", + "Path to client private key"); - if (!client.Init()) { - std::cerr << "Failed to initialize client" << std::endl; - return 1; +namespace { + +struct ClientState { + struct event_base* base = nullptr; + struct bufferevent* bev = nullptr; + WishHandler* handler = nullptr; + + bool connected = false; + bool awaiting_response = false; + std::chrono::steady_clock::time_point request_start; + std::vector latencies_us; +}; + +double PercentileFromSorted(const std::vector& values, double p) { + if (values.empty()) { + return 0.0; + } + const size_t idx = static_cast(p * (values.size() - 1)); + return values[idx]; +} + +// Global TLS context initialized once in main. +TlsContext* g_tls_ctx = nullptr; + +bool InitConnection(ClientState* client) { + client->base = event_base_new(); + if (!client->base) { + LOG(ERROR) << "event_base_new() failed"; + return false; + } + + struct sockaddr_in sin; + std::memset(&sin, 0, sizeof(sin)); + sin.sin_family = AF_INET; + sin.sin_port = htons(absl::GetFlag(FLAGS_port)); + + const std::string host = absl::GetFlag(FLAGS_host); + + if (inet_pton(AF_INET, host.c_str(), &sin.sin_addr) != 1) { + LOG(ERROR) << "Invalid IPv4 host: " << host; + return false; } - const int kTotalMessages = 1000; - int messages_received = 0; - auto start_time = std::chrono::high_resolution_clock::now(); + SSL* ssl = SSL_new(g_tls_ctx->ssl_ctx()); + if (!ssl) { + LOG(ERROR) << "SSL_new() failed"; + return false; + } - client.SetOnOpen([&start_time, kTotalMessages](WishHandler* handler) { - std::cout << "Connected! Starting benchmark..." << std::endl; + client->bev = bufferevent_openssl_socket_new( + client->base, -1, ssl, BUFFEREVENT_SSL_CONNECTING, BEV_OPT_CLOSE_ON_FREE); + if (!client->bev) { + LOG(ERROR) << "bufferevent_openssl_socket_new() failed"; + SSL_free(ssl); + return false; + } - start_time = std::chrono::high_resolution_clock::now(); + if (bufferevent_socket_connect(client->bev, + reinterpret_cast(&sin), + sizeof(sin)) < 0) { + LOG(ERROR) << "bufferevent_socket_connect() failed"; + return false; + } - for (int i = 0; i < kTotalMessages; ++i) { - handler->SendText("Benchmark message " + std::to_string(i)); + client->handler = new WishHandler(client->bev, false); + client->handler->SetOnOpen([client]() { + const int fd = bufferevent_getfd(client->bev); + if (fd >= 0) { + int nodelay = 1; + setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &nodelay, sizeof(nodelay)); } + + client->connected = true; + event_base_loopexit(client->base, nullptr); }); - client.SetOnMessage([&messages_received, kTotalMessages, - &start_time](uint8_t opcode, const std::string& msg) { - messages_received++; - if (messages_received == kTotalMessages) { - auto end_time = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast( - end_time - start_time); - std::cout << "Benchmark complete: Received " << kTotalMessages - << " messages in " << duration.count() << " ms." << std::endl; - exit(0); + client->handler->SetOnMessage( + [client](uint8_t opcode, const std::string& msg) { + (void)opcode; + (void)msg; + + if (!client->awaiting_response) { + return; + } + + const auto end = std::chrono::steady_clock::now(); + const double latency_us = + std::chrono::duration(end - + client->request_start) + .count(); + client->latencies_us.push_back(latency_us); + client->awaiting_response = false; + event_base_loopexit(client->base, nullptr); + }); + + client->handler->Start(); + + event_base_dispatch(client->base); + return client->connected; +} + +void CleanupConnection(ClientState* client) { + if (client->handler) { + delete client->handler; + client->handler = nullptr; + client->bev = nullptr; + } + + if (client->base) { + event_base_free(client->base); + client->base = nullptr; + } +} + +static void BM_TlsClient(benchmark::State& state) { + const int payload_size = static_cast(state.range(0)); + const std::string payload(payload_size, 'A'); + + ClientState client; + if (!InitConnection(&client)) { + state.SkipWithError("Failed to establish TLS WiSH connection"); + CleanupConnection(&client); + return; + } + + for (auto _ : state) { + (void)_; + client.awaiting_response = true; + client.request_start = std::chrono::steady_clock::now(); + + const int send_result = client.handler->SendBinary(payload); + if (send_result != 0) { + state.SkipWithError("WishHandler::SendBinary failed"); + break; } - }); - client.Run(); + event_base_dispatch(client.base); + if (client.awaiting_response) { + state.SkipWithError("Connection closed before response"); + break; + } + } + + if (!client.latencies_us.empty()) { + std::sort(client.latencies_us.begin(), client.latencies_us.end()); + + state.counters["p10_latency_us"] = + PercentileFromSorted(client.latencies_us, 0.10); + state.counters["p50_latency_us"] = + PercentileFromSorted(client.latencies_us, 0.50); + state.counters["p90_latency_us"] = + PercentileFromSorted(client.latencies_us, 0.90); + state.counters["p99_latency_us"] = + PercentileFromSorted(client.latencies_us, 0.99); + + state.SetItemsProcessed(client.latencies_us.size()); + } + + CleanupConnection(&client); +} + +BENCHMARK(BM_TlsClient) + ->UseRealTime() + ->Unit(benchmark::kMicrosecond) + ->Args({0}) + ->RangeMultiplier(2) + ->Range(1 << 10, 128 << 10); + +} // namespace + +int main(int argc, char** argv) { + benchmark::MaybeReenterWithoutASLR(argc, argv); + benchmark::Initialize(&argc, argv); + + absl::ParseCommandLine(argc, argv); + absl::InitializeLog(); + + TlsContext tls_ctx; + tls_ctx.set_ca_file(absl::GetFlag(FLAGS_ca_cert)); + tls_ctx.set_certificate_file(absl::GetFlag(FLAGS_client_cert)); + tls_ctx.set_private_key_file(absl::GetFlag(FLAGS_client_key)); + + if (!tls_ctx.Init(false)) { + LOG(ERROR) << "Failed to initialize TLS context"; + return 1; + } + g_tls_ctx = &tls_ctx; + + benchmark::RunSpecifiedBenchmarks(); + benchmark::Shutdown(); + g_tls_ctx = nullptr; return 0; }