diff --git a/local-server/src/dns_server.rs b/local-server/src/dns_server.rs new file mode 100644 index 0000000..9e9398a --- /dev/null +++ b/local-server/src/dns_server.rs @@ -0,0 +1,78 @@ +use hickory_server::{ + authority::Catalog, + proto::{rr::Name, xfer::Protocol}, + resolver::{ + config::{NameServerConfig, NameServerConfigGroup, ResolverOpts}, + name_server::TokioConnectionProvider, + }, + server::{RequestHandler, ResponseHandler, ResponseInfo}, + store::forwarder::{ForwardAuthority, ForwardConfig}, + ServerFuture, +}; +use std::sync::Arc; +use std::{net::SocketAddr, ops::Deref}; +use tokio::{net::UdpSocket, sync::RwLock}; + +#[derive(Clone)] +pub struct DnsCatalog(Arc>); + +impl DnsCatalog { + pub fn new() -> Self { + Self(Arc::new(RwLock::new(Catalog::new()))) + } +} + +impl Default for DnsCatalog { + fn default() -> Self { + Self::new() + } +} + +impl Deref for DnsCatalog { + type Target = Arc>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[async_trait::async_trait] +impl RequestHandler for DnsCatalog { + async fn handle_request( + &self, + request: &hickory_server::server::Request, + response_handle: R, + ) -> ResponseInfo { + let catalog = self.read().await; + + catalog.handle_request(request, response_handle).await + } +} + +pub async fn serve(dns_catalog: DnsCatalog) { + let cf_name_server = NameServerConfig::new("1.1.1.1:53".parse().unwrap(), Protocol::Udp); + let forward_config = ForwardConfig { + name_servers: NameServerConfigGroup::from(vec![cf_name_server]), + options: Some(ResolverOpts::default()), + }; + + let forwarder = + ForwardAuthority::builder_with_config(forward_config, TokioConnectionProvider::default()) + .with_origin(Name::root()) + .build() + .unwrap(); + + { + let mut catalog = dns_catalog.write().await; + catalog.upsert(Name::root().into(), vec![Arc::new(forwarder)]); + } + + let addr = SocketAddr::from(([0, 0, 0, 0], 8053)); + let sock = UdpSocket::bind(&addr).await.unwrap(); + + let mut server = ServerFuture::new(dns_catalog); + server.register_socket(sock); + + println!("DNS server listening on {addr}"); + server.block_until_done().await.unwrap(); +} diff --git a/local-server/src/lib.rs b/local-server/src/lib.rs index dc45c02..ee0a72f 100644 --- a/local-server/src/lib.rs +++ b/local-server/src/lib.rs @@ -1,145 +1,19 @@ -use axum::{ - body::Body, - extract::{DefaultBodyLimit, Json, Request}, - http::StatusCode, - response::{IntoResponse, Response}, - routing::{any, get, post}, - Extension, Router, -}; -use axum_server::tls_rustls::RustlsConfig; -use hickory_server::{ - authority::{Catalog, ZoneType}, - proto::{ - rr::{Name, RData, Record}, - xfer::Protocol, - }, - resolver::{ - config::{NameServerConfig, NameServerConfigGroup, ResolverOpts}, - name_server::TokioConnectionProvider, - }, - server::{RequestHandler, ResponseHandler, ResponseInfo}, - store::{ - forwarder::{ForwardAuthority, ForwardConfig}, - in_memory::InMemoryAuthority, - }, - ServerFuture, -}; -use http::{header::HeaderMap, HeaderName, HeaderValue, Uri}; +use axum::body::Body; use hyper_rustls::HttpsConnector; -use hyper_util::{ - client::legacy::{connect::HttpConnector, Client}, - rt::TokioExecutor, -}; -use linkup::{ - allow_all_cors, get_additional_headers, get_target_service, MemoryStringStore, NameKind, - Session, SessionAllocator, TargetService, UpdateSessionRequest, -}; -use rustls::ServerConfig; -use serde::Deserialize; -use std::{ - net::{Ipv4Addr, SocketAddr}, - ops::Deref, - path::PathBuf, - str::FromStr, -}; -use std::{path::Path, sync::Arc}; -use tokio::{net::UdpSocket, select, signal, sync::RwLock}; -use tokio_tungstenite::tungstenite::client::IntoClientRequest; -use tower::ServiceBuilder; -use tower_http::trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer}; +use hyper_util::client::legacy::{connect::HttpConnector, Client}; +use linkup::MemoryStringStore; +use std::path::{Path, PathBuf}; +use tokio::{select, signal}; pub mod certificates; +mod dns_server; +mod linkup_server; mod ws; -type HttpsClient = Client, Body>; - -const DISALLOWED_HEADERS: [HeaderName; 2] = [ - HeaderName::from_static("content-encoding"), - HeaderName::from_static("content-length"), -]; - -#[derive(Debug)] -struct ApiError { - message: String, - status_code: StatusCode, -} - -impl ApiError { - fn new(message: String, status_code: StatusCode) -> Self { - ApiError { - message, - status_code, - } - } -} - -impl IntoResponse for ApiError { - fn into_response(self) -> Response { - Response::builder() - .status(self.status_code) - .header("Content-Type", "text/plain") - .body(Body::from(self.message)) - .unwrap() - } -} - -#[derive(Clone)] -pub struct DnsCatalog(Arc>); - -impl DnsCatalog { - pub fn new() -> Self { - Self(Arc::new(RwLock::new(Catalog::new()))) - } -} +pub use dns_server::DnsCatalog; +pub use linkup_server::router; -impl Default for DnsCatalog { - fn default() -> Self { - Self::new() - } -} - -impl Deref for DnsCatalog { - type Target = Arc>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -#[async_trait::async_trait] -impl RequestHandler for DnsCatalog { - async fn handle_request( - &self, - request: &hickory_server::server::Request, - response_handle: R, - ) -> ResponseInfo { - let catalog = self.read().await; - - catalog.handle_request(request, response_handle).await - } -} - -pub fn linkup_router(config_store: MemoryStringStore, dns_catalog: DnsCatalog) -> Router { - let client = https_client(); - - Router::new() - .route("/linkup/local-session", post(linkup_config_handler)) - .route("/linkup/check", get(always_ok)) - .route("/linkup/dns/records", post(dns_create)) // TODO: Modify me - .fallback(any(linkup_request_handler)) - .layer(Extension(config_store)) - .layer(Extension(dns_catalog)) - .layer(Extension(client)) - .layer( - ServiceBuilder::new() - .layer(DefaultBodyLimit::max(1024 * 1024 * 100)) // Set max body size to 100MB - .layer( - TraceLayer::new_for_http() - .on_request(DefaultOnRequest::new()) // Log all incoming requests at INFO level - .on_response(DefaultOnResponse::new()), // Log all responses at INFO level - ), - ) -} +type HttpsClient = Client, Body>; pub async fn start(config_store: MemoryStringStore, certs_dir: &Path) { let dns_catalog = DnsCatalog::new(); @@ -149,13 +23,13 @@ pub async fn start(config_store: MemoryStringStore, certs_dir: &Path) { let https_certs_dir = PathBuf::from(certs_dir); select! { - () = start_server_http(http_config_store, dns_catalog.clone()) => { + () = linkup_server::serve_http(http_config_store, dns_catalog.clone()) => { println!("HTTP server shut down"); }, - () = start_server_https(https_config_store, &https_certs_dir, dns_catalog.clone()) => { + () = linkup_server::serve_https(https_config_store, &https_certs_dir, dns_catalog.clone()) => { println!("HTTPS server shut down"); }, - () = start_dns_server(dns_catalog.clone()) => { + () = dns_server::serve(dns_catalog.clone()) => { println!("DNS server shut down"); }, () = shutdown_signal() => { @@ -164,346 +38,6 @@ pub async fn start(config_store: MemoryStringStore, certs_dir: &Path) { } } -async fn start_server_https( - config_store: MemoryStringStore, - certs_dir: &Path, - dns_catalog: DnsCatalog, -) { - let _ = rustls::crypto::ring::default_provider().install_default(); - - let sni = match certificates::WildcardSniResolver::load_dir(certs_dir) { - Ok(sni) => sni, - Err(error) => { - eprintln!( - "Failed to load certificates from {:?} into SNI: {}", - certs_dir, error - ); - return; - } - }; - - let mut server_config = ServerConfig::builder() - .with_no_client_auth() - .with_cert_resolver(Arc::new(sni)); - server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - - let app = linkup_router(config_store, dns_catalog); - - let addr = SocketAddr::from(([0, 0, 0, 0], 443)); - println!("HTTPS listening on {}", &addr); - - axum_server::bind_rustls(addr, RustlsConfig::from_config(Arc::new(server_config))) - .serve(app.into_make_service()) - .await - .expect("failed to start HTTPS server"); -} - -async fn start_server_http(config_store: MemoryStringStore, dns_catalog: DnsCatalog) { - let app = linkup_router(config_store, dns_catalog); - - let addr = SocketAddr::from(([0, 0, 0, 0], 80)); - println!("HTTP listening on {}", &addr); - - let listener = tokio::net::TcpListener::bind(addr) - .await - .expect("failed to bind to address"); - - axum::serve(listener, app) - .await - .expect("failed to start HTTP server"); -} - -async fn start_dns_server(dns_catalog: DnsCatalog) { - let cf_name_server = NameServerConfig::new("1.1.1.1:53".parse().unwrap(), Protocol::Udp); - let forward_config = ForwardConfig { - name_servers: NameServerConfigGroup::from(vec![cf_name_server]), - options: Some(ResolverOpts::default()), - }; - - let forwarder = - ForwardAuthority::builder_with_config(forward_config, TokioConnectionProvider::default()) - .with_origin(Name::root()) - .build() - .unwrap(); - - { - let mut catalog = dns_catalog.write().await; - catalog.upsert(Name::root().into(), vec![Arc::new(forwarder)]); - } - - let addr = SocketAddr::from(([0, 0, 0, 0], 8053)); - let sock = UdpSocket::bind(&addr).await.unwrap(); - - let mut server = ServerFuture::new(dns_catalog); - server.register_socket(sock); - - println!("listening on {addr}"); - server.block_until_done().await.unwrap(); -} - -async fn linkup_request_handler( - Extension(store): Extension, - Extension(client): Extension, - ws: ws::ExtractOptionalWebSocketUpgrade, - req: Request, -) -> Response { - let sessions = SessionAllocator::new(&store); - - let headers: linkup::HeaderMap = req.headers().into(); - let url = if req.uri().scheme().is_some() { - req.uri().to_string() - } else { - format!( - "http://{}{}", - req.headers() - .get(http::header::HOST) - .and_then(|h| h.to_str().ok()) - .unwrap_or("localhost"), - req.uri() - ) - }; - - let (session_name, config) = match sessions.get_request_session(&url, &headers).await { - Ok(session) => session, - Err(_) => { - return ApiError::new( - "Linkup was unable to determine the session origin of the request. Ensure that your request includes a valid session identifier in the referer or tracestate headers. - Local Server".to_string(), - StatusCode::UNPROCESSABLE_ENTITY, - ) - .into_response() - } - }; - - let target_service = match get_target_service(&url, &headers, &config, &session_name) { - Some(result) => result, - None => { - return ApiError::new( - "The request belonged to a session, but there was no target for the request. Check that the routing rules in your linkup config have a match for this request. - Local Server".to_string(), - StatusCode::NOT_FOUND, - ) - .into_response() - } - }; - - let extra_headers = get_additional_headers(&url, &headers, &session_name, &target_service); - - match ws.0 { - Some(downstream_upgrade) => { - let mut url = target_service.url; - if url.starts_with("http://") { - url = url.replace("http://", "ws://"); - } else if url.starts_with("https://") { - url = url.replace("https://", "wss://"); - } - - let uri = url.parse::().unwrap(); - let host = uri.host().unwrap().to_string(); - let mut upstream_request = uri.into_client_request().unwrap(); - - // Copy over all headers from the incoming request - let mut cookie_values: Vec = Vec::new(); - for (key, value) in req.headers() { - if key == http::header::COOKIE { - if let Ok(cookie_value) = value.to_str().map(str::trim) { - if !cookie_value.is_empty() { - cookie_values.push(cookie_value.to_string()); - } - } - continue; - } - - upstream_request.headers_mut().insert(key, value.clone()); - } - - if !cookie_values.is_empty() { - let combined = cookie_values.join("; "); - if let Ok(cookie_header_value) = HeaderValue::from_str(&combined) { - upstream_request - .headers_mut() - .insert(http::header::COOKIE, cookie_header_value); - } - } - - linkup::normalize_cookie_header(upstream_request.headers_mut()); - - // add the extra headers that linkup wants - let extra_http_headers: HeaderMap = extra_headers.into(); - for (key, value) in extra_http_headers.iter() { - upstream_request.headers_mut().insert(key, value.clone()); - } - - // Overriding host header neccesary for tokio_tungstenite - upstream_request - .headers_mut() - .insert(http::header::HOST, HeaderValue::from_str(&host).unwrap()); - - let (upstream_ws_stream, upstream_response) = - match tokio_tungstenite::connect_async(upstream_request).await { - Ok(connection) => connection, - Err(error) => match error { - tokio_tungstenite::tungstenite::Error::Http(response) => { - let (parts, body) = response.into_parts(); - let body = body.unwrap_or_default(); - - return Response::from_parts(parts, Body::from(body)); - } - error => { - return Response::builder() - .status(StatusCode::BAD_GATEWAY) - .body(Body::from(error.to_string())) - .unwrap(); - } - }, - }; - - let mut downstream_upgrade_response = - downstream_upgrade.on_upgrade(ws::context_handle_socket(upstream_ws_stream)); - - let downstream_response_headers = downstream_upgrade_response.headers_mut(); - - // The headers from the upstream response are more important - trust the upstream server - for (upstream_key, upstream_value) in upstream_response.headers() { - // Except for content encoding headers, cloudflare does _not_ like them.. - if !DISALLOWED_HEADERS.contains(upstream_key) { - downstream_response_headers - .insert(upstream_key.clone(), upstream_value.clone()); - } - } - - downstream_response_headers.extend(allow_all_cors()); - - downstream_upgrade_response - } - None => handle_http_req(req, target_service, extra_headers, client).await, - } -} - -async fn handle_http_req( - mut req: Request, - target_service: TargetService, - extra_headers: linkup::HeaderMap, - client: HttpsClient, -) -> Response { - *req.uri_mut() = Uri::try_from(&target_service.url).unwrap(); - let extra_http_headers: HeaderMap = extra_headers.into(); - req.headers_mut().extend(extra_http_headers); - // Request uri and host headers should not conflict - req.headers_mut().remove(http::header::HOST); - linkup::normalize_cookie_header(req.headers_mut()); - - if target_service.url.starts_with("http://") { - *req.version_mut() = http::Version::HTTP_11; - } - - // Send the modified request to the target service. - let mut resp = match client.request(req).await { - Ok(resp) => resp, - Err(e) => { - return ApiError::new( - format!( - "Failed to proxy request - are all your servers started? {}", - e - ), - StatusCode::BAD_GATEWAY, - ) - .into_response() - } - }; - - resp.headers_mut().extend(allow_all_cors()); - - resp.into_response() -} - -async fn linkup_config_handler( - Extension(store): Extension, - Json(update_req): Json, -) -> impl IntoResponse { - let desired_name = update_req.desired_name.clone(); - let server_conf: Session = match update_req.try_into() { - Ok(conf) => conf, - Err(e) => { - return ApiError::new( - format!("Failed to parse server config: {} - local server", e), - StatusCode::BAD_REQUEST, - ) - .into_response() - } - }; - - let sessions = SessionAllocator::new(&store); - let session_name = sessions - .store_session(server_conf, NameKind::Animal, desired_name) - .await; - - let name = match session_name { - Ok(session_name) => session_name, - Err(e) => { - return ApiError::new( - format!("Failed to store server config: {}", e), - StatusCode::INTERNAL_SERVER_ERROR, - ) - .into_response() - } - }; - - (StatusCode::OK, name).into_response() -} - -async fn always_ok() -> &'static str { - "OK" -} - -#[derive(Deserialize)] -pub struct CreateDnsRecord { - pub domain: String, -} - -async fn dns_create( - Extension(dns_catalog): Extension, - Json(payload): Json, -) -> impl IntoResponse { - let mut catalog = dns_catalog.write().await; - - let record_name = Name::from_str(&format!("{}.", payload.domain)).unwrap(); - - let authority = InMemoryAuthority::empty(record_name.clone(), ZoneType::Primary, false); - - let record = Record::from_rdata( - record_name.clone(), - 3600, - RData::A(Ipv4Addr::new(127, 0, 0, 1).into()), - ); - - authority.upsert(record, 0).await; - - catalog.upsert(record_name.clone().into(), vec![Arc::new(authority)]); - - StatusCode::CREATED.into_response() -} - -fn https_client() -> HttpsClient { - let _ = rustls::crypto::ring::default_provider().install_default(); - - let mut roots = rustls::RootCertStore::empty(); - for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") { - roots.add(cert).unwrap(); - } - - let tls = rustls::ClientConfig::builder() - .with_root_certificates(roots) - .with_no_client_auth(); - - let https = hyper_rustls::HttpsConnectorBuilder::new() - .with_tls_config(tls) - .https_or_http() - .enable_http1() - .enable_http2() - .build(); - - Client::builder(TokioExecutor::new()).build(https) -} - async fn shutdown_signal() { let ctrl_c = async { signal::ctrl_c() diff --git a/local-server/src/linkup_server/handlers/dns.rs b/local-server/src/linkup_server/handlers/dns.rs new file mode 100644 index 0000000..6635844 --- /dev/null +++ b/local-server/src/linkup_server/handlers/dns.rs @@ -0,0 +1,41 @@ +use std::{net::Ipv4Addr, str::FromStr, sync::Arc}; + +use axum::{response::IntoResponse, Extension, Json}; +use hickory_server::{ + authority::ZoneType, + proto::rr::{RData, Record}, + resolver::Name, + store::in_memory::InMemoryAuthority, +}; +use http::StatusCode; +use serde::Deserialize; + +use crate::dns_server::DnsCatalog; + +#[derive(Deserialize)] +pub struct CreateDnsRecord { + pub domain: String, +} + +pub async fn handle_create( + Extension(dns_catalog): Extension, + Json(payload): Json, +) -> impl IntoResponse { + let mut catalog = dns_catalog.write().await; + + let record_name = Name::from_str(&format!("{}.", payload.domain)).unwrap(); + + let authority = InMemoryAuthority::empty(record_name.clone(), ZoneType::Primary, false); + + let record = Record::from_rdata( + record_name.clone(), + 3600, + RData::A(Ipv4Addr::new(127, 0, 0, 1).into()), + ); + + authority.upsert(record, 0).await; + + catalog.upsert(record_name.clone().into(), vec![Arc::new(authority)]); + + StatusCode::CREATED.into_response() +} diff --git a/local-server/src/linkup_server/handlers/local_session.rs b/local-server/src/linkup_server/handlers/local_session.rs new file mode 100644 index 0000000..2b41967 --- /dev/null +++ b/local-server/src/linkup_server/handlers/local_session.rs @@ -0,0 +1,40 @@ +use axum::{response::IntoResponse, Extension, Json}; +use http::StatusCode; +use linkup::{MemoryStringStore, NameKind, Session, SessionAllocator, UpdateSessionRequest}; + +use crate::linkup_server::ApiError; + +pub async fn handle_upsert( + Extension(store): Extension, + Json(update_req): Json, +) -> impl IntoResponse { + let desired_name = update_req.desired_name.clone(); + let server_conf: Session = match update_req.try_into() { + Ok(conf) => conf, + Err(e) => { + return ApiError::new( + format!("Failed to parse server config: {} - local server", e), + StatusCode::BAD_REQUEST, + ) + .into_response() + } + }; + + let sessions = SessionAllocator::new(&store); + let session_name = sessions + .store_session(server_conf, NameKind::Animal, desired_name) + .await; + + let name = match session_name { + Ok(session_name) => session_name, + Err(e) => { + return ApiError::new( + format!("Failed to store server config: {}", e), + StatusCode::INTERNAL_SERVER_ERROR, + ) + .into_response() + } + }; + + (StatusCode::OK, name).into_response() +} diff --git a/local-server/src/linkup_server/handlers/mod.rs b/local-server/src/linkup_server/handlers/mod.rs new file mode 100644 index 0000000..16a8df0 --- /dev/null +++ b/local-server/src/linkup_server/handlers/mod.rs @@ -0,0 +1,3 @@ +pub mod dns; +pub mod local_session; +pub mod proxy; diff --git a/local-server/src/linkup_server/handlers/proxy.rs b/local-server/src/linkup_server/handlers/proxy.rs new file mode 100644 index 0000000..ba75871 --- /dev/null +++ b/local-server/src/linkup_server/handlers/proxy.rs @@ -0,0 +1,193 @@ +use axum::{ + body::Body, + extract::Request, + response::{IntoResponse, Response}, + Extension, +}; +use http::{header::HeaderMap, HeaderName, HeaderValue, StatusCode, Uri}; +use linkup::{ + allow_all_cors, get_additional_headers, get_target_service, MemoryStringStore, + SessionAllocator, TargetService, +}; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; + +use crate::{linkup_server::ApiError, ws, HttpsClient}; + +const DISALLOWED_HEADERS: [HeaderName; 2] = [ + HeaderName::from_static("content-encoding"), + HeaderName::from_static("content-length"), +]; + +pub async fn handle( + Extension(store): Extension, + Extension(client): Extension, + ws: ws::ExtractOptionalWebSocketUpgrade, + req: Request, +) -> Response { + let sessions = SessionAllocator::new(&store); + + let headers: linkup::HeaderMap = req.headers().into(); + let url = if req.uri().scheme().is_some() { + req.uri().to_string() + } else { + format!( + "http://{}{}", + req.headers() + .get(http::header::HOST) + .and_then(|h| h.to_str().ok()) + .unwrap_or("localhost"), + req.uri() + ) + }; + + let (session_name, config) = match sessions.get_request_session(&url, &headers).await { + Ok(session) => session, + Err(_) => { + return ApiError::new( + "Linkup was unable to determine the session origin of the request. Ensure that your request includes a valid session identifier in the referer or tracestate headers. - Local Server".to_string(), + StatusCode::UNPROCESSABLE_ENTITY, + ) + .into_response() + } + }; + + let target_service = match get_target_service(&url, &headers, &config, &session_name) { + Some(result) => result, + None => { + return ApiError::new( + "The request belonged to a session, but there was no target for the request. Check that the routing rules in your linkup config have a match for this request. - Local Server".to_string(), + StatusCode::NOT_FOUND, + ) + .into_response() + } + }; + + let extra_headers = get_additional_headers(&url, &headers, &session_name, &target_service); + + match ws.0 { + Some(downstream_upgrade) => { + let mut url = target_service.url; + if url.starts_with("http://") { + url = url.replace("http://", "ws://"); + } else if url.starts_with("https://") { + url = url.replace("https://", "wss://"); + } + + let uri = url.parse::().unwrap(); + let host = uri.host().unwrap().to_string(); + let mut upstream_request = uri.into_client_request().unwrap(); + + // Copy over all headers from the incoming request + let mut cookie_values: Vec = Vec::new(); + for (key, value) in req.headers() { + if key == http::header::COOKIE { + if let Ok(cookie_value) = value.to_str().map(str::trim) { + if !cookie_value.is_empty() { + cookie_values.push(cookie_value.to_string()); + } + } + continue; + } + + upstream_request.headers_mut().insert(key, value.clone()); + } + + if !cookie_values.is_empty() { + let combined = cookie_values.join("; "); + if let Ok(cookie_header_value) = HeaderValue::from_str(&combined) { + upstream_request + .headers_mut() + .insert(http::header::COOKIE, cookie_header_value); + } + } + + linkup::normalize_cookie_header(upstream_request.headers_mut()); + + // add the extra headers that linkup wants + let extra_http_headers: HeaderMap = extra_headers.into(); + for (key, value) in extra_http_headers.iter() { + upstream_request.headers_mut().insert(key, value.clone()); + } + + // Overriding host header neccesary for tokio_tungstenite + upstream_request + .headers_mut() + .insert(http::header::HOST, HeaderValue::from_str(&host).unwrap()); + + let (upstream_ws_stream, upstream_response) = + match tokio_tungstenite::connect_async(upstream_request).await { + Ok(connection) => connection, + Err(error) => match error { + tokio_tungstenite::tungstenite::Error::Http(response) => { + let (parts, body) = response.into_parts(); + let body = body.unwrap_or_default(); + + return Response::from_parts(parts, Body::from(body)); + } + error => { + return Response::builder() + .status(StatusCode::BAD_GATEWAY) + .body(Body::from(error.to_string())) + .unwrap(); + } + }, + }; + + let mut downstream_upgrade_response = + downstream_upgrade.on_upgrade(ws::context_handle_socket(upstream_ws_stream)); + + let downstream_response_headers = downstream_upgrade_response.headers_mut(); + + // The headers from the upstream response are more important - trust the upstream server + for (upstream_key, upstream_value) in upstream_response.headers() { + // Except for content encoding headers, cloudflare does _not_ like them.. + if !DISALLOWED_HEADERS.contains(upstream_key) { + downstream_response_headers + .insert(upstream_key.clone(), upstream_value.clone()); + } + } + + downstream_response_headers.extend(allow_all_cors()); + + downstream_upgrade_response + } + None => handle_http_req(req, target_service, extra_headers, client).await, + } +} + +async fn handle_http_req( + mut req: Request, + target_service: TargetService, + extra_headers: linkup::HeaderMap, + client: HttpsClient, +) -> Response { + *req.uri_mut() = Uri::try_from(&target_service.url).unwrap(); + let extra_http_headers: HeaderMap = extra_headers.into(); + req.headers_mut().extend(extra_http_headers); + // Request uri and host headers should not conflict + req.headers_mut().remove(http::header::HOST); + linkup::normalize_cookie_header(req.headers_mut()); + + if target_service.url.starts_with("http://") { + *req.version_mut() = http::Version::HTTP_11; + } + + // Send the modified request to the target service. + let mut resp = match client.request(req).await { + Ok(resp) => resp, + Err(e) => { + return ApiError::new( + format!( + "Failed to proxy request - are all your servers started? {}", + e + ), + StatusCode::BAD_GATEWAY, + ) + .into_response() + } + }; + + resp.headers_mut().extend(allow_all_cors()); + + resp.into_response() +} diff --git a/local-server/src/linkup_server/mod.rs b/local-server/src/linkup_server/mod.rs new file mode 100644 index 0000000..6b77c3f --- /dev/null +++ b/local-server/src/linkup_server/mod.rs @@ -0,0 +1,141 @@ +use std::{net::SocketAddr, path::Path, sync::Arc}; + +use axum::{ + body::Body, + extract::DefaultBodyLimit, + response::{IntoResponse, Response}, + routing::{any, get, post}, + Extension, Router, +}; +use axum_server::tls_rustls::RustlsConfig; +use http::StatusCode; +use hyper_util::{client::legacy::Client, rt::TokioExecutor}; +use linkup::MemoryStringStore; +use rustls::ServerConfig; +use tower::ServiceBuilder; +use tower_http::trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer}; + +use crate::{certificates, dns_server::DnsCatalog, HttpsClient}; + +mod handlers; + +#[derive(Debug)] +struct ApiError { + message: String, + status_code: StatusCode, +} + +impl ApiError { + fn new(message: String, status_code: StatusCode) -> Self { + ApiError { + message, + status_code, + } + } +} + +impl IntoResponse for ApiError { + fn into_response(self) -> Response { + Response::builder() + .status(self.status_code) + .header("Content-Type", "text/plain") + .body(Body::from(self.message)) + .unwrap() + } +} + +pub async fn serve_http(config_store: MemoryStringStore, dns_catalog: DnsCatalog) { + let app = router(config_store, dns_catalog); + + let addr = SocketAddr::from(([0, 0, 0, 0], 80)); + println!("HTTP listening on {}", &addr); + + let listener = tokio::net::TcpListener::bind(addr) + .await + .expect("failed to bind to address"); + + axum::serve(listener, app) + .await + .expect("failed to start HTTP server"); +} + +pub async fn serve_https( + config_store: MemoryStringStore, + certs_dir: &Path, + dns_catalog: DnsCatalog, +) { + let _ = rustls::crypto::ring::default_provider().install_default(); + + let sni = match certificates::WildcardSniResolver::load_dir(certs_dir) { + Ok(sni) => sni, + Err(error) => { + eprintln!( + "Failed to load certificates from {:?} into SNI: {}", + certs_dir, error + ); + return; + } + }; + + let mut server_config = ServerConfig::builder() + .with_no_client_auth() + .with_cert_resolver(Arc::new(sni)); + server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + + let app = router(config_store, dns_catalog); + + let addr = SocketAddr::from(([0, 0, 0, 0], 443)); + println!("HTTPS listening on {}", &addr); + + axum_server::bind_rustls(addr, RustlsConfig::from_config(Arc::new(server_config))) + .serve(app.into_make_service()) + .await + .expect("failed to start HTTPS server"); +} + +pub fn router(config_store: MemoryStringStore, dns_catalog: DnsCatalog) -> Router { + let client = https_client(); + + Router::new() + .route( + "/linkup/local-session", + post(handlers::local_session::handle_upsert), + ) + .route("/linkup/check", get(async || "Ok")) + .route("/linkup/dns/records", post(handlers::dns::handle_create)) + .fallback(any(handlers::proxy::handle)) + .layer(Extension(config_store)) + .layer(Extension(dns_catalog)) + .layer(Extension(client)) + .layer( + ServiceBuilder::new() + .layer(DefaultBodyLimit::max(1024 * 1024 * 100)) // Set max body size to 100MB + .layer( + TraceLayer::new_for_http() + .on_request(DefaultOnRequest::new()) // Log all incoming requests at INFO level + .on_response(DefaultOnResponse::new()), // Log all responses at INFO level + ), + ) +} + +fn https_client() -> HttpsClient { + let _ = rustls::crypto::ring::default_provider().install_default(); + + let mut roots = rustls::RootCertStore::empty(); + for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") { + roots.add(cert).unwrap(); + } + + let tls = rustls::ClientConfig::builder() + .with_root_certificates(roots) + .with_no_client_auth(); + + let https = hyper_rustls::HttpsConnectorBuilder::new() + .with_tls_config(tls) + .https_or_http() + .enable_http1() + .enable_http2() + .build(); + + Client::builder(TokioExecutor::new()).build(https) +} diff --git a/server-tests/tests/helpers.rs b/server-tests/tests/helpers.rs index b1a39c3..ea46c71 100644 --- a/server-tests/tests/helpers.rs +++ b/server-tests/tests/helpers.rs @@ -1,7 +1,8 @@ use std::process::Command; use linkup::{Domain, MemoryStringStore, SessionService, UpdateSessionRequest}; -use linkup_local_server::{linkup_router, DnsCatalog}; + +use linkup_local_server::{router, DnsCatalog}; use reqwest::Url; use tokio::net::TcpListener; @@ -14,7 +15,7 @@ pub enum ServerKind { pub async fn setup_server(kind: ServerKind) -> String { match kind { ServerKind::Local => { - let app = linkup_router(MemoryStringStore::default(), DnsCatalog::new()); + let app = router(MemoryStringStore::default(), DnsCatalog::new()); // Bind to a random port assigned by the OS let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();