diff --git a/Cargo.lock b/Cargo.lock index 0a9e381b..d0f65b49 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -156,9 +156,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.88" +version = "0.1.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", @@ -1759,6 +1759,7 @@ dependencies = [ name = "linkup-local-server" version = "0.1.0" dependencies = [ + "async-trait", "axum 0.8.1", "axum-server", "futures", @@ -1772,6 +1773,8 @@ dependencies = [ "rustls", "rustls-native-certs", "rustls-pemfile", + "serde", + "serde_json", "thiserror 2.0.11", "tokio", "tokio-tungstenite 0.28.0", @@ -2149,9 +2152,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.68" +version = "0.10.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" +checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" dependencies = [ "bitflags", "cfg-if", @@ -2181,9 +2184,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.104" +version = "0.9.111" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741" +checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" dependencies = [ "cc", "libc", diff --git a/linkup-cli/Cargo.toml b/linkup-cli/Cargo.toml index 98242419..963b0452 100644 --- a/linkup-cli/Cargo.toml +++ b/linkup-cli/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "linkup-cli" version = "3.5.0" -edition = "2021" +edition = "2024" build = "build.rs" [[bin]] diff --git a/linkup-cli/src/commands/completion.rs b/linkup-cli/src/commands/completion.rs index 7db1fe39..ea1be4b9 100644 --- a/linkup-cli/src/commands/completion.rs +++ b/linkup-cli/src/commands/completion.rs @@ -1,7 +1,7 @@ use std::io::stdout; use clap::{Command, CommandFactory}; -use clap_complete::{generate, Generator, Shell}; +use clap_complete::{Generator, Shell, generate}; use crate::{Cli, Result}; @@ -19,6 +19,11 @@ pub fn completion(args: &Args) -> Result<()> { Ok(()) } -fn print_completions(gen: &G, cmd: &mut Command) { - generate(gen.clone(), cmd, cmd.get_name().to_string(), &mut stdout()); +fn print_completions(generator: &G, cmd: &mut Command) { + generate( + generator.clone(), + cmd, + cmd.get_name().to_string(), + &mut stdout(), + ); } diff --git a/linkup-cli/src/commands/deploy/api.rs b/linkup-cli/src/commands/deploy/api.rs index a92c1cd5..3c80bd85 100644 --- a/linkup-cli/src/commands/deploy/api.rs +++ b/linkup-cli/src/commands/deploy/api.rs @@ -1,11 +1,11 @@ -use reqwest::{multipart, Client}; +use reqwest::{Client, multipart}; use serde::{Deserialize, Serialize}; use serde_json::json; use super::{ + DeployError, auth::CloudflareApiAuth, resources::{DNSRecord, Rule, WorkerMetadata, WorkerScriptInfo, WorkerScriptPart}, - DeployError, }; pub trait CloudflareApi { @@ -795,20 +795,19 @@ impl CloudflareApi for AccountCloudflareApi { return Err(DeployError::OtherError); } - if let Some(records) = data.result { - if let Some(r) = records + if let Some(records) = data.result + && let Some(r) = records .into_iter() .find(|r| r.comment == Some(comment.clone())) - { - return Ok(Some(DNSRecord { - id: r.id, - name: r.name, - record_type: r.record_type, - content: r.content, - comment: comment.clone(), - proxied: r.proxied.unwrap_or(false), - })); - } + { + return Ok(Some(DNSRecord { + id: r.id, + name: r.name, + record_type: r.record_type, + content: r.content, + comment: comment.clone(), + proxied: r.proxied.unwrap_or(false), + })); } Ok(None) diff --git a/linkup-cli/src/commands/deploy/cf_deploy.rs b/linkup-cli/src/commands/deploy/cf_deploy.rs index 821920ce..a099532e 100644 --- a/linkup-cli/src/commands/deploy/cf_deploy.rs +++ b/linkup-cli/src/commands/deploy/cf_deploy.rs @@ -1,6 +1,6 @@ +use crate::Result; use crate::commands::deploy::auth; use crate::commands::deploy::resources::cf_resources; -use crate::Result; use super::api::{AccountCloudflareApi, CloudflareApi}; use super::console_notify::ConsoleNotifier; @@ -125,7 +125,7 @@ pub async fn deploy_to_cloudflare( #[cfg(test)] mod tests { use cloudflare::framework::{ - async_api::Client, auth, endpoint::spec::EndpointSpec, Environment, HttpApiClientConfig, + Environment, HttpApiClientConfig, async_api::Client, auth, endpoint::spec::EndpointSpec, }; use mockito::ServerGuard; use std::cell::RefCell; @@ -135,8 +135,9 @@ mod tests { api::Token, cf_destroy::destroy_from_cloudflare, resources::{ - rules_equal, DNSRecord, KvNamespace, Rule, TargectCfZoneResources, TargetCacheRules, + DNSRecord, KvNamespace, Rule, TargectCfZoneResources, TargetCacheRules, TargetDNSRecord, TargetWorkerRoute, WorkerMetadata, WorkerScriptInfo, WorkerScriptPart, + rules_equal, }, }; @@ -651,9 +652,11 @@ export default { let dns_records = api.dns_records.borrow(); assert_eq!(dns_records.len(), 1); assert_eq!(dns_records[0].name, "linkup-integration-test"); - assert!(dns_records[0] - .content - .contains("linkup-integration-test-script.workers.dev")); + assert!( + dns_records[0] + .content + .contains("linkup-integration-test-script.workers.dev") + ); // Check route created let routes = api.worker_routes.borrow(); diff --git a/linkup-cli/src/commands/deploy/cf_destroy.rs b/linkup-cli/src/commands/deploy/cf_destroy.rs index b0fb9d1d..46d012b0 100644 --- a/linkup-cli/src/commands/deploy/cf_destroy.rs +++ b/linkup-cli/src/commands/deploy/cf_destroy.rs @@ -1,7 +1,7 @@ +use crate::Result; use crate::commands::deploy::{ api::AccountCloudflareApi, auth, console_notify::ConsoleNotifier, resources::cf_resources, }; -use crate::Result; use super::{api::CloudflareApi, cf_deploy::DeployNotifier, resources::TargetCfResources}; diff --git a/linkup-cli/src/commands/deploy/mod.rs b/linkup-cli/src/commands/deploy/mod.rs index 5546aa0b..cb03ce3e 100644 --- a/linkup-cli/src/commands/deploy/mod.rs +++ b/linkup-cli/src/commands/deploy/mod.rs @@ -5,5 +5,5 @@ mod cf_destroy; mod console_notify; mod resources; -pub use cf_deploy::{deploy, DeployArgs, DeployError}; -pub use cf_destroy::{destroy, DestroyArgs}; +pub use cf_deploy::{DeployArgs, DeployError, deploy}; +pub use cf_destroy::{DestroyArgs, destroy}; diff --git a/linkup-cli/src/commands/deploy/resources.rs b/linkup-cli/src/commands/deploy/resources.rs index 0aa3c87c..00895d39 100644 --- a/linkup-cli/src/commands/deploy/resources.rs +++ b/linkup-cli/src/commands/deploy/resources.rs @@ -5,7 +5,7 @@ use reqwest::StatusCode; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; -use super::{api::CloudflareApi, cf_deploy::DeployNotifier, DeployError}; +use super::{DeployError, api::CloudflareApi, cf_deploy::DeployNotifier}; const LINKUP_WORKER_SHIM: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/shim.mjs")); const LINKUP_WORKER_INDEX_WASM: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/index.wasm")); @@ -527,7 +527,7 @@ impl TargetCfResources { let bindings = match client.request(&req).await { Ok(response) => response.result, Err(cloudflare::framework::response::ApiFailure::Error(StatusCode::NOT_FOUND, _)) => { - return Ok(None) + return Ok(None); } Err(error) => return Err(DeployError::from(error)), }; @@ -536,10 +536,10 @@ impl TargetCfResources { use cloudflare::endpoints::workers::WorkersBinding; // NOTE(augustoccesar)[2025-02-26]: We are saving WORKER_TOKEN as plain text, so we don't need other binding types - if let WorkersBinding::PlainText { name, text } = binding { - if name == "WORKER_TOKEN" { - return Ok(Some(text)); - } + if let WorkersBinding::PlainText { name, text } = binding + && name == "WORKER_TOKEN" + { + return Ok(Some(text)); } } @@ -642,25 +642,21 @@ impl TargetCfResources { name, namespace_id, } = binding + && *name == kv_namespace.binding { - if *name == kv_namespace.binding { - *namespace_id = kv_ns_id.clone(); - break; - } + *namespace_id = kv_ns_id.clone(); + break; } } } if let Some(token) = token { for binding in final_metadata.bindings.iter_mut() { - if let cloudflare::endpoints::workers::WorkersBinding::SecretText { - name, - text, - } = binding + if let cloudflare::endpoints::workers::WorkersBinding::SecretText { name, text } = + binding + && *name == "CLOUDFLARE_API_TOKEN" { - if *name == "CLOUDFLARE_API_TOKEN" { - *text = Some(token.clone()); - } + *text = Some(token.clone()); } } } diff --git a/linkup-cli/src/commands/health.rs b/linkup-cli/src/commands/health.rs index 6a625d11..ae56827d 100644 --- a/linkup-cli/src/commands/health.rs +++ b/linkup-cli/src/commands/health.rs @@ -9,10 +9,9 @@ use std::{ }; use crate::{ - linkup_dir_path, - local_config::LocalState, - services::{self, find_service_pid, BackgroundService}, - Result, + Result, linkup_dir_path, + services::{self, BackgroundService}, + state::State, }; use super::local_dns; @@ -62,7 +61,7 @@ struct Session { } impl Session { - fn load(state: Option<&LocalState>) -> Self { + fn load(state: Option<&State>) -> Self { match state { Some(state) => Self { name: Some(state.linkup.session_name.clone()), @@ -90,23 +89,21 @@ struct OrphanProcess { pub struct BackgroundServices { pub linkup_server: BackgroundServiceHealth, cloudflared: BackgroundServiceHealth, - dns_server: BackgroundServiceHealth, possible_orphan_processes: Vec, } #[derive(Debug, Serialize)] pub enum BackgroundServiceHealth { - Unknown, NotInstalled, Stopped, Running(u32), } impl BackgroundServices { - pub fn load(state: Option<&LocalState>) -> Self { + pub fn load(_state: Option<&State>) -> Self { let mut managed_pids: Vec = Vec::with_capacity(4); - let linkup_server = match find_service_pid(services::LocalServer::ID) { + let linkup_server = match services::LocalServer::find_pid() { Some(pid) => { managed_pids.push(pid); @@ -116,7 +113,7 @@ impl BackgroundServices { }; let cloudflared = if services::is_cloudflared_installed() { - match find_service_pid(services::CloudflareTunnel::ID) { + match services::CloudflareTunnel::find_pid() { Some(pid) => { managed_pids.push(pid); @@ -128,33 +125,9 @@ impl BackgroundServices { BackgroundServiceHealth::NotInstalled }; - let dns_server = match find_service_pid(services::LocalDnsServer::ID) { - Some(pid) => { - managed_pids.push(pid); - - BackgroundServiceHealth::Running(pid.as_u32()) - } - None => match state { - // If there is no state, we cannot know if local-dns is installed since we depend on - // the domains listed on it. - Some(state) => { - if local_dns::is_installed(&crate::local_config::managed_domains( - Some(state), - &None, - )) { - BackgroundServiceHealth::Stopped - } else { - BackgroundServiceHealth::NotInstalled - } - } - None => BackgroundServiceHealth::Unknown, - }, - }; - Self { linkup_server, cloudflared, - dns_server, possible_orphan_processes: find_potential_orphan_processes(managed_pids), } } @@ -174,10 +147,10 @@ fn find_potential_orphan_processes(managed_pids: Vec) -> Vec) -> Result { + fn load(state: Option<&State>) -> Result { // If there is no state, we cannot know if local-dns is installed since we depend on // the domains listed on it. let is_installed = state.as_ref().map(|state| { - local_dns::is_installed(&crate::local_config::managed_domains(Some(state), &None)) + local_dns::is_installed(&crate::state::managed_domains(Some(state), &None)) }); Ok(Self { @@ -309,7 +282,7 @@ struct Health { impl Health { pub fn load() -> Result { - let state = LocalState::load().ok(); + let state = State::load().ok(); let session = Session::load(state.as_ref()); Ok(Self { @@ -360,15 +333,6 @@ impl Display for Health { BackgroundServiceHealth::NotInstalled => writeln!(f, "{}", "NOT INSTALLED".yellow())?, BackgroundServiceHealth::Stopped => writeln!(f, "{}", "NOT RUNNING".yellow())?, BackgroundServiceHealth::Running(pid) => writeln!(f, "{} ({})", "RUNNING".blue(), pid)?, - BackgroundServiceHealth::Unknown => writeln!(f, "{}", "UNKNOWN".yellow())?, - } - - write!(f, " - DNS Server ")?; - match &self.background_services.dns_server { - BackgroundServiceHealth::NotInstalled => writeln!(f, "{}", "NOT INSTALLED".yellow())?, - BackgroundServiceHealth::Stopped => writeln!(f, "{}", "NOT RUNNING".yellow())?, - BackgroundServiceHealth::Running(pid) => writeln!(f, "{} ({})", "RUNNING".blue(), pid)?, - BackgroundServiceHealth::Unknown => writeln!(f, "{}", "UNKNOWN".yellow())?, } write!(f, " - Cloudflared ")?; @@ -376,7 +340,6 @@ impl Display for Health { BackgroundServiceHealth::NotInstalled => writeln!(f, "{}", "NOT INSTALLED".yellow())?, BackgroundServiceHealth::Stopped => writeln!(f, "{}", "NOT RUNNING".yellow())?, BackgroundServiceHealth::Running(pid) => writeln!(f, "{} ({})", "RUNNING".blue(), pid)?, - BackgroundServiceHealth::Unknown => writeln!(f, "{}", "UNKNOWN".yellow())?, } writeln!(f, "{}", "Linkup:".bold().italic())?; diff --git a/linkup-cli/src/commands/local.rs b/linkup-cli/src/commands/local.rs index f994f1d6..5c6e3b8d 100644 --- a/linkup-cli/src/commands/local.rs +++ b/linkup-cli/src/commands/local.rs @@ -2,9 +2,9 @@ use anyhow::anyhow; use colored::Colorize; use crate::{ - local_config::{upload_state, LocalState, ServiceTarget}, - services::{self, find_service_pid, BackgroundService}, Result, + services::{self, BackgroundService}, + state::{ServiceTarget, State, upload_state}, }; #[derive(clap::Args)] @@ -25,7 +25,7 @@ pub async fn local(args: &Args) -> Result<()> { return Err(anyhow!("No service names provided")); } - if !LocalState::exists() { + if !State::exists() { println!( "{}", "Seems like you don't have any state yet to point to local.".yellow() @@ -35,7 +35,7 @@ pub async fn local(args: &Args) -> Result<()> { return Ok(()); } - if find_service_pid(services::LocalServer::ID).is_none() { + if services::LocalServer::find_pid().is_none() { println!( "{}", "Seems like your local Linkup server is not running. Please run 'linkup start' first." @@ -45,7 +45,7 @@ pub async fn local(args: &Args) -> Result<()> { return Ok(()); } - let mut state = LocalState::load()?; + let mut state = State::load()?; if args.all { for service in state.services.iter_mut() { @@ -56,7 +56,7 @@ pub async fn local(args: &Args) -> Result<()> { let service = state .services .iter_mut() - .find(|s| s.name.as_str() == service_name) + .find(|s| s.config.name.as_str() == service_name) .ok_or_else(|| anyhow!("Service with name '{}' does not exist", service_name))?; service.current = ServiceTarget::Local; diff --git a/linkup-cli/src/commands/local_dns.rs b/linkup-cli/src/commands/local_dns.rs index 0a1560aa..5308b30b 100644 --- a/linkup-cli/src/commands/local_dns.rs +++ b/linkup-cli/src/commands/local_dns.rs @@ -4,11 +4,11 @@ use std::{ }; use crate::{ - commands, is_sudo, linkup_certs_dir_path, - local_config::{self, managed_domains, top_level_domains, LocalState}, - sudo_su, Result, + Result, commands, is_sudo, linkup_certs_dir_path, + state::{self, State, managed_domains, top_level_domains}, + sudo_su, }; -use anyhow::{anyhow, Context}; +use anyhow::{Context, anyhow}; use clap::Subcommand; use linkup_local_server::certificates::{ setup_self_signed_certificates, uninstall_self_signed_certificates, @@ -50,7 +50,7 @@ pub async fn install(config_arg: &Option) -> Result<()> { ensure_resolver_dir()?; - let domains = managed_domains(LocalState::load().ok().as_ref(), config_arg); + let domains = managed_domains(State::load().ok().as_ref(), config_arg); install_resolvers(&top_level_domains(&domains))?; @@ -76,9 +76,10 @@ pub async fn uninstall(config_arg: &Option) -> Result<()> { commands::stop(&commands::StopArgs {}, false)?; - let managed_top_level_domains = local_config::top_level_domains( - &local_config::managed_domains(LocalState::load().ok().as_ref(), config_arg), - ); + let managed_top_level_domains = state::top_level_domains(&state::managed_domains( + State::load().ok().as_ref(), + config_arg, + )); uninstall_resolvers(&managed_top_level_domains)?; uninstall_self_signed_certificates(&linkup_certs_dir_path()) diff --git a/linkup-cli/src/commands/mod.rs b/linkup-cli/src/commands/mod.rs index 808ab18f..a27d1a46 100644 --- a/linkup-cli/src/commands/mod.rs +++ b/linkup-cli/src/commands/mod.rs @@ -13,18 +13,18 @@ pub mod stop; pub mod uninstall; pub mod update; -pub use {completion::completion, completion::Args as CompletionArgs}; -pub use {deploy::deploy, deploy::DeployArgs}; -pub use {deploy::destroy, deploy::DestroyArgs}; -pub use {health::health, health::Args as HealthArgs}; -pub use {local::local, local::Args as LocalArgs}; -pub use {local_dns::local_dns, local_dns::Args as LocalDnsArgs}; -pub use {preview::preview, preview::Args as PreviewArgs}; -pub use {remote::remote, remote::Args as RemoteArgs}; -pub use {reset::reset, reset::Args as ResetArgs}; -pub use {server::server, server::Args as ServerArgs}; -pub use {start::start, start::Args as StartArgs}; -pub use {status::status, status::Args as StatusArgs}; -pub use {stop::stop, stop::Args as StopArgs}; -pub use {uninstall::uninstall, uninstall::Args as UninstallArgs}; -pub use {update::update, update::Args as UpdateArgs}; +pub use {completion::Args as CompletionArgs, completion::completion}; +pub use {deploy::DeployArgs, deploy::deploy}; +pub use {deploy::DestroyArgs, deploy::destroy}; +pub use {health::Args as HealthArgs, health::health}; +pub use {local::Args as LocalArgs, local::local}; +pub use {local_dns::Args as LocalDnsArgs, local_dns::local_dns}; +pub use {preview::Args as PreviewArgs, preview::preview}; +pub use {remote::Args as RemoteArgs, remote::remote}; +pub use {reset::Args as ResetArgs, reset::reset}; +pub use {server::Args as ServerArgs, server::server}; +pub use {start::Args as StartArgs, start::start}; +pub use {status::Args as StatusArgs, status::status}; +pub use {stop::Args as StopArgs, stop::stop}; +pub use {uninstall::Args as UninstallArgs, uninstall::uninstall}; +pub use {update::Args as UpdateArgs, update::update}; diff --git a/linkup-cli/src/commands/preview.rs b/linkup-cli/src/commands/preview.rs index b26a9d83..4436b7a7 100644 --- a/linkup-cli/src/commands/preview.rs +++ b/linkup-cli/src/commands/preview.rs @@ -1,10 +1,11 @@ -use crate::commands::status::{format_state_domains, SessionStatus}; -use crate::local_config::{config_path, get_config}; -use crate::worker_client::WorkerClient; use crate::Result; +use crate::commands::status::{SessionStatus, format_state_domains}; +use crate::state::{config_path, get_config}; +use crate::worker_client::WorkerClient; use anyhow::Context; use clap::builder::ValueParser; use linkup::CreatePreviewRequest; +use url::Url; #[derive(clap::Args)] pub struct Args { @@ -14,7 +15,7 @@ pub struct Args { required = true, num_args = 1.., )] - services: Vec<(String, String)>, + services: Vec<(String, Url)>, #[arg(long, help = "Print the request body instead of sending it.")] print_request: bool, @@ -24,7 +25,7 @@ pub async fn preview(args: &Args, config: &Option) -> Result<()> { let config_path = config_path(config)?; let input_config = get_config(&config_path)?; let create_preview_request: CreatePreviewRequest = - input_config.create_preview_request(&args.services); + linkup::create_preview_req_from_config(&input_config, &args.services); let url = input_config.linkup.worker_url.clone(); if args.print_request { diff --git a/linkup-cli/src/commands/remote.rs b/linkup-cli/src/commands/remote.rs index 2fd638f6..196db287 100644 --- a/linkup-cli/src/commands/remote.rs +++ b/linkup-cli/src/commands/remote.rs @@ -1,7 +1,7 @@ use crate::{ - local_config::{upload_state, LocalState, ServiceTarget}, - services::{self, find_service_pid, BackgroundService}, Result, + services::{self, BackgroundService}, + state::{ServiceTarget, State, upload_state}, }; use anyhow::anyhow; @@ -25,7 +25,7 @@ pub async fn remote(args: &Args) -> Result<()> { return Err(anyhow!("No service names provided")); } - if !LocalState::exists() { + if !State::exists() { println!( "{}", "Seems like you don't have any state yet to point to remote.".yellow() @@ -35,9 +35,9 @@ pub async fn remote(args: &Args) -> Result<()> { return Ok(()); } - let mut state = LocalState::load()?; + let mut state = State::load()?; - if find_service_pid(services::LocalServer::ID).is_none() { + if services::LocalServer::find_pid().is_none() { println!( "{}", "Seems like your local Linkup server is not running. Please run 'linkup start' first." @@ -56,7 +56,7 @@ pub async fn remote(args: &Args) -> Result<()> { let service = state .services .iter_mut() - .find(|s| s.name.as_str() == service_name) + .find(|s| s.config.name.as_str() == service_name) .ok_or_else(|| anyhow!("Service with name '{}' does not exist", service_name))?; service.current = ServiceTarget::Remote; diff --git a/linkup-cli/src/commands/reset.rs b/linkup-cli/src/commands/reset.rs index b700d9cf..c1584321 100644 --- a/linkup-cli/src/commands/reset.rs +++ b/linkup-cli/src/commands/reset.rs @@ -1,10 +1,10 @@ -use crate::{commands, local_config::LocalState, Result}; +use crate::{Result, commands, state::State}; #[derive(clap::Args)] pub struct Args {} pub async fn reset(_args: &Args) -> Result<()> { - let _ = LocalState::load()?; + let _ = State::load()?; commands::stop(&commands::StopArgs {}, false)?; commands::start(&commands::StartArgs { no_tunnel: false }, false, &None).await?; diff --git a/linkup-cli/src/commands/server.rs b/linkup-cli/src/commands/server.rs index fa488143..4b1e260b 100644 --- a/linkup-cli/src/commands/server.rs +++ b/linkup-cli/src/commands/server.rs @@ -1,78 +1,19 @@ +use std::path::PathBuf; + use crate::Result; use linkup::MemoryStringStore; -use tokio::select; #[derive(clap::Args)] pub struct Args { - #[command(subcommand)] - server_kind: ServerKind, -} - -#[derive(clap::Subcommand)] -pub enum ServerKind { - LocalWorker { - #[arg(long)] - certs_dir: String, - }, - - Dns { - #[arg(long)] - session_name: String, - #[arg(long, value_parser, num_args = 1.., value_delimiter = ',')] - domains: Vec, - }, + #[arg(long)] + certs_dir: String, } pub async fn server(args: &Args) -> Result<()> { - match &args.server_kind { - ServerKind::LocalWorker { certs_dir } => { - let config_store = MemoryStringStore::default(); - - let http_config_store = config_store.clone(); - let handler_http = tokio::spawn(async move { - linkup_local_server::start_server_http(http_config_store) - .await - .unwrap(); - }); - - let handler_https = { - use std::path::PathBuf; - - let https_config_store = config_store.clone(); - let https_certs_dir = PathBuf::from(certs_dir); - - Some(tokio::spawn(async move { - linkup_local_server::start_server_https(https_config_store, &https_certs_dir) - .await; - })) - }; - - match handler_https { - Some(handler_https) => { - select! { - _ = handler_http => (), - _ = handler_https => (), - } - } - None => { - handler_http.await.unwrap(); - } - } - } - ServerKind::Dns { - session_name, - domains, - } => { - let session_name = session_name.clone(); - let domains = domains.clone(); - - let handler_dns = tokio::spawn(async move { - linkup_local_server::start_dns_server(session_name, domains).await; - }); + let config_store = MemoryStringStore::default(); + let https_certs_dir = PathBuf::from(&args.certs_dir); - handler_dns.await.unwrap(); - } - } + linkup_local_server::start(config_store, &https_certs_dir).await; Ok(()) } diff --git a/linkup-cli/src/commands/start.rs b/linkup-cli/src/commands/start.rs index 6ba8347e..1059462b 100644 --- a/linkup-cli/src/commands/start.rs +++ b/linkup-cli/src/commands/start.rs @@ -4,21 +4,21 @@ use std::{ io::stdout, path::{Path, PathBuf}, sync, - thread::{self, sleep, JoinHandle}, + thread::{self, JoinHandle, sleep}, time::Duration, }; -use anyhow::{anyhow, Context, Error}; +use anyhow::{Context, Error, anyhow}; use colored::Colorize; -use crossterm::{cursor, ExecutableCommand}; +use crossterm::{ExecutableCommand, cursor}; +use crate::{Result, state::State}; use crate::{ - commands::status::{format_state_domains, SessionStatus}, + commands::status::{SessionStatus, format_state_domains}, env_files::write_to_env_file, - local_config::{config_path, config_to_state, get_config}, services::{self, BackgroundService}, + state::{config_path, config_to_state, get_config}, }; -use crate::{local_config::LocalState, Result}; const LOADING_CHARS: [char; 10] = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']; @@ -39,14 +39,13 @@ pub async fn start(args: &Args, fresh_state: bool, config_arg: &Option) state } else { - LocalState::load()? + State::load()? }; let status_update_channel = sync::mpsc::channel::(); let local_server = services::LocalServer::new(); let cloudflare_tunnel = services::CloudflareTunnel::new(); - let local_dns_server = services::LocalDnsServer::new(); let mut display_thread: Option> = None; let display_channel = sync::mpsc::channel::(); @@ -59,7 +58,6 @@ pub async fn start(args: &Args, fresh_state: bool, config_arg: &Option) &[ services::LocalServer::NAME, services::CloudflareTunnel::NAME, - services::LocalDnsServer::NAME, ], status_update_channel.1, display_channel.1, @@ -89,16 +87,6 @@ pub async fn start(args: &Args, fresh_state: bool, config_arg: &Option) } } - if exit_error.is_none() { - match local_dns_server - .run_with_progress(&mut state, status_update_channel.0.clone()) - .await - { - Ok(_) => (), - Err(err) => exit_error = Some(err), - } - } - if let Some(display_thread) = display_thread { display_channel.0.send(true).unwrap(); display_thread.join().unwrap(); @@ -217,18 +205,18 @@ fn spawn_display_thread( }) } -fn set_linkup_env(state: LocalState) -> Result<()> { +fn set_linkup_env(state: State) -> Result<()> { // Set env vars to linkup for service in &state.services { - if let Some(d) = &service.directory { + if let Some(d) = &service.config.directory { set_service_env(d.clone(), state.linkup.config_path.clone())? } } Ok(()) } -fn load_and_save_state(config_arg: &Option, no_tunnel: bool) -> Result { - let previous_state = LocalState::load(); +fn load_and_save_state(config_arg: &Option, no_tunnel: bool) -> Result { + let previous_state = State::load(); let config_path = config_path(config_arg)?; let input_config = get_config(&config_path)?; diff --git a/linkup-cli/src/commands/status.rs b/linkup-cli/src/commands/status.rs index 7b22a5ac..decab535 100644 --- a/linkup-cli/src/commands/status.rs +++ b/linkup-cli/src/commands/status.rs @@ -1,7 +1,7 @@ use anyhow::Context; use colored::{ColoredString, Colorize}; use crossterm::{cursor, execute, style::Print, terminal}; -use linkup::{get_additional_headers, HeaderMap, StorableDomain, TargetService}; +use linkup::{Domain, HeaderMap, TargetService, config::HealthConfig, get_additional_headers}; use serde::{Deserialize, Serialize}; use std::{ io::stdout, @@ -12,9 +12,8 @@ use std::{ }; use crate::{ - commands, - local_config::{HealthConfig, LocalService, LocalState, ServiceTarget}, - services, + commands, services, + state::{LocalService, ServiceTarget, State}, }; const LOADING_CHARS: [char; 10] = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']; @@ -26,22 +25,10 @@ pub struct Args { // Output status in JSON format #[arg(long)] pub json: bool, - - #[arg(short, long)] - all: bool, } pub fn status(args: &Args) -> anyhow::Result<()> { - // TODO(augustocesar)[2024-10-28]: Remove --all/-a in a future release. - // Do not print the warning in case of JSON so it doesn't break any usage if the result of the command - // is passed on to somewhere else. - if args.all && !args.json { - let warning = "--all/-a is a noop now. All services statuses will always be shown. \ - This arg will be removed in a future release.\n"; - println!("{}", warning.yellow()); - } - - if !LocalState::exists() { + if !State::exists() { println!( "{}", "Seems like you don't have any state yet, so there is no status to report.".yellow() @@ -51,7 +38,7 @@ pub fn status(args: &Args) -> anyhow::Result<()> { return Ok(()); } - let state = LocalState::load().context("Failed to load local state")?; + let state = State::load().context("Failed to load local state")?; let linkup_services = linkup_services(&state); let all_services = state.clone().services.into_iter().chain(linkup_services); @@ -269,7 +256,7 @@ fn table_header(terminal_width: u16) -> String { output } -pub fn format_state_domains(session_name: &str, domains: &[StorableDomain]) -> Vec { +pub fn format_state_domains(session_name: &str, domains: &[Domain]) -> Vec { // Filter out domains that are subdomains of other domains let filtered_domains = domains .iter() @@ -287,45 +274,51 @@ pub fn format_state_domains(session_name: &str, domains: &[StorableDomain]) -> V .collect() } -fn linkup_services(state: &LocalState) -> Vec { +fn linkup_services(state: &State) -> Vec { let local_url = services::LocalServer::url(); vec![ LocalService { - name: "linkup_local_server".to_string(), - remote: local_url.clone(), - local: local_url.clone(), current: ServiceTarget::Local, - directory: None, - rewrites: vec![], - health: Some(HealthConfig { - path: Some("/linkup/check".to_string()), - ..Default::default() - }), + config: linkup::config::ServiceConfig { + name: "linkup_local_server".to_string(), + remote: local_url.clone(), + local: local_url.clone(), + directory: None, + rewrites: None, + health: Some(HealthConfig { + path: Some("/linkup/check".to_string()), + ..Default::default() + }), + }, }, LocalService { - name: "linkup_remote_server".to_string(), - remote: state.linkup.worker_url.clone(), - local: state.linkup.worker_url.clone(), current: ServiceTarget::Remote, - directory: None, - rewrites: vec![], - health: Some(HealthConfig { - path: Some("/linkup/check".to_string()), - ..Default::default() - }), + config: linkup::config::ServiceConfig { + name: "linkup_remote_server".to_string(), + remote: state.linkup.worker_url.clone(), + local: state.linkup.worker_url.clone(), + directory: None, + rewrites: None, + health: Some(HealthConfig { + path: Some("/linkup/check".to_string()), + ..Default::default() + }), + }, }, LocalService { - name: "tunnel".to_string(), - remote: state.get_tunnel_url(), - local: state.get_tunnel_url(), current: ServiceTarget::Remote, - directory: None, - rewrites: vec![], - health: Some(HealthConfig { - path: Some("/linkup/check".to_string()), - ..Default::default() - }), + config: linkup::config::ServiceConfig { + name: "tunnel".to_string(), + remote: state.get_tunnel_url(), + local: state.get_tunnel_url(), + directory: None, + rewrites: None, + health: Some(HealthConfig { + path: Some("/linkup/check".to_string()), + ..Default::default() + }), + }, }, ] } @@ -334,7 +327,7 @@ fn service_status(service: &LocalService, session_name: &str) -> ServerStatus { let mut acceptable_statuses_override: Option> = None; let mut url = service.current_url(); - if let Some(health_config) = &service.health { + if let Some(health_config) = &service.config.health { if let Some(path) = &health_config.path { url = url.join(path).unwrap(); } @@ -349,7 +342,7 @@ fn service_status(service: &LocalService, session_name: &str) -> ServerStatus { &HeaderMap::new(), session_name, &TargetService { - name: service.name.clone(), + name: service.config.name.clone(), url: url.to_string(), }, ); @@ -424,7 +417,7 @@ where let priority = service_priority(&service); ServiceStatus { - name: service.name.clone(), + name: service.config.name.clone(), component_kind: service.current.to_string(), status: ServerStatus::Loading, service, @@ -443,7 +436,7 @@ where thread::spawn(move || { let status = service_status(&service_clone, &session_name); - tx.send((service_clone.name.clone(), status)) + tx.send((service_clone.config.name.clone(), status)) .expect("Failed to send service status"); }); } @@ -454,15 +447,13 @@ where } fn is_internal_service(service: &LocalService) -> bool { - service.name == "linkup_local_server" - || service.name == "linkup_remote_server" - || service.name == "tunnel" + let service_name = &service.config.name; + + service_name == "linkup_local_server" + || service_name == "linkup_remote_server" + || service_name == "tunnel" } fn service_priority(service: &LocalService) -> i8 { - if is_internal_service(service) { - 1 - } else { - 2 - } + if is_internal_service(service) { 1 } else { 2 } } diff --git a/linkup-cli/src/commands/stop.rs b/linkup-cli/src/commands/stop.rs index 2bbbb06a..4d2a8e7f 100644 --- a/linkup-cli/src/commands/stop.rs +++ b/linkup-cli/src/commands/stop.rs @@ -4,25 +4,28 @@ use std::path::{Path, PathBuf}; use anyhow::Context; use crate::env_files::clear_env_file; -use crate::local_config::LocalState; -use crate::services::{stop_service, BackgroundService}; -use crate::{services, Result}; +use crate::services::BackgroundService; +use crate::state::State; +use crate::{Result, services}; #[derive(clap::Args)] pub struct Args {} pub fn stop(_args: &Args, clear_env: bool) -> Result<()> { - match (LocalState::load(), clear_env) { + match (State::load(), clear_env) { (Ok(state), true) => { // Reset env vars back to what they were before for service in &state.services { - let remove_res = match &service.directory { + let remove_res = match &service.config.directory { Some(d) => remove_service_env(d.clone(), state.linkup.config_path.clone()), None => Ok(()), }; if let Err(e) = remove_res { - println!("Could not remove env for service {}: {}", service.name, e); + println!( + "Could not remove env for service {}: {}", + service.config.name, e + ); } } } @@ -32,9 +35,8 @@ pub fn stop(_args: &Args, clear_env: bool) -> Result<()> { } } - stop_service(services::LocalServer::ID); - stop_service(services::CloudflareTunnel::ID); - stop_service(services::LocalDnsServer::ID); + services::LocalServer::stop(); + services::CloudflareTunnel::stop(); println!("Stopped linkup"); diff --git a/linkup-cli/src/commands/uninstall.rs b/linkup-cli/src/commands/uninstall.rs index d4f1b6f1..ed54d70d 100644 --- a/linkup-cli/src/commands/uninstall.rs +++ b/linkup-cli/src/commands/uninstall.rs @@ -1,8 +1,8 @@ use std::{fs, process}; use crate::{ - commands, commands::local_dns, linkup_dir_path, linkup_exe_path, local_config::managed_domains, - local_config::LocalState, prompt, InstallationMethod, Result, + InstallationMethod, Result, commands, commands::local_dns, linkup_dir_path, linkup_exe_path, + prompt, state::State, state::managed_domains, }; #[cfg(target_os = "linux")] @@ -24,10 +24,7 @@ pub async fn uninstall(_args: &Args, config_arg: &Option) -> Result<()> commands::stop(&commands::StopArgs {}, true)?; - if local_dns::is_installed(&managed_domains( - LocalState::load().ok().as_ref(), - config_arg, - )) { + if local_dns::is_installed(&managed_domains(State::load().ok().as_ref(), config_arg)) { local_dns::uninstall(config_arg).await?; } diff --git a/linkup-cli/src/commands/update.rs b/linkup-cli/src/commands/update.rs index d228c113..9b7b5a36 100644 --- a/linkup-cli/src/commands/update.rs +++ b/linkup-cli/src/commands/update.rs @@ -2,7 +2,7 @@ use anyhow::Context; #[cfg(not(target_os = "linux"))] use std::fs; -use crate::{commands, current_version, linkup_exe_path, release, InstallationMethod, Result}; +use crate::{InstallationMethod, Result, commands, current_version, linkup_exe_path, release}; #[cfg(target_os = "linux")] use crate::{is_sudo, sudo_su}; diff --git a/linkup-cli/src/env_files.rs b/linkup-cli/src/env_files.rs index 80f756ee..59925a4f 100644 --- a/linkup-cli/src/env_files.rs +++ b/linkup-cli/src/env_files.rs @@ -11,10 +11,10 @@ use crate::Result; const LINKUP_ENV_SEPARATOR: &str = "##### Linkup environment - DO NOT EDIT #####"; pub fn write_to_env_file(service: &str, dev_env_path: &PathBuf, env_path: &PathBuf) -> Result<()> { - if let Ok(env_content) = fs::read_to_string(env_path) { - if env_content.contains(LINKUP_ENV_SEPARATOR) { - return Ok(()); - } + if let Ok(env_content) = fs::read_to_string(env_path) + && env_content.contains(LINKUP_ENV_SEPARATOR) + { + return Ok(()); } let mut dev_env_content = fs::read_to_string(dev_env_path).with_context(|| { diff --git a/linkup-cli/src/main.rs b/linkup-cli/src/main.rs index 785d21d5..01cef11a 100644 --- a/linkup-cli/src/main.rs +++ b/linkup-cli/src/main.rs @@ -1,6 +1,6 @@ use std::{env, fs, io::ErrorKind, path::PathBuf}; -use anyhow::{anyhow, Context}; +use anyhow::{Context, anyhow}; use clap::{Parser, Subcommand}; use colored::Colorize; use thiserror::Error; @@ -10,9 +10,9 @@ pub use linkup::Version; mod commands; mod env_files; -mod local_config; mod release; mod services; +mod state; mod worker_client; const CURRENT_VERSION: &str = env!("CARGO_PKG_VERSION"); diff --git a/linkup-cli/src/release.rs b/linkup-cli/src/release.rs index b150837c..e2e23c03 100644 --- a/linkup-cli/src/release.rs +++ b/linkup-cli/src/release.rs @@ -4,7 +4,7 @@ mod github { use flate2::read::GzDecoder; use linkup::VersionError; use reqwest::header::HeaderValue; - use serde::{de::DeserializeOwned, Deserialize, Serialize}; + use serde::{Deserialize, Serialize, de::DeserializeOwned}; use tar::Archive; use url::Url; diff --git a/linkup-cli/src/services/cloudflare_tunnel.rs b/linkup-cli/src/services/cloudflare_tunnel.rs index 3d9ee091..707614c1 100644 --- a/linkup-cli/src/services/cloudflare_tunnel.rs +++ b/linkup-cli/src/services/cloudflare_tunnel.rs @@ -7,15 +7,15 @@ use std::{ time::Duration, }; -use hickory_resolver::{config::ResolverOpts, proto::rr::RecordType, TokioResolver}; +use hickory_resolver::{TokioResolver, config::ResolverOpts, proto::rr::RecordType}; use log::debug; use serde::{Deserialize, Serialize}; use tokio::time::sleep; use url::Url; -use crate::{linkup_file_path, local_config::LocalState, worker_client::WorkerClient, Result}; +use crate::{Result, linkup_file_path, state::State, worker_client::WorkerClient}; -use super::{find_service_pid, BackgroundService, PidError}; +use super::{BackgroundService, PidError}; #[derive(thiserror::Error, Debug)] #[allow(dead_code)] @@ -129,7 +129,7 @@ impl CloudflareTunnel { false } - fn update_state(&self, tunnel_url: &Url, state: &mut LocalState) -> Result<()> { + fn update_state(&self, tunnel_url: &Url, state: &mut State) -> Result<()> { debug!("Adding tunnel url {} to the state", tunnel_url.as_str()); state.linkup.tunnel = Some(tunnel_url.clone()); @@ -147,7 +147,7 @@ impl BackgroundService for CloudflareTunnel { async fn run_with_progress( &self, - state: &mut LocalState, + state: &mut State, status_sender: std::sync::mpsc::Sender, ) -> Result<()> { if !state.should_use_tunnel() { @@ -170,7 +170,7 @@ impl BackgroundService for CloudflareTunnel { return Err(Error::InvalidSessionName(state.linkup.session_name.clone()).into()); } - if find_service_pid(Self::ID).is_some() { + if Self::find_pid().is_some() { self.notify_update_with_details( &status_sender, super::RunStatus::Started, diff --git a/linkup-cli/src/services/local_dns_server.rs b/linkup-cli/src/services/local_dns_server.rs deleted file mode 100644 index 80638c08..00000000 --- a/linkup-cli/src/services/local_dns_server.rs +++ /dev/null @@ -1,97 +0,0 @@ -use std::{ - env, - fs::File, - os::unix::process::CommandExt, - path::PathBuf, - process::{self, Stdio}, -}; - -use anyhow::Context; - -use crate::{commands::local_dns, linkup_file_path, local_config::LocalState, Result}; - -use super::BackgroundService; - -pub struct LocalDnsServer { - stdout_file_path: PathBuf, - stderr_file_path: PathBuf, -} - -impl LocalDnsServer { - pub fn new() -> Self { - Self { - stdout_file_path: linkup_file_path("localdns-stdout"), - stderr_file_path: linkup_file_path("localdns-stderr"), - } - } - - fn start(&self, session_name: &str, domains: &[String]) -> Result<()> { - log::debug!("Starting {}", Self::NAME); - - let stdout_file = File::create(&self.stdout_file_path)?; - let stderr_file = File::create(&self.stderr_file_path)?; - - let mut command = process::Command::new( - env::current_exe().context("Failed to get the current executable")?, - ); - command.env("RUST_LOG", "debug"); - command.env("LINKUP_SERVICE_ID", Self::ID); - command.args([ - "server", - "dns", - "--session-name", - session_name, - "--domains", - &domains.join(","), - ]); - - command - .process_group(0) - .stdout(stdout_file) - .stderr(stderr_file) - .stdin(Stdio::null()) - .spawn()?; - - Ok(()) - } -} - -impl BackgroundService for LocalDnsServer { - const ID: &str = "linkup-local-dns-server"; - const NAME: &str = "Local DNS server"; - - async fn run_with_progress( - &self, - state: &mut LocalState, - status_sender: std::sync::mpsc::Sender, - ) -> Result<()> { - self.notify_update(&status_sender, super::RunStatus::Starting); - - let session_name = state.linkup.session_name.clone(); - let domains = state.domain_strings(); - - if !local_dns::is_installed(&domains) { - self.notify_update_with_details( - &status_sender, - super::RunStatus::Skipped, - "Not installed", - ); - - return Ok(()); - } - - if let Err(e) = self.start(&session_name, &domains) { - self.notify_update_with_details( - &status_sender, - super::RunStatus::Error, - "Failed to start", - ); - - return Err(e); - } - - self.notify_update(&status_sender, super::RunStatus::Started); - - Ok(()) - } -} diff --git a/linkup-cli/src/services/local_server.rs b/linkup-cli/src/services/local_server.rs index 7bbeb420..b6112040 100644 --- a/linkup-cli/src/services/local_server.rs +++ b/linkup-cli/src/services/local_server.rs @@ -13,9 +13,9 @@ use tokio::time::sleep; use url::Url; use crate::{ - linkup_certs_dir_path, linkup_file_path, - local_config::{upload_state, LocalState}, - worker_client, Result, + Result, linkup_certs_dir_path, linkup_file_path, + state::{State, upload_state}, + worker_client, }; use super::{BackgroundService, PidError}; @@ -66,7 +66,6 @@ impl LocalServer { command.env("LINKUP_SERVICE_ID", Self::ID); command.args([ "server", - "local-worker", "--certs-dir", linkup_certs_dir_path().to_str().unwrap(), ]); @@ -93,7 +92,7 @@ impl LocalServer { matches!(response, Ok(res) if res.status() == StatusCode::OK) } - async fn update_state(&self, state: &mut LocalState) -> Result<()> { + async fn update_state(&self, state: &mut State) -> Result<()> { let session_name = upload_state(state).await?; state.linkup.session_name = session_name; @@ -111,7 +110,7 @@ impl BackgroundService for LocalServer { async fn run_with_progress( &self, - state: &mut LocalState, + state: &mut State, status_sender: std::sync::mpsc::Sender, ) -> Result<()> { self.notify_update(&status_sender, super::RunStatus::Starting); diff --git a/linkup-cli/src/services/mod.rs b/linkup-cli/src/services/mod.rs index 3559ef4d..95f25a17 100644 --- a/linkup-cli/src/services/mod.rs +++ b/linkup-cli/src/services/mod.rs @@ -4,18 +4,16 @@ use sysinfo::{ProcessRefreshKind, RefreshKind, System}; use thiserror::Error; mod cloudflare_tunnel; -mod local_dns_server; mod local_server; -pub use local_dns_server::LocalDnsServer; pub use local_server::LocalServer; pub use sysinfo::{Pid, Signal}; pub use { - cloudflare_tunnel::is_installed as is_cloudflared_installed, cloudflare_tunnel::CloudflareTunnel, + cloudflare_tunnel::is_installed as is_cloudflared_installed, }; -use crate::local_config::LocalState; +use crate::state::State; #[derive(Clone)] pub enum RunStatus { @@ -26,16 +24,12 @@ pub enum RunStatus { Error, } -impl Display for RunStatus { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Pending => write!(f, "pending"), - Self::Starting => write!(f, "starting"), - Self::Started => write!(f, "started"), - Self::Skipped => write!(f, "skipped"), - Self::Error => write!(f, "error"), - } - } +#[derive(Error, Debug)] +pub enum PidError { + #[error("no pid file: {0}")] + NoPidFile(String), + #[error("bad pid file: {0}")] + BadPidFile(String), } #[derive(Clone)] @@ -51,10 +45,18 @@ pub trait BackgroundService { async fn run_with_progress( &self, - local_state: &mut LocalState, + local_state: &mut State, status_sender: sync::mpsc::Sender, ) -> anyhow::Result<()>; + fn stop() { + if let Some(pid) = Self::find_pid() { + system() + .process(pid) + .map(|process| process.kill_with(Signal::Interrupt)); + } + } + fn notify_update(&self, status_sender: &sync::mpsc::Sender, status: RunStatus) { status_sender .send(RunUpdate { @@ -79,35 +81,31 @@ pub trait BackgroundService { }) .unwrap(); } -} - -#[derive(Error, Debug)] -pub enum PidError { - #[error("no pid file: {0}")] - NoPidFile(String), - #[error("bad pid file: {0}")] - BadPidFile(String), -} -pub fn find_service_pid(service_id: &str) -> Option { - for (pid, process) in system().processes() { - if process - .environ() - .iter() - .any(|item| item.to_string_lossy() == format!("LINKUP_SERVICE_ID={service_id}")) - { - return Some(*pid); + fn find_pid() -> Option { + for (pid, process) in system().processes() { + if process + .environ() + .iter() + .any(|item| item.to_string_lossy() == format!("LINKUP_SERVICE_ID={}", Self::ID)) + { + return Some(*pid); + } } - } - None + None + } } -pub fn stop_service(service_id: &str) { - if let Some(pid) = find_service_pid(service_id) { - system() - .process(pid) - .map(|process| process.kill_with(Signal::Interrupt)); +impl Display for RunStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Pending => write!(f, "pending"), + Self::Starting => write!(f, "starting"), + Self::Started => write!(f, "started"), + Self::Skipped => write!(f, "skipped"), + Self::Error => write!(f, "error"), + } } } diff --git a/linkup-cli/src/local_config.rs b/linkup-cli/src/state.rs similarity index 61% rename from linkup-cli/src/local_config.rs rename to linkup-cli/src/state.rs index 4d37da45..e53eb0e8 100644 --- a/linkup-cli/src/local_config.rs +++ b/linkup-cli/src/state.rs @@ -6,28 +6,25 @@ use std::{ use anyhow::Context; use rand::distr::{Alphanumeric, SampleString}; +use regex::Regex; use serde::{Deserialize, Serialize}; use url::Url; -use linkup::{ - CreatePreviewRequest, StorableDomain, StorableRewrite, StorableService, StorableSession, - UpdateSessionRequest, -}; +use linkup::{Domain, Session, SessionService, UpdateSessionRequest}; use crate::{ - linkup_file_path, services, + LINKUP_CONFIG_ENV, LINKUP_STATE_FILE, Result, linkup_file_path, services, worker_client::{self, WorkerClient}, - Result, LINKUP_CONFIG_ENV, LINKUP_STATE_FILE, }; -#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] -pub struct LocalState { +#[derive(Deserialize, Serialize, Clone, Debug)] +pub struct State { pub linkup: LinkupState, - pub domains: Vec, + pub domains: Vec, pub services: Vec, } -impl LocalState { +impl State { pub fn load() -> anyhow::Result { let state_file_path = linkup_file_path(LINKUP_STATE_FILE); let content = fs::read_to_string(&state_file_path) @@ -70,7 +67,7 @@ impl LocalState { pub fn domain_strings(&self) -> Vec { self.domains .iter() - .map(|storable_domain| storable_domain.domain.clone()) + .map(|domain| domain.domain.clone()) .collect::>() } @@ -79,7 +76,7 @@ impl LocalState { } } -#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] +#[derive(Deserialize, Serialize, Clone, Debug)] pub struct LinkupState { pub session_name: String, pub session_token: String, @@ -87,31 +84,27 @@ pub struct LinkupState { pub worker_token: String, pub config_path: String, pub tunnel: Option, - pub cache_routes: Option>, -} - -#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Default)] -pub struct HealthConfig { - pub path: Option, - pub statuses: Option>, + #[serde( + default, + serialize_with = "linkup::serde_ext::serialize_opt_vec_regex", + deserialize_with = "linkup::serde_ext::deserialize_opt_vec_regex" + )] + pub cache_routes: Option>, } -#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] +#[derive(Deserialize, Serialize, Clone, Debug)] pub struct LocalService { - pub name: String, - pub remote: Url, - pub local: Url, pub current: ServiceTarget, - pub directory: Option, - pub rewrites: Vec, - pub health: Option, + + #[serde(flatten)] + pub config: linkup::config::ServiceConfig, } impl LocalService { pub fn current_url(&self) -> Url { match self.current { - ServiceTarget::Local => self.local.clone(), - ServiceTarget::Remote => self.remote.clone(), + ServiceTarget::Local => self.config.local.clone(), + ServiceTarget::Remote => self.config.remote.clone(), } } } @@ -131,72 +124,17 @@ impl Display for ServiceTarget { } } -#[derive(Deserialize, Clone)] -pub struct YamlLocalConfig { - pub linkup: LinkupConfig, - pub services: Vec, - pub domains: Vec, -} - -impl YamlLocalConfig { - pub fn create_preview_request(&self, services: &[(String, String)]) -> CreatePreviewRequest { - let services = self - .services - .iter() - .map(|yaml_local_service: &YamlLocalService| { - let name = yaml_local_service.name.clone(); - let mut location = yaml_local_service.remote.clone(); - - for (param_service_name, param_service_url) in services { - if param_service_name == &name { - location = Url::parse(param_service_url).unwrap(); - } - } - - StorableService { - name, - location, - rewrites: yaml_local_service.rewrites.clone(), - } - }) - .collect(); - - CreatePreviewRequest { - services, - domains: self.domains.clone(), - cache_routes: self.linkup.cache_routes.clone(), - } - } -} - -#[derive(Deserialize, Clone)] -pub struct LinkupConfig { - pub worker_url: Url, - pub worker_token: String, - cache_routes: Option>, -} - -#[derive(Deserialize, Clone)] -pub struct YamlLocalService { - name: String, - remote: Url, - local: Url, - directory: Option, - rewrites: Option>, - health: Option, -} - #[derive(Debug)] -pub struct ServerConfig { - pub local: StorableSession, - pub remote: StorableSession, +pub struct ServersSessions { + pub local: Session, + pub remote: Session, } pub fn config_to_state( - yaml_config: YamlLocalConfig, + config: linkup::config::Config, config_path: String, no_tunnel: bool, -) -> LocalState { +) -> State { let random_token = Alphanumeric.sample_string(&mut rand::rng(), 16); let tunnel = match no_tunnel { @@ -207,30 +145,25 @@ pub fn config_to_state( let linkup = LinkupState { session_name: String::new(), session_token: random_token, - worker_token: yaml_config.linkup.worker_token, + worker_token: config.linkup.worker_token, config_path, - worker_url: yaml_config.linkup.worker_url, + worker_url: config.linkup.worker_url, tunnel, - cache_routes: yaml_config.linkup.cache_routes, + cache_routes: config.linkup.cache_routes, }; - let services = yaml_config + let services = config .services .into_iter() - .map(|yaml_service| LocalService { - name: yaml_service.name, - remote: yaml_service.remote, - local: yaml_service.local, + .map(|service_config| LocalService { + config: service_config.clone(), current: ServiceTarget::Remote, - directory: yaml_service.directory, - rewrites: yaml_service.rewrites.unwrap_or_default(), - health: yaml_service.health, }) .collect::>(); - let domains = yaml_config.domains; + let domains = config.domains; - LocalState { + State { linkup, domains, services, @@ -258,7 +191,7 @@ pub fn config_path(config_arg: &Option) -> Result { } } -pub fn get_config(config_path: &str) -> Result { +pub fn get_config(config_path: &str) -> Result { let content = fs::read_to_string(config_path) .with_context(|| format!("Failed to read config file {config_path:?}"))?; @@ -268,25 +201,25 @@ pub fn get_config(config_path: &str) -> Result { // This method gets the local state and uploads it to both the local linkup server and // the remote linkup server (worker). -pub async fn upload_state(state: &LocalState) -> Result { +pub async fn upload_state(state: &State) -> Result { let local_url = services::LocalServer::url(); - let server_config = ServerConfig::from(state); + let servers_sessions = ServersSessions::from(state); let session_name = &state.linkup.session_name; - let server_session_name = upload_config_to_server( + let server_session_name = upload_session_to_server( &state.linkup.worker_url, &state.linkup.worker_token, session_name, - server_config.remote, + servers_sessions.remote, ) .await?; - let local_session_name = upload_config_to_server( + let local_session_name = upload_session_to_server( &local_url, &state.linkup.worker_token, &server_session_name, - server_config.local, + servers_sessions.local, ) .await?; @@ -303,18 +236,18 @@ pub async fn upload_state(state: &LocalState) -> Result { Ok(server_session_name) } -async fn upload_config_to_server( +async fn upload_session_to_server( linkup_url: &Url, worker_token: &str, desired_name: &str, - config: StorableSession, + session: Session, ) -> Result { let session_update_req = UpdateSessionRequest { - session_token: config.session_token, + session_token: session.session_token, desired_name: desired_name.to_string(), - services: config.services, - domains: config.domains, - cache_routes: config.cache_routes, + services: session.services, + domains: session.domains, + cache_routes: session.cache_routes, }; let session_name = WorkerClient::new(linkup_url, worker_token) @@ -324,65 +257,65 @@ async fn upload_config_to_server( Ok(session_name) } -impl From<&LocalState> for ServerConfig { - fn from(state: &LocalState) -> Self { +impl From<&State> for ServersSessions { + fn from(state: &State) -> Self { let local_server_services = state .services .iter() - .map(|service| StorableService { - name: service.name.clone(), + .map(|service| SessionService { + name: service.config.name.clone(), location: if service.current == ServiceTarget::Remote { - service.remote.clone() + service.config.remote.clone() } else { - service.local.clone() + service.config.local.clone() }, - rewrites: Some(service.rewrites.clone()), + rewrites: service.config.rewrites.clone(), }) - .collect::>(); + .collect::>(); let remote_server_services = state .services .iter() - .map(|service| StorableService { - name: service.name.clone(), + .map(|service| SessionService { + name: service.config.name.clone(), location: if service.current == ServiceTarget::Remote { - service.remote.clone() + service.config.remote.clone() } else { state.get_tunnel_url() }, - rewrites: Some(service.rewrites.clone()), + rewrites: service.config.rewrites.clone(), }) - .collect::>(); + .collect::>(); - let local_storable_session = StorableSession { + let local_session = Session { session_token: state.linkup.session_token.clone(), services: local_server_services, domains: state.domains.clone(), cache_routes: state.linkup.cache_routes.clone(), }; - let remote_storable_session = StorableSession { + let remote_session = Session { session_token: state.linkup.session_token.clone(), services: remote_server_services, domains: state.domains.clone(), cache_routes: state.linkup.cache_routes.clone(), }; - ServerConfig { - local: local_storable_session, - remote: remote_storable_session, + ServersSessions { + local: local_session, + remote: remote_session, } } } -pub fn managed_domains(state: Option<&LocalState>, cfg_path: &Option) -> Vec { +pub fn managed_domains(state: Option<&State>, cfg_path: &Option) -> Vec { let config_domains = match config_path(cfg_path).ok() { Some(cfg_path) => match get_config(&cfg_path) { Ok(config) => Some( config .domains .iter() - .map(|storable_domain| storable_domain.domain.clone()) + .map(|domain| domain.domain.clone()) .collect::>(), ), Err(_) => None, @@ -453,8 +386,8 @@ domains: #[test] fn test_config_to_state() { let input_str = String::from(CONF_STR); - let yaml_config = serde_yaml::from_str(&input_str).unwrap(); - let local_state = config_to_state(yaml_config, "./path/to/config.yaml".to_string(), false); + let config = serde_yaml::from_str(&input_str).unwrap(); + let local_state = config_to_state(config, "./path/to/config.yaml".to_string(), false); assert_eq!(local_state.linkup.config_path, "./path/to/config.yaml"); @@ -468,40 +401,45 @@ domains: ); assert_eq!(local_state.services.len(), 2); - assert_eq!(local_state.services[0].name, "frontend"); + assert_eq!(local_state.services[0].config.name, "frontend"); assert_eq!( - local_state.services[0].remote, + local_state.services[0].config.remote, Url::parse("http://remote-service1.example.com").unwrap() ); assert_eq!( - local_state.services[0].local, + local_state.services[0].config.local, Url::parse("http://localhost:8000").unwrap() ); assert_eq!(local_state.services[0].current, ServiceTarget::Remote); - assert_eq!(local_state.services[0].health, None); + assert!(local_state.services[0].config.health.is_none()); - assert_eq!(local_state.services[0].rewrites.len(), 1); - assert_eq!(local_state.services[1].name, "backend"); assert_eq!( - local_state.services[1].remote, + local_state.services[0] + .config + .rewrites + .as_ref() + .unwrap() + .len(), + 1 + ); + assert_eq!(local_state.services[1].config.name, "backend"); + assert_eq!( + local_state.services[1].config.remote, Url::parse("http://remote-service2.example.com").unwrap() ); assert_eq!( - local_state.services[1].local, + local_state.services[1].config.local, Url::parse("http://localhost:8001").unwrap() ); - assert_eq!(local_state.services[1].rewrites.len(), 0); + assert!(local_state.services[1].config.rewrites.is_none()); assert_eq!( - local_state.services[1].directory, + local_state.services[1].config.directory, Some("../backend".to_string()) ); - assert_eq!( - local_state.services[1].health, - Some(HealthConfig { - path: Some("/health".to_string()), - statuses: Some(vec![200, 304]), - }) - ); + assert!(local_state.services[1].config.health.is_some()); + let health = local_state.services[1].config.health.as_ref().unwrap(); + assert_eq!(health.path, Some("/health".to_string())); + assert_eq!(health.statuses, Some(vec![200, 304])); assert_eq!(local_state.domains.len(), 2); assert_eq!(local_state.domains[0].domain, "example.com"); diff --git a/linkup-cli/src/worker_client.rs b/linkup-cli/src/worker_client.rs index f69fcb2e..9f34d216 100644 --- a/linkup-cli/src/worker_client.rs +++ b/linkup-cli/src/worker_client.rs @@ -1,10 +1,8 @@ use linkup::{CreatePreviewRequest, UpdateSessionRequest}; -use reqwest::{header, StatusCode}; +use reqwest::{StatusCode, header}; use serde::{Deserialize, Serialize}; use url::Url; -use crate::local_config::YamlLocalConfig; - #[derive(thiserror::Error, Debug)] pub enum Error { #[error("{0}")] @@ -103,21 +101,20 @@ impl WorkerClient { .send() .await?; - match response.status() { - StatusCode::OK => { - let content = response.text().await?; - Ok(content) - } - _ => Err(Error::Response( + if response.status().is_success() { + let content = response.text().await?; + Ok(content) + } else { + Err(Error::Response( response.status(), response.text().await.unwrap_or_else(|_| "".to_string()), - )), + )) } } } -impl From<&YamlLocalConfig> for WorkerClient { - fn from(config: &YamlLocalConfig) -> Self { +impl From<&linkup::config::Config> for WorkerClient { + fn from(config: &linkup::config::Config) -> Self { Self::new(&config.linkup.worker_url, &config.linkup.worker_token) } } diff --git a/linkup/Cargo.toml b/linkup/Cargo.toml index 93240fb7..8ce03e9f 100644 --- a/linkup/Cargo.toml +++ b/linkup/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "linkup" version = "0.1.0" -edition = "2021" +edition = "2024" [dependencies] hex = "0.4.3" diff --git a/linkup/src/config.rs b/linkup/src/config.rs new file mode 100644 index 00000000..8a06a216 --- /dev/null +++ b/linkup/src/config.rs @@ -0,0 +1,40 @@ +use regex::Regex; +use serde::{Deserialize, Serialize}; +use url::Url; + +use crate::{Domain, Rewrite}; + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Config { + pub linkup: LinkupConfig, + pub services: Vec, + pub domains: Vec, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct LinkupConfig { + pub worker_url: Url, + pub worker_token: String, + #[serde( + default, + deserialize_with = "crate::serde_ext::deserialize_opt_vec_regex", + serialize_with = "crate::serde_ext::serialize_opt_vec_regex" + )] + pub cache_routes: Option>, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ServiceConfig { + pub name: String, + pub remote: Url, + pub local: Url, + pub directory: Option, + pub rewrites: Option>, + pub health: Option, +} + +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct HealthConfig { + pub path: Option, + pub statuses: Option>, +} diff --git a/linkup/src/headers.rs b/linkup/src/headers.rs index 5e6efe69..64df5c12 100644 --- a/linkup/src/headers.rs +++ b/linkup/src/headers.rs @@ -164,13 +164,15 @@ impl From for HeaderMap { impl From for HttpHeaderMap { fn from(linkup_headers: HeaderMap) -> Self { let mut http_headers = HttpHeaderMap::new(); + for (key, value) in linkup_headers.into_iter() { - if let Ok(http_value) = HttpHeaderValue::from_str(&value) { - if let Ok(http_key) = http::header::HeaderName::from_bytes(key.as_bytes()) { - http_headers.insert(http_key, http_value); - } + if let Ok(http_value) = HttpHeaderValue::from_str(&value) + && let Ok(http_key) = http::header::HeaderName::from_bytes(key.as_bytes()) + { + http_headers.insert(http_key, http_value); } } + http_headers } } @@ -178,7 +180,7 @@ impl From for HttpHeaderMap { #[cfg(test)] mod tests { use super::normalize_cookie_header; - use http::{header::COOKIE, HeaderMap, HeaderValue}; + use http::{HeaderMap, HeaderValue, header::COOKIE}; #[test] fn normalizes_multiple_cookie_headers_with_semicolon() { diff --git a/linkup/src/lib.rs b/linkup/src/lib.rs index cc3f3a35..608c4afb 100644 --- a/linkup/src/lib.rs +++ b/linkup/src/lib.rs @@ -1,3 +1,6 @@ +pub mod config; +pub mod serde_ext; + mod headers; mod memory_session_store; mod name_gen; @@ -162,32 +165,32 @@ pub fn get_target_service( // If there was a destination created in a previous linkup, we don't want to // re-do path rewrites, so we use the destination service. - if let Some(destination_service) = headers.get(HeaderName::LinkupDestination) { - if let Some(service) = config.services.get(destination_service) { - let target = redirect(target.clone(), &service.origin, Some(path.to_string())); - return Some(TargetService { - name: destination_service.to_string(), - url: target.to_string(), - }); - } + if let Some(destination_service) = headers.get(HeaderName::LinkupDestination) + && let Some(service) = config.get_service(destination_service) + { + let target = redirect(target.clone(), &service.location, Some(path.to_string())); + return Some(TargetService { + name: destination_service.to_string(), + url: target.to_string(), + }); } - let url_target = config.domains.get(&get_target_domain(url, session_name)); + let url_target = config.get_domain(&get_target_domain(url, session_name)); // Forwarded hosts persist over the tunnel - let forwarded_host_target = config.domains.get(&get_target_domain( + let forwarded_host_target = config.get_domain(&get_target_domain( headers.get_or_default(HeaderName::ForwardedHost, "does-not-exist"), session_name, )); // This is more for e2e tests to work - let referer_target = config.domains.get(&get_target_domain( + let referer_target = config.get_domain(&get_target_domain( headers.get_or_default(HeaderName::Referer, "does-not-exist"), session_name, )); // This one is for redirects, where the referer doesn't exist - let origin_target = config.domains.get(&get_target_domain( + let origin_target = config.get_domain(&get_target_domain( headers.get_or_default(HeaderName::Origin, "does-not-exist"), session_name, )); @@ -203,30 +206,34 @@ pub fn get_target_service( }; if let Some(domain) = target_domain { - let service_name = domain - .routes - .iter() - .find_map(|route| { - if route.path.is_match(path) { - Some(route.service.clone()) - } else { - None - } - }) - .unwrap_or_else(|| domain.default_service.clone()); + let service_name = match &domain.routes { + Some(routes) => routes + .iter() + .find_map(|route| { + if route.path.is_match(path) { + Some(route.service.clone()) + } else { + None + } + }) + .unwrap_or_else(|| domain.default_service.clone()), + None => domain.default_service.clone(), + }; - if let Some(service) = config.services.get(&service_name) { + if let Some(service) = config.get_service(&service_name) { let mut new_path = path.to_string(); - for modifier in &service.rewrites { - if modifier.source.is_match(&new_path) { - new_path = modifier - .source - .replace_all(&new_path, &modifier.target) - .to_string(); + if let Some(rewrites) = &service.rewrites { + for modifier in rewrites { + if modifier.source.is_match(&new_path) { + new_path = modifier + .source + .replace_all(&new_path, &modifier.target) + .to_string(); + } } } - let target = redirect(target, &service.origin, Some(new_path)); + let target = redirect(target, &service.location, Some(new_path)); return Some(TargetService { name: service_name, url: target.to_string(), diff --git a/linkup/src/serde_ext.rs b/linkup/src/serde_ext.rs new file mode 100644 index 00000000..b9b46514 --- /dev/null +++ b/linkup/src/serde_ext.rs @@ -0,0 +1,136 @@ +use std::str::FromStr; + +use regex::Regex; +use serde::{Deserialize, Deserializer, Serializer, ser::SerializeSeq}; + +pub fn serialize_regex(regex: &Regex, serializer: S) -> Result +where + S: Serializer, +{ + serializer.serialize_str(regex.as_str()) +} + +pub fn deserialize_regex<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + Regex::from_str(&s).map_err(serde::de::Error::custom) +} + +pub fn serialize_opt_vec_regex( + regexes: &Option>, + serializer: S, +) -> Result +where + S: Serializer, +{ + match regexes { + Some(regexes) => { + let mut seq = serializer.serialize_seq(Some(regexes.len()))?; + + for regex in regexes { + seq.serialize_element(regex.as_str())?; + } + + seq.end() + } + None => serializer.serialize_none(), + } +} + +pub fn deserialize_opt_vec_regex<'de, D>(deserializer: D) -> Result>, D::Error> +where + D: Deserializer<'de>, +{ + let regexes_str: Option> = Option::deserialize(deserializer)?; + let Some(regexes_str) = regexes_str else { + return Ok(None); + }; + + let mut regexes: Vec = Vec::with_capacity(regexes_str.len()); + + for regex_str in regexes_str { + let regex = Regex::from_str(®ex_str).map_err(serde::de::Error::custom)?; + regexes.push(regex); + } + + Ok(Some(regexes)) +} + +#[cfg(test)] +mod tests { + use regex::Regex; + use serde::{Deserialize, Serialize}; + + #[test] + fn test_serialize_deserialize_regex() { + #[derive(Serialize, Deserialize)] + struct A { + #[serde( + deserialize_with = "crate::serde_ext::deserialize_regex", + serialize_with = "crate::serde_ext::serialize_regex" + )] + reg_field: Regex, + } + + let record = A { + reg_field: Regex::new("abc: (.+)").unwrap(), + }; + + let serialized_record = serde_json::to_string(&record).unwrap(); + assert_eq!(r#"{"reg_field":"abc: (.+)"}"#, &serialized_record); + + let des_record: A = serde_json::from_str(&serialized_record).unwrap(); + assert!(des_record.reg_field.is_match("abc: foo")); + + let captures = des_record.reg_field.captures("abc: foo").unwrap(); + assert_eq!("foo", captures.get(1).unwrap().as_str()); + } + + #[test] + fn test_serialize_deserialize_opt_vec_regex() { + #[derive(Serialize, Deserialize)] + struct A { + #[serde( + deserialize_with = "crate::serde_ext::deserialize_opt_vec_regex", + serialize_with = "crate::serde_ext::serialize_opt_vec_regex" + )] + reg_field: Option>, + + #[serde( + deserialize_with = "crate::serde_ext::deserialize_opt_vec_regex", + serialize_with = "crate::serde_ext::serialize_opt_vec_regex" + )] + reg_field2: Option>, + + #[serde( + deserialize_with = "crate::serde_ext::deserialize_opt_vec_regex", + serialize_with = "crate::serde_ext::serialize_opt_vec_regex" + )] + reg_field3: Option>, + } + + let record = A { + reg_field: None, + reg_field2: Some(vec![]), + reg_field3: Some(vec![Regex::new("abc: (.+)").unwrap()]), + }; + + let serialized_record = serde_json::to_string(&record).unwrap(); + assert_eq!( + r#"{"reg_field":null,"reg_field2":[],"reg_field3":["abc: (.+)"]}"#, + &serialized_record + ); + + let des_record: A = serde_json::from_str(&serialized_record).unwrap(); + + assert!(des_record.reg_field.is_none()); + + assert!(des_record.reg_field2.is_some()); + assert!(des_record.reg_field2.unwrap().is_empty()); + + assert!(des_record.reg_field3.is_some()); + assert!(des_record.reg_field3.unwrap()[0].is_match("abc: foo")); + } +} diff --git a/linkup/src/session.rs b/linkup/src/session.rs index 68122278..8b2aaefb 100644 --- a/linkup/src/session.rs +++ b/linkup/src/session.rs @@ -1,98 +1,87 @@ -use std::{ - cmp::Ordering, - collections::{HashMap, HashSet}, -}; +use std::collections::HashSet; use thiserror::Error; use regex::Regex; use serde::{Deserialize, Serialize}; use url::Url; -pub const PREVIEW_SESSION_TOKEN: &str = "preview_session"; +use crate::config::Config; -#[derive(Clone, Debug)] -pub struct Session { - pub session_token: String, - pub services: HashMap, - pub domains: HashMap, - pub domain_selection_order: Vec, - pub cache_routes: Option>, -} - -#[derive(Clone, Debug)] -pub struct Service { - pub origin: Url, - pub rewrites: Vec, -} - -#[derive(Clone, Debug)] -pub struct Rewrite { - pub source: Regex, - pub target: String, -} +pub const PREVIEW_SESSION_TOKEN: &str = "preview_session"; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct Domain { + pub domain: String, pub default_service: String, - pub routes: Vec, + pub routes: Option>, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct Route { + #[serde( + serialize_with = "crate::serde_ext::serialize_regex", + deserialize_with = "crate::serde_ext::deserialize_regex" + )] pub path: Regex, pub service: String, } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct UpdateSessionRequest { pub desired_name: String, pub session_token: String, - pub services: Vec, - pub domains: Vec, - pub cache_routes: Option>, + pub services: Vec, + pub domains: Vec, + #[serde( + default, + serialize_with = "crate::serde_ext::serialize_opt_vec_regex", + deserialize_with = "crate::serde_ext::deserialize_opt_vec_regex" + )] + pub cache_routes: Option>, } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct CreatePreviewRequest { - pub services: Vec, - pub domains: Vec, - pub cache_routes: Option>, + pub services: Vec, + pub domains: Vec, + #[serde( + default, + serialize_with = "crate::serde_ext::serialize_opt_vec_regex", + deserialize_with = "crate::serde_ext::deserialize_opt_vec_regex" + )] + pub cache_routes: Option>, } -#[derive(Debug, Deserialize, Serialize)] -pub struct StorableSession { +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Session { pub session_token: String, - pub services: Vec, - pub domains: Vec, - pub cache_routes: Option>, + pub services: Vec, + pub domains: Vec, + #[serde( + default, + serialize_with = "crate::serde_ext::serialize_opt_vec_regex", + deserialize_with = "crate::serde_ext::deserialize_opt_vec_regex" + )] + pub cache_routes: Option>, } #[derive(Clone, Debug, Deserialize, Serialize)] -pub struct StorableService { +pub struct SessionService { pub name: String, pub location: Url, - pub rewrites: Option>, + pub rewrites: Option>, } -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] -pub struct StorableRewrite { - pub source: String, +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Rewrite { + #[serde( + serialize_with = "crate::serde_ext::serialize_regex", + deserialize_with = "crate::serde_ext::deserialize_regex" + )] + pub source: Regex, pub target: String, } -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] -pub struct StorableDomain { - pub domain: String, - pub default_service: String, - pub routes: Option>, -} - -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] -pub struct StorableRoute { - pub path: String, - pub service: String, -} - #[derive(Error, Debug)] pub enum ConfigError { #[error("linkup session json format error: {0}")] @@ -109,135 +98,53 @@ pub enum ConfigError { Empty, } -impl From for StorableSession { - fn from(req: UpdateSessionRequest) -> Self { - StorableSession { - session_token: req.session_token, - services: req.services, - domains: req.domains, - cache_routes: req.cache_routes, - } +impl Session { + pub fn get_service(&self, service_name: &str) -> Option<&SessionService> { + self.services + .iter() + .find(|service| service.name == service_name) + } + + pub fn get_domain(&self, domain: &str) -> Option<&Domain> { + self.domains + .iter() + .find(|domain_record| domain_record.domain == domain) } } impl TryFrom for Session { type Error = ConfigError; - fn try_from(value: UpdateSessionRequest) -> Result { - let storable: StorableSession = value.into(); - storable.try_into() - } -} - -impl From for StorableSession { - fn from(req: CreatePreviewRequest) -> Self { - StorableSession { - session_token: PREVIEW_SESSION_TOKEN.to_string(), + fn try_from(req: UpdateSessionRequest) -> Result { + let session = Self { + session_token: req.session_token, services: req.services, domains: req.domains, cache_routes: req.cache_routes, - } - } -} - -impl TryFrom for Session { - type Error = ConfigError; - - fn try_from(value: CreatePreviewRequest) -> Result { - let storable: StorableSession = value.into(); - storable.try_into() - } -} - -impl TryFrom for Rewrite { - type Error = ConfigError; - - fn try_from(value: StorableRewrite) -> Result { - let source: Result = Regex::new(&value.source); - match source { - Err(e) => Err(ConfigError::InvalidRegex(value.source, e)), - Ok(s) => Ok(Rewrite { - source: s, - target: value.target, - }), - } - } -} + }; -impl TryFrom for Route { - type Error = ConfigError; + validate_not_empty(&session)?; + validate_services(&session)?; - fn try_from(value: StorableRoute) -> Result { - let path = Regex::new(&value.path); - match path { - Err(e) => Err(ConfigError::InvalidRegex(value.path, e)), - Ok(p) => Ok(Route { - path: p, - service: value.service, - }), - } + Ok(session) } } -impl TryFrom for Session { +impl TryFrom for Session { type Error = ConfigError; - fn try_from(value: StorableSession) -> Result { - validate_not_empty(&value)?; - validate_service_references(&value)?; - - let mut services: HashMap = HashMap::new(); - let mut domains: HashMap = HashMap::new(); - - for stored_service in value.services { - validate_url_origin(&stored_service.location)?; - - let rewrites = match stored_service.rewrites { - Some(pm) => pm.into_iter().map(|r| r.try_into()).collect(), - None => Ok(Vec::new()), - }?; - - let service = Service { - origin: stored_service.location, - rewrites, - }; - - services.insert(stored_service.name, service); - } - - for stored_domain in value.domains { - let routes = match stored_domain.routes { - Some(dr) => dr.into_iter().map(|r| r.try_into()).collect(), - None => Ok(Vec::new()), - }?; - - let domain = Domain { - default_service: stored_domain.default_service, - routes, - }; - - domains.insert(stored_domain.domain, domain); - } - - let domain_names = domains.keys().cloned().collect(); - - let cache_routes = match value.cache_routes { - Some(cr) => Some( - cr.into_iter() - .map(|r| Regex::new(&r)) - .collect::, regex::Error>>() - .map_err(|e| ConfigError::InvalidRegex("cache route".to_string(), e))?, - ), - None => None, + fn try_from(req: CreatePreviewRequest) -> Result { + let session = Self { + session_token: PREVIEW_SESSION_TOKEN.to_string(), + services: req.services, + domains: req.domains, + cache_routes: req.cache_routes, }; - Ok(Session { - session_token: value.session_token, - services, - domains, - domain_selection_order: choose_domain_ordering(domain_names), - cache_routes, - }) + validate_not_empty(&session)?; + validate_services(&session)?; + + Ok(session) } } @@ -245,153 +152,66 @@ impl TryFrom for Session { type Error = ConfigError; fn try_from(value: serde_json::Value) -> Result { - let session_yml_res: Result = - serde_json::from_value(value); + let session = serde_json::from_value(value)?; - match session_yml_res { - Err(e) => Err(ConfigError::JsonFormat(e)), - Ok(c) => c.try_into(), - } + validate_not_empty(&session)?; + validate_services(&session)?; + + Ok(session) } } -impl From for StorableSession { - fn from(value: Session) -> Self { - let services: Vec = value - .services - .into_iter() - .map(|(name, service)| { - let rewrites = if service.rewrites.is_empty() { - None - } else { - Some( - service - .rewrites - .into_iter() - .map(|path_modifier| StorableRewrite { - source: path_modifier.source.to_string(), - target: path_modifier.target, - }) - .collect(), - ) - }; - - StorableService { - name, - location: service.origin, - rewrites, - } - }) - .collect(); - - let domains: Vec = value - .domains - .into_iter() - .map(|(domain, domain_data)| { - let default_service = domain_data.default_service; - let routes = if domain_data.routes.is_empty() { - None - } else { - Some( - domain_data - .routes - .into_iter() - .map(|route| StorableRoute { - path: route.path.to_string(), - service: route.service, - }) - .collect(), - ) - }; - - StorableDomain { - domain, - default_service, - routes, - } - }) - .collect(); - - let cache_routes = value.cache_routes.map(|cr| { - cr.into_iter() - .map(|r| r.to_string()) - .collect::>() - }); +pub fn create_preview_req_from_config( + config: &Config, + services_overwrite: &[(String, Url)], +) -> CreatePreviewRequest { + let mut session_services: Vec = Vec::with_capacity(config.services.len()); - StorableSession { - session_token: value.session_token, - services, - domains, - cache_routes, - } - } -} + for service in &config.services { + let service_overwrite = services_overwrite + .iter() + .find(|overwrite| overwrite.0 == service.name); -pub fn update_session_req_from_json(input_json: String) -> Result<(String, Session), ConfigError> { - let update_session_req_res: Result = - serde_json::from_str(&input_json); - - match update_session_req_res { - Err(e) => Err(ConfigError::JsonFormat(e)), - Ok(c) => { - let server_conf = StorableSession { - session_token: c.session_token, - services: c.services, - domains: c.domains, - cache_routes: c.cache_routes, - } - .try_into(); + let location = match service_overwrite { + Some((_, location_overwrite)) => location_overwrite.clone(), + None => service.remote.clone(), + }; - match server_conf { - Err(e) => Err(e), - Ok(sc) => Ok((c.desired_name, sc)), - } - } + session_services.push(SessionService { + name: service.name.clone(), + location, + rewrites: service.rewrites.clone(), + }); } -} - -pub fn create_preview_req_from_json(input_json: String) -> Result { - let update_session_req_res: Result = - serde_json::from_str(&input_json); - - match update_session_req_res { - Err(e) => Err(ConfigError::JsonFormat(e)), - Ok(c) => { - let server_conf = StorableSession { - session_token: String::from(PREVIEW_SESSION_TOKEN), - services: c.services, - domains: c.domains, - cache_routes: None, - } - .try_into(); - match server_conf { - Err(e) => Err(e), - Ok(sc) => Ok(sc), - } - } + CreatePreviewRequest { + services: session_services, + domains: config.domains.clone(), + cache_routes: config.linkup.cache_routes.clone(), } } -fn validate_not_empty(server_config: &StorableSession) -> Result<(), ConfigError> { - if server_config.services.is_empty() { +fn validate_not_empty(session: &Session) -> Result<(), ConfigError> { + if session.services.is_empty() { return Err(ConfigError::Empty); } - if server_config.domains.is_empty() { + if session.domains.is_empty() { return Err(ConfigError::Empty); } Ok(()) } -fn validate_service_references(server_config: &StorableSession) -> Result<(), ConfigError> { - let service_names: HashSet<&str> = server_config - .services - .iter() - .map(|s| s.name.as_str()) - .collect(); +fn validate_services(session: &Session) -> Result<(), ConfigError> { + let mut service_names: HashSet<&str> = HashSet::new(); + + for service in &session.services { + validate_url_origin(&service.location)?; - for domain in &server_config.domains { + service_names.insert(&service.name); + } + + for domain in &session.domains { if !service_names.contains(&domain.default_service.as_str()) { return Err(ConfigError::NoSuchService( domain.default_service.to_string(), @@ -423,37 +243,6 @@ fn validate_url_origin(url: &Url) -> Result<(), ConfigError> { Ok(()) } -fn choose_domain_ordering(domains: Vec) -> Vec { - let mut sorted_domains = domains; - sorted_domains.sort_by(|a, b| { - let a_subdomains: Vec<&str> = a.split('.').collect(); - let b_subdomains: Vec<&str> = b.split('.').collect(); - - let a_len = a_subdomains.len(); - let b_len = b_subdomains.len(); - - if a_len != b_len { - b_len.cmp(&a_len) - } else { - a_subdomains - .iter() - .zip(b_subdomains.iter()) - .map(|(a_sub, b_sub)| b_sub.len().cmp(&a_sub.len())) - .find(|&ord| ord != Ordering::Equal) - .unwrap_or(Ordering::Equal) - } - }); - - sorted_domains -} - -pub fn session_to_json(session: Session) -> String { - let storable_session: StorableSession = session.into(); - - // This should never fail, due to previous validation - serde_json::to_string(&storable_session).unwrap() -} - #[cfg(test)] mod tests { use super::*; @@ -500,127 +289,75 @@ mod tests { "#; #[test] - fn test_convert_server_config() { + fn test_convert_session() { let input_str = String::from(CONF_STR); - let server_config_value = serde_json::from_str::(&input_str).unwrap(); - let server_config: Session = server_config_value.try_into().unwrap(); - check_means_same_as_input_conf(&server_config); + let session_value = serde_json::from_str::(&input_str).unwrap(); + let session: Session = session_value.try_into().unwrap(); + check_means_same_as_input_conf(&session); // Inverse should mean the same thing - let output_conf = session_to_json(server_config); - let output_conf_value = serde_json::from_str::(&output_conf).unwrap(); - let second_server_conf: Session = output_conf_value.try_into().unwrap(); - check_means_same_as_input_conf(&second_server_conf); + let output_session = serde_json::to_string(&session).unwrap(); + let output_session_value = + serde_json::from_str::(&output_session).unwrap(); + let second_session: Session = output_session_value.try_into().unwrap(); + check_means_same_as_input_conf(&second_session); } - fn check_means_same_as_input_conf(server_config: &Session) { + fn check_means_same_as_input_conf(session: &Session) { // Test services - assert_eq!(server_config.services.len(), 2); - assert!(server_config.services.contains_key("frontend")); - assert!(server_config.services.contains_key("backend")); + assert_eq!(session.services.len(), 2); + + let frontend_service = session.get_service("frontend").unwrap(); assert_eq!( - server_config.services.get("frontend").unwrap().origin, + frontend_service.location, Url::parse("http://localhost:8000").unwrap() ); + assert_eq!( - server_config.services.get("frontend").unwrap().rewrites[0] - .source - .as_str(), - "/foo/(.*)" - ); - assert_eq!( - server_config.services.get("frontend").unwrap().rewrites[0].target, - "/bar/$1" + Some(1), + frontend_service + .rewrites + .as_ref() + .map(|rewrites| rewrites.len()) ); + + let frontend_service_rewrite = &frontend_service.rewrites.as_ref().unwrap()[0]; + assert_eq!(frontend_service_rewrite.source.as_str(), "/foo/(.*)"); + assert_eq!(frontend_service_rewrite.target, "/bar/$1"); + + let backend_service = session.get_service("backend").unwrap(); assert_eq!( - server_config.services.get("backend").unwrap().origin, + backend_service.location, Url::parse("http://localhost:8001").unwrap() ); - assert!(server_config - .services - .get("backend") - .unwrap() - .rewrites - .is_empty()); + assert!(backend_service.rewrites.is_none()); // Test domains - assert_eq!(server_config.domains.len(), 2); - assert!(server_config.domains.contains_key("example.com")); - assert!(server_config.domains.contains_key("api.example.com")); - assert_eq!( - server_config - .domains - .get("example.com") - .unwrap() - .default_service, - "frontend" - ); - assert_eq!( - server_config.domains.get("example.com").unwrap().routes[0] - .path - .as_str(), - "/api/v1/.*" - ); - assert_eq!( - server_config.domains.get("example.com").unwrap().routes[0].service, - "backend" - ); + assert_eq!(2, session.domains.len()); + + let example_domain = session.get_domain("example.com").unwrap(); + assert_eq!(example_domain.default_service, "frontend"); + assert_eq!( - server_config - .domains - .get("api.example.com") - .unwrap() - .default_service, - "backend" + Some(1), + example_domain.routes.as_ref().map(|routes| routes.len()) ); - assert!(server_config - .domains - .get("api.example.com") - .unwrap() - .routes - .is_empty()); - - assert_eq!(server_config.cache_routes.as_ref().unwrap().len(), 1); + + let example_domain_route = &example_domain.routes.as_ref().unwrap()[0]; + assert_eq!(example_domain_route.path.as_str(), "/api/v1/.*"); + assert_eq!(example_domain_route.service, "backend"); + + let api_domain = session.get_domain("api.example.com").unwrap(); + assert_eq!(api_domain.default_service, "backend"); + assert!(api_domain.routes.is_none()); + + // Test cache routes + + assert_eq!(session.cache_routes.as_ref().unwrap().len(), 1); assert_eq!( - server_config.cache_routes.as_ref().unwrap()[0].as_str(), + session.cache_routes.as_ref().unwrap()[0].as_str(), "/static/.*" ); } - - #[test] - fn test_choose_domain_ordering() { - let input = vec![ - "example.com".to_string(), - "api.example.com".to_string(), - "render-api.example.com".to_string(), - "another-example.com".to_string(), - ]; - - let expected_output = vec![ - "render-api.example.com".to_string(), - "api.example.com".to_string(), - "another-example.com".to_string(), - "example.com".to_string(), - ]; - - assert_eq!(choose_domain_ordering(input), expected_output); - } - - #[test] - fn test_choose_domain_ordering_with_same_length() { - let input = vec![ - "a.domain.com".to_string(), - "b.domain.com".to_string(), - "c.domain.com".to_string(), - ]; - - let expected_output = vec![ - "a.domain.com".to_string(), - "b.domain.com".to_string(), - "c.domain.com".to_string(), - ]; - - assert_eq!(choose_domain_ordering(input), expected_output); - } } diff --git a/linkup/src/session_allocator.rs b/linkup/src/session_allocator.rs index 0507554d..cbe4010a 100644 --- a/linkup/src/session_allocator.rs +++ b/linkup/src/session_allocator.rs @@ -1,7 +1,7 @@ use crate::{ - extract_tracestate_session, first_subdomain, headers::HeaderName, - name_gen::deterministic_six_char_hash, random_animal, random_six_char, session_to_json, ConfigError, HeaderMap, NameKind, Session, SessionError, StringStore, + extract_tracestate_session, first_subdomain, headers::HeaderName, + name_gen::deterministic_six_char_hash, random_animal, random_six_char, }; pub struct SessionAllocator<'a, S: StringStore> { @@ -63,7 +63,8 @@ impl<'a, S: StringStore> SessionAllocator<'a, S> { name_kind: NameKind, desired_name: String, ) -> Result { - let config_str = session_to_json(config.clone()); + let config_str = serde_json::to_string(&config) + .map_err(|error| SessionError::ConfigErr(error.to_string()))?; let name = self .choose_name(desired_name, config.session_token, name_kind, &config_str) @@ -87,10 +88,10 @@ impl<'a, S: StringStore> SessionAllocator<'a, S> { .await; } - if let Some(session) = self.get_session_config(desired_name.clone()).await? { - if session.session_token == session_token { - return Ok(desired_name); - } + if let Some(session) = self.get_session_config(desired_name.clone()).await? + && session.session_token == session_token + { + return Ok(desired_name); } self.new_session_name(name_kind, desired_name, config_json) diff --git a/local-server/Cargo.toml b/local-server/Cargo.toml index a476930d..4d123c78 100644 --- a/local-server/Cargo.toml +++ b/local-server/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "linkup-local-server" version = "0.1.0" -edition = "2021" +edition = "2024" [lib] name = "linkup_local_server" @@ -10,6 +10,7 @@ path = "src/lib.rs" [dependencies] axum = { version = "0.8.1", features = ["http2", "json", "ws"] } axum-server = { version = "0.8.0", features = ["tls-rustls"] } +async-trait = "0.1.43" http = "1.2.0" hickory-server = { version = "0.25.1", features = ["resolver"] } hyper = { version = "1.5.2", features = ["server"] } @@ -22,6 +23,8 @@ futures = "0.3.31" linkup = { path = "../linkup" } rustls = { version = "0.23.37", default-features = false, features = ["ring"] } rustls-native-certs = "0.8.1" +serde = "1.0.217" +serde_json = "1.0.137" thiserror = "2.0.11" tokio = { version = "1.49.0", features = [ "macros", diff --git a/local-server/src/certificates/mod.rs b/local-server/src/certificates/mod.rs index 66594a83..4a3c1f6f 100644 --- a/local-server/src/certificates/mod.rs +++ b/local-server/src/certificates/mod.rs @@ -90,8 +90,8 @@ pub fn setup_self_signed_certificates( if !is_nss_installed() { println!("It seems like you have Firefox installed."); println!( - "For self-signed certificates to work with Firefox, you need to have nss installed." - ); + "For self-signed certificates to work with Firefox, you need to have nss installed." + ); let nss_url = if cfg!(target_os = "macos") { "`brew install nss`" } else { diff --git a/local-server/src/lib.rs b/local-server/src/lib.rs index 0a726d88..8fd22ff3 100644 --- a/local-server/src/lib.rs +++ b/local-server/src/lib.rs @@ -1,13 +1,14 @@ use axum::{ + Extension, Router, body::Body, extract::{DefaultBodyLimit, Json, Request}, http::StatusCode, response::{IntoResponse, Response}, routing::{any, get, post}, - Extension, Router, }; use axum_server::tls_rustls::RustlsConfig; use hickory_server::{ + ServerFuture, authority::{Catalog, ZoneType}, proto::{ rr::{Name, RData, Record}, @@ -17,29 +18,31 @@ use hickory_server::{ config::{NameServerConfig, NameServerConfigGroup, ResolverOpts}, name_server::TokioConnectionProvider, }, + server::{RequestHandler, ResponseHandler, ResponseInfo}, store::{ forwarder::{ForwardAuthority, ForwardConfig}, in_memory::InMemoryAuthority, }, - ServerFuture, }; -use http::{header::HeaderMap, HeaderName, HeaderValue, Uri}; +use http::{HeaderName, HeaderValue, Uri, header::HeaderMap}; use hyper_rustls::HttpsConnector; use hyper_util::{ - client::legacy::{connect::HttpConnector, Client}, + client::legacy::{Client, connect::HttpConnector}, rt::TokioExecutor, }; use linkup::{ - allow_all_cors, get_additional_headers, get_target_service, MemoryStringStore, NameKind, - Session, SessionAllocator, TargetService, UpdateSessionRequest, + MemoryStringStore, NameKind, Session, SessionAllocator, TargetService, UpdateSessionRequest, + allow_all_cors, get_additional_headers, get_target_service, }; use rustls::ServerConfig; use std::{ net::{Ipv4Addr, SocketAddr}, + ops::Deref, + path::PathBuf, str::FromStr, }; use std::{path::Path, sync::Arc}; -use tokio::{net::UdpSocket, signal}; +use tokio::{net::UdpSocket, select, signal, sync::RwLock}; use tokio_tungstenite::tungstenite::client::IntoClientRequest; use tower::ServiceBuilder; use tower_http::trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer}; @@ -79,7 +82,43 @@ impl IntoResponse for ApiError { } } -pub fn linkup_router(config_store: MemoryStringStore) -> Router { +#[derive(Clone)] +pub struct DnsCatalog(Arc>); + +impl DnsCatalog { + pub fn new() -> Self { + Self(Arc::new(RwLock::new(Catalog::new()))) + } +} + +impl Default for DnsCatalog { + fn default() -> Self { + Self::new() + } +} + +impl Deref for DnsCatalog { + type Target = Arc>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[async_trait::async_trait] +impl RequestHandler for DnsCatalog { + async fn handle_request( + &self, + request: &hickory_server::server::Request, + response_handle: R, + ) -> ResponseInfo { + let catalog = self.read().await; + + catalog.handle_request(request, response_handle).await + } +} + +pub fn linkup_router(config_store: MemoryStringStore, dns_catalog: DnsCatalog) -> Router { let client = https_client(); Router::new() @@ -87,6 +126,7 @@ pub fn linkup_router(config_store: MemoryStringStore) -> Router { .route("/linkup/check", get(always_ok)) .fallback(any(linkup_request_handler)) .layer(Extension(config_store)) + .layer(Extension(dns_catalog)) .layer(Extension(client)) .layer( ServiceBuilder::new() @@ -99,7 +139,34 @@ pub fn linkup_router(config_store: MemoryStringStore) -> Router { ) } -pub async fn start_server_https(config_store: MemoryStringStore, certs_dir: &Path) { +pub async fn start(config_store: MemoryStringStore, certs_dir: &Path) { + let dns_catalog = DnsCatalog::new(); + + let http_config_store = config_store.clone(); + let https_config_store = config_store.clone(); + let https_certs_dir = PathBuf::from(certs_dir); + + select! { + () = start_server_http(http_config_store, dns_catalog.clone()) => { + println!("HTTP server shut down"); + }, + () = start_server_https(https_config_store, &https_certs_dir, dns_catalog.clone()) => { + println!("HTTPS server shut down"); + }, + () = start_dns_server(dns_catalog.clone()) => { + println!("DNS server shut down"); + }, + () = shutdown_signal() => { + println!("Shutdown signal received, stopping all servers"); + } + } +} + +async fn start_server_https( + config_store: MemoryStringStore, + certs_dir: &Path, + dns_catalog: DnsCatalog, +) { let _ = rustls::crypto::ring::default_provider().install_default(); let sni = match certificates::WildcardSniResolver::load_dir(certs_dir) { @@ -118,10 +185,10 @@ pub async fn start_server_https(config_store: MemoryStringStore, certs_dir: &Pat .with_cert_resolver(Arc::new(sni)); server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - let app = linkup_router(config_store); + let app = linkup_router(config_store, dns_catalog); let addr = SocketAddr::from(([0, 0, 0, 0], 443)); - println!("listening on {}", &addr); + println!("HTTPS listening on {}", &addr); axum_server::bind_rustls(addr, RustlsConfig::from_config(Arc::new(server_config))) .serve(app.into_make_service()) @@ -129,39 +196,22 @@ pub async fn start_server_https(config_store: MemoryStringStore, certs_dir: &Pat .expect("failed to start HTTPS server"); } -pub async fn start_server_http(config_store: MemoryStringStore) -> std::io::Result<()> { - let app = linkup_router(config_store); +async fn start_server_http(config_store: MemoryStringStore, dns_catalog: DnsCatalog) { + let app = linkup_router(config_store, dns_catalog); let addr = SocketAddr::from(([0, 0, 0, 0], 80)); - println!("listening on {}", &addr); + println!("HTTP listening on {}", &addr); - let listener = tokio::net::TcpListener::bind(addr).await?; - axum::serve(listener, app) - .with_graceful_shutdown(shutdown_signal()) - .await?; + let listener = tokio::net::TcpListener::bind(addr) + .await + .expect("failed to bind to address"); - Ok(()) + axum::serve(listener, app) + .await + .expect("failed to start HTTP server"); } -pub async fn start_dns_server(linkup_session_name: String, domains: Vec) { - let mut catalog = Catalog::new(); - - for domain in &domains { - let record_name = Name::from_str(&format!("{linkup_session_name}.{domain}.")).unwrap(); - - let authority = InMemoryAuthority::empty(record_name.clone(), ZoneType::Primary, false); - - let record = Record::from_rdata( - record_name.clone(), - 3600, - RData::A(Ipv4Addr::new(127, 0, 0, 1).into()), - ); - - authority.upsert(record, 0).await; - - catalog.upsert(record_name.clone().into(), vec![Arc::new(authority)]); - } - +async fn start_dns_server(dns_catalog: DnsCatalog) { let cf_name_server = NameServerConfig::new("1.1.1.1:53".parse().unwrap(), Protocol::Udp); let forward_config = ForwardConfig { name_servers: NameServerConfigGroup::from(vec![cf_name_server]), @@ -174,12 +224,15 @@ pub async fn start_dns_server(linkup_session_name: String, domains: Vec) .build() .unwrap(); - catalog.upsert(Name::root().into(), vec![Arc::new(forwarder)]); + { + let mut catalog = dns_catalog.write().await; + catalog.upsert(Name::root().into(), vec![Arc::new(forwarder)]); + } let addr = SocketAddr::from(([0, 0, 0, 0], 8053)); let sock = UdpSocket::bind(&addr).await.unwrap(); - let mut server = ServerFuture::new(catalog); + let mut server = ServerFuture::new(dns_catalog); server.register_socket(sock); println!("listening on {addr}"); @@ -249,11 +302,12 @@ async fn linkup_request_handler( let mut cookie_values: Vec = Vec::new(); for (key, value) in req.headers() { if key == http::header::COOKIE { - if let Ok(cookie_value) = value.to_str().map(str::trim) { - if !cookie_value.is_empty() { - cookie_values.push(cookie_value.to_string()); - } + if let Ok(cookie_value) = value.to_str().map(str::trim) + && !cookie_value.is_empty() + { + cookie_values.push(cookie_value.to_string()); } + continue; } @@ -351,7 +405,7 @@ async fn handle_http_req( ), StatusCode::BAD_GATEWAY, ) - .into_response() + .into_response(); } }; @@ -362,9 +416,16 @@ async fn handle_http_req( async fn linkup_config_handler( Extension(store): Extension, + Extension(dns_catalog): Extension, Json(update_req): Json, ) -> impl IntoResponse { let desired_name = update_req.desired_name.clone(); + let domains = update_req + .domains + .iter() + .map(|domain| domain.domain.clone()) + .collect::>(); + let server_conf: Session = match update_req.try_into() { Ok(conf) => conf, Err(e) => { @@ -372,36 +433,56 @@ async fn linkup_config_handler( format!("Failed to parse server config: {} - local server", e), StatusCode::BAD_REQUEST, ) - .into_response() + .into_response(); } }; let sessions = SessionAllocator::new(&store); - let session_name = sessions + let session_name_result = sessions .store_session(server_conf, NameKind::Animal, desired_name) .await; - let name = match session_name { + let session_name = match session_name_result { Ok(session_name) => session_name, Err(e) => { return ApiError::new( format!("Failed to store server config: {}", e), StatusCode::INTERNAL_SERVER_ERROR, ) - .into_response() + .into_response(); } }; - (StatusCode::OK, name).into_response() + for domain in &domains { + let full_domain = format!("{session_name}.{domain}"); + + register_dns_record(&dns_catalog, &full_domain).await; + } + + (StatusCode::OK, session_name).into_response() } async fn always_ok() -> &'static str { "OK" } -async fn shutdown_signal() { - let _ = signal::ctrl_c().await; - println!("signal received, starting graceful shutdown"); +async fn register_dns_record(dns_catalog: &DnsCatalog, domain: &str) { + let mut catalog = dns_catalog.write().await; + + let record_name = Name::from_str(&format!("{}.", domain)) + .expect("dns record from domain should always succeed"); + + let authority = InMemoryAuthority::empty(record_name.clone(), ZoneType::Primary, false); + + let record = Record::from_rdata( + record_name.clone(), + 3600, + RData::A(Ipv4Addr::new(127, 0, 0, 1).into()), + ); + + authority.upsert(record, 0).await; + + catalog.upsert(record_name.clone().into(), vec![Arc::new(authority)]); } fn https_client() -> HttpsClient { @@ -425,3 +506,27 @@ fn https_client() -> HttpsClient { Client::builder(TokioExecutor::new()).build(https) } + +async fn shutdown_signal() { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to start SIGINT handler"); + }; + + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to start SIGTERM handler") + .recv() + .await; + }; + + tokio::select! { + () = ctrl_c => { + println!("Received SIGINT signal"); + }, + () = terminate => { + println!("Received SIGTERM signal"); + }, + } +} diff --git a/local-server/src/ws.rs b/local-server/src/ws.rs index da66b6b9..bb730f2c 100644 --- a/local-server/src/ws.rs +++ b/local-server/src/ws.rs @@ -1,12 +1,12 @@ use std::{future::Future, pin::Pin}; -use axum::extract::{ws::WebSocket, FromRequestParts, WebSocketUpgrade}; +use axum::extract::{FromRequestParts, WebSocketUpgrade, ws::WebSocket}; use futures::{SinkExt, StreamExt}; -use http::{request::Parts, StatusCode}; +use http::{StatusCode, request::Parts}; use tokio::net::TcpStream; use tokio_tungstenite::{ - tungstenite::{self, Message}, MaybeTlsStream, WebSocketStream, + tungstenite::{self, Message}, }; pub struct ExtractOptionalWebSocketUpgrade(pub Option); @@ -76,7 +76,7 @@ pub fn context_handle_socket( ) -> WrappedSocketHandler { Box::new(move |downstream: WebSocket| { Box::pin(async move { - use futures::future::{select, Either}; + use futures::future::{Either, select}; let (mut upstream_write, mut upstream_read) = upstream_ws.split(); let (mut downstream_write, mut downstream_read) = downstream.split(); diff --git a/server-tests/Cargo.toml b/server-tests/Cargo.toml index d1282d0e..0f381639 100644 --- a/server-tests/Cargo.toml +++ b/server-tests/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "linkup-server-tests" version = "0.1.0" -edition = "2021" +edition = "2024" [dev-dependencies] linkup = { path = "../linkup" } diff --git a/server-tests/tests/helpers.rs b/server-tests/tests/helpers.rs index b44d40a0..b37605ad 100644 --- a/server-tests/tests/helpers.rs +++ b/server-tests/tests/helpers.rs @@ -1,7 +1,7 @@ use std::process::Command; -use linkup::{MemoryStringStore, StorableDomain, StorableService, UpdateSessionRequest}; -use linkup_local_server::linkup_router; +use linkup::{Domain, MemoryStringStore, SessionService, UpdateSessionRequest}; +use linkup_local_server::{DnsCatalog, linkup_router}; use reqwest::Url; use tokio::net::TcpListener; @@ -14,7 +14,7 @@ pub enum ServerKind { pub async fn setup_server(kind: ServerKind) -> String { match kind { ServerKind::Local => { - let app = linkup_router(MemoryStringStore::default()); + let app = linkup_router(MemoryStringStore::default(), DnsCatalog::new()); // Bind to a random port assigned by the OS let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); @@ -57,12 +57,12 @@ pub fn create_session_request(name: String, fe_location: Option) -> Stri let req = UpdateSessionRequest { desired_name: name, session_token: "token".to_string(), - domains: vec![StorableDomain { + domains: vec![Domain { domain: "example.com".to_string(), default_service: "frontend".to_string(), routes: None, }], - services: vec![StorableService { + services: vec![SessionService { name: "frontend".to_string(), location: Url::parse(&location).unwrap(), rewrites: None, diff --git a/server-tests/tests/http_test.rs b/server-tests/tests/http_test.rs index f2cc662b..e3c8dd08 100644 --- a/server-tests/tests/http_test.rs +++ b/server-tests/tests/http_test.rs @@ -1,10 +1,10 @@ use axum::{ + Router, response::{AppendHeaders, Redirect}, routing::{any, get}, - Router, }; use helpers::ServerKind; -use http::{header::SET_COOKIE, StatusCode}; +use http::{StatusCode, header::SET_COOKIE}; use rstest::rstest; use tokio::net::TcpListener; diff --git a/server-tests/tests/server_test.rs b/server-tests/tests/server_test.rs index 3a560aab..2ca38cdc 100644 --- a/server-tests/tests/server_test.rs +++ b/server-tests/tests/server_test.rs @@ -1,5 +1,5 @@ use helpers::ServerKind; -use linkup::{CreatePreviewRequest, StorableDomain, StorableService}; +use linkup::{CreatePreviewRequest, Domain, SessionService}; use reqwest::Url; use rstest::rstest; @@ -85,12 +85,12 @@ pub fn create_preview_request(fe_location: Option) -> String { None => "http://example.com".to_string(), }; let req = CreatePreviewRequest { - domains: vec![StorableDomain { + domains: vec![Domain { domain: "example.com".to_string(), default_service: "frontend".to_string(), routes: None, }], - services: vec![StorableService { + services: vec![SessionService { name: "frontend".to_string(), location: Url::parse(&location).unwrap(), rewrites: None, diff --git a/server-tests/tests/ws_test.rs b/server-tests/tests/ws_test.rs index 51469ed6..14cb0deb 100644 --- a/server-tests/tests/ws_test.rs +++ b/server-tests/tests/ws_test.rs @@ -1,8 +1,8 @@ use std::str::FromStr; +use axum::Router; use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade}; use axum::response::IntoResponse; -use axum::Router; use futures::{SinkExt, StreamExt}; use helpers::ServerKind; use http::{HeaderName, HeaderValue}; diff --git a/worker/Cargo.toml b/worker/Cargo.toml index 7b8cce64..6906a01c 100644 --- a/worker/Cargo.toml +++ b/worker/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "linkup-worker" version = "0.1.0" -edition = "2021" +edition = "2024" [lib] crate-type = ["cdylib"] diff --git a/worker/src/lib.rs b/worker/src/lib.rs index 8d12e559..8c2d3b54 100644 --- a/worker/src/lib.rs +++ b/worker/src/lib.rs @@ -1,24 +1,24 @@ use axum::{ + Router, body::to_bytes, extract::{Json, Query, Request, State}, http::StatusCode, - middleware::{from_fn_with_state, Next}, + middleware::{Next, from_fn_with_state}, response::IntoResponse, routing::{any, get, post}, - Router, }; use http::{HeaderMap, Uri}; use http_error::HttpError; use kv_store::CfWorkerStringStore; use linkup::{ - allow_all_cors, get_additional_headers, get_target_service, CreatePreviewRequest, NameKind, - Session, SessionAllocator, UpdateSessionRequest, Version, VersionChannel, + CreatePreviewRequest, NameKind, Session, SessionAllocator, UpdateSessionRequest, Version, + VersionChannel, allow_all_cors, get_additional_headers, get_target_service, }; use serde::{Deserialize, Serialize}; use tower_service::Service; use worker::{ - console_error, console_log, console_warn, event, kv::KvStore, Env, Fetch, HttpRequest, - HttpResponse, RequestRedirect, + Env, Fetch, HttpRequest, HttpResponse, RequestRedirect, console_error, console_log, + console_warn, event, kv::KvStore, }; use ws::handle_ws_resp; @@ -208,10 +208,10 @@ async fn linkup_session_handler( Ok(conf) => conf, Err(e) => { return HttpError::new( - format!("Failed to parse server config: {} - local server", e), + format!("Failed to parse server config: {} - Worker", e), StatusCode::BAD_REQUEST, ) - .into_response() + .into_response(); } }; @@ -226,7 +226,7 @@ async fn linkup_session_handler( format!("Failed to store server config: {}", e), StatusCode::INTERNAL_SERVER_ERROR, ) - .into_response() + .into_response(); } }; @@ -245,10 +245,10 @@ async fn linkup_preview_handler( Ok(conf) => conf, Err(e) => { return HttpError::new( - format!("Failed to parse server config: {} - local server", e), + format!("Failed to parse server config: {} - Worker", e), StatusCode::BAD_REQUEST, ) - .into_response() + .into_response(); } }; @@ -263,7 +263,7 @@ async fn linkup_preview_handler( format!("Failed to store server config: {}", e), StatusCode::INTERNAL_SERVER_ERROR, ) - .into_response() + .into_response(); } }; @@ -296,8 +296,7 @@ async fn linkup_request_handler( Ok(session) => session, Err(_) => { return HttpError::new( - "Linkup was unable to determine the session origin of the request. - Make sure your request includes a valid session ID in the referer or tracestate headers. - Local Server".to_string(), + "Linkup was unable to determine the session origin of the request.\nMake sure your request includes a valid session ID in the referer or tracestate headers. - Worker".to_string(), StatusCode::UNPROCESSABLE_ENTITY, ) .into_response() @@ -308,9 +307,7 @@ async fn linkup_request_handler( Some(result) => result, None => { return HttpError::new( - "The request belonged to a session, but there was no target for the request. - Check your routing rules in the linkup config for a match. - Local Server" - .to_string(), + "The request belonged to a session, but there was no target for the request.\nCheck your routing rules in the linkup config for a match. - Worker".to_string(), StatusCode::NOT_FOUND, ) .into_response() @@ -338,20 +335,20 @@ async fn linkup_request_handler( let cacheable_req = is_cacheable_request(&upstream_request, &config); let cache_key = get_cache_key(&upstream_request, &session_name).unwrap_or_default(); - if cacheable_req { - if let Some(upstream_response) = get_cached_req(cache_key.clone()).await { - let resp: HttpResponse = match upstream_response.try_into() { - Ok(resp) => resp, - Err(e) => { - return HttpError::new( - format!("Failed to parse cached response: {}", e), - StatusCode::BAD_GATEWAY, - ) - .into_response() - } - }; - return resp.into_response(); - } + + if cacheable_req && let Some(upstream_response) = get_cached_req(cache_key.clone()).await { + let resp: HttpResponse = match upstream_response.try_into() { + Ok(resp) => resp, + Err(e) => { + return HttpError::new( + format!("Failed to parse cached response: {}", e), + StatusCode::BAD_GATEWAY, + ) + .into_response(); + } + }; + + return resp.into_response(); } let mut upstream_response = match Fetch::Request(upstream_request).send().await { @@ -376,7 +373,7 @@ async fn linkup_request_handler( format!("Failed to clone response: {}", e), StatusCode::BAD_GATEWAY, ) - .into_response() + .into_response(); } }; if let Err(e) = set_cached_req(cache_key, cache_clone).await { @@ -513,7 +510,7 @@ async fn handle_http_resp(worker_resp: worker::Response) -> impl IntoResponse { format!("Failed to parse response: {}", e), StatusCode::BAD_GATEWAY, ) - .into_response() + .into_response(); } }; resp.headers_mut().extend(allow_all_cors()); diff --git a/worker/src/tunnel.rs b/worker/src/tunnel.rs index cb0ae5c0..14979239 100644 --- a/worker/src/tunnel.rs +++ b/worker/src/tunnel.rs @@ -154,13 +154,13 @@ pub async fn delete_tunnel( 0 => { return Err(DeleteTunnelError::GetDNSRecord( "Fetching DNS for tunnel returned empty".to_string(), - )) + )); } 1 => &records[0], 2.. => { return Err(DeleteTunnelError::GetDNSRecord( "Fetching DNS for tunnel returned more than one record".to_string(), - )) + )); } }; diff --git a/worker/src/ws.rs b/worker/src/ws.rs index cfec1f1a..c1d7d58e 100644 --- a/worker/src/ws.rs +++ b/worker/src/ws.rs @@ -3,7 +3,7 @@ use std::str::FromStr; use axum::{http::StatusCode, response::IntoResponse}; use http::{HeaderName, HeaderValue}; use linkup::allow_all_cors; -use worker::{console_log, Error, HttpResponse, WebSocket, WebSocketPair, WebsocketEvent}; +use worker::{Error, HttpResponse, WebSocket, WebSocketPair, WebsocketEvent, console_log}; use futures::{ future::{self, Either}, @@ -25,7 +25,7 @@ pub async fn handle_ws_resp(upstream_response: worker::Response) -> impl IntoRes format!("Failed to connect to destination: {}", e), StatusCode::BAD_GATEWAY, ) - .into_response() + .into_response(); } }; @@ -36,7 +36,7 @@ pub async fn handle_ws_resp(upstream_response: worker::Response) -> impl IntoRes format!("Failed to create source websocket: {}", e), StatusCode::INTERNAL_SERVER_ERROR, ) - .into_response() + .into_response(); } }; let downstream_ws_server = downstream_ws.server; @@ -102,7 +102,7 @@ pub async fn handle_ws_resp(upstream_response: worker::Response) -> impl IntoRes format!("Failed to create response from websocket: {}", e), StatusCode::INTERNAL_SERVER_ERROR, ) - .into_response() + .into_response(); } }; @@ -113,7 +113,7 @@ pub async fn handle_ws_resp(upstream_response: worker::Response) -> impl IntoRes format!("Failed to parse response: {}", e), StatusCode::BAD_GATEWAY, ) - .into_response() + .into_response(); } };