Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions crates/openshell-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use tracing::{debug, error, info};

pub use grpc::OpenShellService;
pub use http::{health_router, http_router};
pub use multiplex::ALPN_H2;
pub use multiplex::{MultiplexService, MultiplexedService};
use persistence::Store;
use sandbox::{SandboxClient, spawn_sandbox_watcher, spawn_store_reconciler};
Expand Down Expand Up @@ -180,12 +181,20 @@ pub async fn run_server(config: Config, tracing_log_bus: TracingLogBus) -> Resul

// Build TLS acceptor when TLS is configured; otherwise serve plaintext.
let tls_acceptor = if let Some(tls) = &config.tls {
Some(TlsAcceptor::from_files(
let acceptor = TlsAcceptor::from_files(
&tls.cert_path,
&tls.key_path,
&tls.client_ca_path,
tls.allow_unauthenticated,
)?)
)?;
info!(
cert = %tls.cert_path.display(),
key = %tls.key_path.display(),
client_ca = %tls.client_ca_path.display(),
allow_unauthenticated = tls.allow_unauthenticated,
"TLS enabled — ALPN advertises h2 + http/1.1",
);
Some(acceptor)
} else {
info!("TLS disabled — accepting plaintext connections");
None
Expand All @@ -208,7 +217,20 @@ pub async fn run_server(config: Config, tracing_log_bus: TracingLogBus) -> Resul
tokio::spawn(async move {
match tls_acceptor.inner().accept(stream).await {
Ok(tls_stream) => {
if let Err(e) = service.serve(tls_stream).await {
// Use ALPN-negotiated protocol when available. This
// avoids the byte-sniffing auto-detection in
// `serve()`, which can misidentify h2 connections as
// HTTP/1.1 when the first read returns a partial
// preface.
let alpn = tls_stream.get_ref().1.alpn_protocol().unwrap_or_default();
let result = if alpn == ALPN_H2 {
debug!(client = %addr, "ALPN negotiated h2 — serving HTTP/2");
service.serve_h2(tls_stream).await
} else {
debug!(client = %addr, alpn = ?String::from_utf8_lossy(alpn), "ALPN fallback — auto-detecting protocol");
service.serve(tls_stream).await
};
if let Err(e) = result {
error!(error = %e, client = %addr, "Connection error");
}
}
Expand Down Expand Up @@ -255,4 +277,9 @@ mod tests {
assert!(!is_benign_tls_handshake_failure(&error));
}
}

#[test]
fn alpn_h2_constant_matches_standard_protocol_id() {
assert_eq!(super::ALPN_H2, b"h2");
}
}
38 changes: 38 additions & 0 deletions crates/openshell-server/src/multiplex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ use tower::ServiceExt;

use crate::{OpenShellService, ServerState, http_router, inference::InferenceService};

/// ALPN protocol identifier for HTTP/2.
pub const ALPN_H2: &[u8] = b"h2";

/// Maximum inbound gRPC message size (1 MB).
///
/// Replaces tonic's implicit 4 MB default with a conservative limit to
Expand All @@ -49,6 +52,11 @@ impl MultiplexService {
}

/// Serve a connection, routing to gRPC or HTTP based on content-type.
///
/// Uses hyper's auto-detection to determine whether the connection speaks
/// HTTP/1.1 or HTTP/2. For TLS connections where ALPN already negotiated
/// the protocol, prefer [`serve_h2`](Self::serve_h2) to skip the
/// auto-detection round-trip.
pub async fn serve<S>(&self, stream: S) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
Expand All @@ -68,6 +76,36 @@ impl MultiplexService {

Ok(())
}

/// Serve a connection that has already been identified as HTTP/2.
///
/// This is the preferred path for TLS connections where ALPN negotiated
/// `h2`. It avoids the byte-sniffing auto-detection in [`serve`](Self::serve)
/// and immediately starts the HTTP/2 state machine, which eliminates a
/// class of edge-case failures (partial reads, buffering delays) that can
/// cause the auto-detector to misidentify an h2 connection as HTTP/1.1.
pub async fn serve_h2<S>(
&self,
stream: S,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let openshell = OpenShellServer::new(OpenShellService::new(self.state.clone()))
.max_decoding_message_size(MAX_GRPC_DECODE_SIZE);
let inference = InferenceServer::new(InferenceService::new(self.state.clone()))
.max_decoding_message_size(MAX_GRPC_DECODE_SIZE);
let grpc_service = GrpcRouter::new(openshell, inference);
let http_service = http_router(self.state.clone());

let service = MultiplexedService::new(grpc_service, http_service);

hyper::server::conn::http2::Builder::new(TokioExecutor::new())
.serve_connection(TokioIo::new(stream), service)
.await?;

Ok(())
}
}

/// Combined gRPC service that routes between `OpenShell` and Inference services
Expand Down
29 changes: 29 additions & 0 deletions crates/openshell-server/src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,32 @@ fn load_key(path: &Path) -> Result<PrivateKeyDer<'static>> {

Err(Error::tls("no private key found in file"))
}

#[cfg(test)]
mod tests {
use super::*;
use std::path::Path;

#[test]
fn from_files_rejects_missing_cert() {
let result = TlsAcceptor::from_files(
Path::new("/nonexistent/cert.pem"),
Path::new("/nonexistent/key.pem"),
Path::new("/nonexistent/ca.pem"),
false,
);
assert!(result.is_err());
}

#[test]
fn load_certs_rejects_nonexistent_file() {
let result = load_certs(Path::new("/nonexistent/cert.pem"));
assert!(result.is_err());
}

#[test]
fn load_key_rejects_nonexistent_file() {
let result = load_key(Path::new("/nonexistent/key.pem"));
assert!(result.is_err());
}
}
Loading