Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 153 additions & 5 deletions crates/openshell-router/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ struct ValidationProbe {
path: &'static str,
protocol: &'static str,
body: bytes::Bytes,
/// Alternate body to try when the primary probe fails with HTTP 400.
/// Used for OpenAI chat completions where newer models require
/// `max_completion_tokens` while legacy/self-hosted backends only
/// accept `max_tokens`.
fallback_body: Option<bytes::Bytes>,
}

/// Response from a proxied HTTP request to a backend (fully buffered).
Expand Down Expand Up @@ -163,12 +168,17 @@ fn validation_probe(route: &ResolvedRoute) -> Result<ValidationProbe, Validation
.iter()
.any(|protocol| protocol == "openai_chat_completions")
{
// Use max_completion_tokens (modern OpenAI parameter, required by GPT-5+)
// with max_tokens as fallback for legacy/self-hosted backends.
return Ok(ValidationProbe {
path: "/v1/chat/completions",
protocol: "openai_chat_completions",
body: bytes::Bytes::from_static(
br#"{"messages":[{"role":"user","content":"ping"}],"max_tokens":32}"#,
br#"{"messages":[{"role":"user","content":"ping"}],"max_completion_tokens":32}"#,
),
fallback_body: Some(bytes::Bytes::from_static(
br#"{"messages":[{"role":"user","content":"ping"}],"max_tokens":32}"#,
)),
});
}

Expand All @@ -183,6 +193,7 @@ fn validation_probe(route: &ResolvedRoute) -> Result<ValidationProbe, Validation
body: bytes::Bytes::from_static(
br#"{"messages":[{"role":"user","content":"ping"}],"max_tokens":32}"#,
),
fallback_body: None,
});
}

Expand All @@ -195,6 +206,7 @@ fn validation_probe(route: &ResolvedRoute) -> Result<ValidationProbe, Validation
path: "/v1/responses",
protocol: "openai_responses",
body: bytes::Bytes::from_static(br#"{"input":"ping","max_output_tokens":32}"#),
fallback_body: None,
});
}

Expand All @@ -207,6 +219,7 @@ fn validation_probe(route: &ResolvedRoute) -> Result<ValidationProbe, Validation
path: "/v1/completions",
protocol: "openai_completions",
body: bytes::Bytes::from_static(br#"{"prompt":"ping","max_tokens":32}"#),
fallback_body: None,
});
}

Expand All @@ -233,7 +246,47 @@ pub async fn verify_backend_endpoint(
});
}

let response = send_backend_request(client, route, "POST", probe.path, headers, probe.body)
let result = try_validation_request(
client,
route,
probe.path,
probe.protocol,
headers.clone(),
probe.body,
)
.await;

// If the primary probe failed with a request-shape error (HTTP 400) and
// there is a fallback body, retry with the alternate token parameter.
// This handles the split between `max_completion_tokens` (GPT-5+) and
// `max_tokens` (legacy/self-hosted backends).
if let (Err(err), Some(fallback_body)) = (&result, probe.fallback_body) {
if err.kind == ValidationFailureKind::RequestShape {
return try_validation_request(
client,
route,
probe.path,
probe.protocol,
headers,
fallback_body,
)
.await;
}
}

result
}

/// Send a single validation request and classify the response.
async fn try_validation_request(
client: &reqwest::Client,
route: &ResolvedRoute,
path: &str,
protocol: &str,
headers: Vec<(String, String)>,
body: bytes::Bytes,
) -> Result<ValidatedEndpoint, ValidationFailure> {
let response = send_backend_request(client, route, "POST", path, headers, body)
.await
.map_err(|err| match err {
RouterError::UpstreamUnavailable(details) => ValidationFailure {
Expand All @@ -253,12 +306,12 @@ pub async fn verify_backend_endpoint(
details,
},
})?;
let url = build_backend_url(&route.endpoint, probe.path);
let url = build_backend_url(&route.endpoint, path);

if response.status().is_success() {
return Ok(ValidatedEndpoint {
url,
protocol: probe.protocol.to_string(),
protocol: protocol.to_string(),
});
}

Expand Down Expand Up @@ -376,7 +429,7 @@ fn build_backend_url(endpoint: &str, path: &str) -> String {

#[cfg(test)]
mod tests {
use super::{build_backend_url, verify_backend_endpoint};
use super::{build_backend_url, verify_backend_endpoint, ValidationFailureKind};
use crate::config::ResolvedRoute;
use openshell_core::inference::AuthHeader;
use wiremock::matchers::{body_partial_json, header, method, path};
Expand Down Expand Up @@ -463,4 +516,99 @@ mod tests {
assert_eq!(validated.protocol, "openai_chat_completions");
assert_eq!(validated.url, "mock://test-backend/v1/chat/completions");
}

/// GPT-5+ models reject `max_tokens` — the primary probe uses
/// `max_completion_tokens` so validation should succeed directly.
#[tokio::test]
async fn verify_openai_chat_uses_max_completion_tokens() {
let mock_server = MockServer::start().await;
let route = test_route(
&mock_server.uri(),
&["openai_chat_completions"],
AuthHeader::Bearer,
);

Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.and(body_partial_json(serde_json::json!({
"max_completion_tokens": 32,
})))
.respond_with(
ResponseTemplate::new(200).set_body_json(serde_json::json!({"id": "chatcmpl-1"})),
)
.mount(&mock_server)
.await;

let client = reqwest::Client::builder().build().unwrap();
let validated = verify_backend_endpoint(&client, &route).await.unwrap();

assert_eq!(validated.protocol, "openai_chat_completions");
}

/// Legacy/self-hosted backends that reject `max_completion_tokens`
/// should succeed on the fallback probe using `max_tokens`.
#[tokio::test]
async fn verify_openai_chat_falls_back_to_max_tokens() {
let mock_server = MockServer::start().await;
let route = test_route(
&mock_server.uri(),
&["openai_chat_completions"],
AuthHeader::Bearer,
);

// Reject the primary probe (max_completion_tokens) with 400.
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.and(body_partial_json(serde_json::json!({
"max_completion_tokens": 32,
})))
.respond_with(ResponseTemplate::new(400).set_body_string(
r#"{"error":{"message":"Unsupported parameter: 'max_completion_tokens'"}}"#,
))
.expect(1)
.mount(&mock_server)
.await;

// Accept the fallback probe (max_tokens).
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.and(body_partial_json(serde_json::json!({
"max_tokens": 32,
})))
.respond_with(
ResponseTemplate::new(200).set_body_json(serde_json::json!({"id": "chatcmpl-2"})),
)
.expect(1)
.mount(&mock_server)
.await;

let client = reqwest::Client::builder().build().unwrap();
let validated = verify_backend_endpoint(&client, &route).await.unwrap();

assert_eq!(validated.protocol, "openai_chat_completions");
}

/// Non-chat-completions probes (e.g. anthropic_messages) should not
/// have a fallback — a 400 remains a hard failure.
#[tokio::test]
async fn verify_non_chat_completions_no_fallback() {
let mock_server = MockServer::start().await;
let route = test_route(
&mock_server.uri(),
&["anthropic_messages"],
AuthHeader::Custom("x-api-key"),
);

Mock::given(method("POST"))
.and(path("/v1/messages"))
.respond_with(ResponseTemplate::new(400).set_body_string("bad request"))
.mount(&mock_server)
.await;

let client = reqwest::Client::builder().build().unwrap();
let result = verify_backend_endpoint(&client, &route).await;

assert!(result.is_err());
assert_eq!(result.unwrap_err().kind, ValidationFailureKind::RequestShape);
}
}
2 changes: 1 addition & 1 deletion crates/openshell-server/src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,7 @@ mod tests {
.and(header("content-type", "application/json"))
.and(body_partial_json(serde_json::json!({
"model": "gpt-4o-mini",
"max_tokens": 32,
"max_completion_tokens": 32,
})))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"id": "chatcmpl-123",
Expand Down
Loading