diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b53b1d8..e2fc866 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -52,7 +52,8 @@ jobs: run: cargo +nightly fmt -- --check - name: Clippy - run: cargo clippy --all-targets --all-features -- -D clippy::correctness -W clippy::style + # Note: --all-features is not used because `wasm` is mutually exclusive with `server` + run: cargo clippy --all-targets -- -D clippy::correctness -W clippy::style - name: Install cargo-nextest uses: taiki-e/install-action@nextest @@ -115,6 +116,59 @@ jobs: - name: Tests (integration) run: cargo test --no-default-features --features ${{ matrix.features }} -- --ignored + # WASM build check + wasm-build: + name: WASM Build + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + targets: wasm32-unknown-unknown + + - name: Install wasm-pack + run: cargo install wasm-pack@0.13.1 --locked + + - name: Cache cargo + uses: Swatinem/rust-cache@v2 + with: + shared-key: wasm + + - name: Create placeholder directories for rust-embed + run: | + mkdir -p ui/dist docs/out + echo 'Placeholder' > ui/dist/index.html + echo 'Placeholder' > docs/out/index.html + + - name: Build WASM module + run: ./scripts/build-wasm.sh --release + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: "20" + + - name: Setup pnpm + uses: pnpm/action-setup@v4 + with: + version: 9 + + - name: Install UI dependencies + working-directory: ui + run: pnpm install --frozen-lockfile + + - name: Generate API client + working-directory: ui + run: pnpm run generate-api + + - name: Build frontend (WASM mode) + working-directory: ui + run: pnpm build + env: + VITE_WASM_MODE: "true" + # Cross-platform builds cross-build: name: Cross Build (${{ matrix.target }}) diff --git a/.github/workflows/deploy-wasm.yml b/.github/workflows/deploy-wasm.yml new file mode 100644 index 0000000..8144270 --- /dev/null +++ b/.github/workflows/deploy-wasm.yml @@ -0,0 +1,94 @@ +name: Deploy WASM App + +on: + push: + branches: [main] + paths: + - "src/**" + - "ui/**" + - "Cargo.toml" + - "Cargo.lock" + - "scripts/build-wasm.sh" + - ".github/workflows/deploy-wasm.yml" + workflow_dispatch: + +concurrency: + group: deploy-wasm + cancel-in-progress: true + +jobs: + deploy: + name: Build & Deploy + runs-on: ubuntu-latest + permissions: + contents: read + deployments: write + steps: + - uses: actions/checkout@v4 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + targets: wasm32-unknown-unknown + + - name: Install wasm-pack + run: cargo install wasm-pack@0.13.1 --locked + + - name: Cache cargo + uses: Swatinem/rust-cache@v2 + with: + shared-key: wasm + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: "20" + + - name: Setup pnpm + uses: pnpm/action-setup@v4 + with: + version: 9 + + - name: Get pnpm store directory + shell: bash + run: echo "STORE_PATH=$(pnpm store path --silent)" >> $GITHUB_ENV + + - name: Cache pnpm + uses: actions/cache@v4 + with: + path: ${{ env.STORE_PATH }} + key: ${{ runner.os }}-pnpm-wasm-${{ hashFiles('ui/pnpm-lock.yaml') }} + restore-keys: | + ${{ runner.os }}-pnpm-wasm- + + - name: Install UI dependencies + working-directory: ui + run: pnpm install --frozen-lockfile + + - name: Generate API client + working-directory: ui + run: pnpm run generate-api + + - name: Build WASM module + run: ./scripts/build-wasm.sh --release + + - name: Build frontend (WASM mode) + working-directory: ui + run: pnpm build + env: + VITE_WASM_MODE: "true" + + - name: Add headers for service worker + run: | + cat > ui/dist/_headers <<'EOF' + /sw.js + Service-Worker-Allowed: / + Cache-Control: no-cache, no-store, must-revalidate + EOF + + - name: Deploy to Cloudflare Pages + uses: cloudflare/wrangler-action@v3 + with: + apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }} + accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }} + command: pages deploy ui/dist --project-name=hadrian diff --git a/.greptile/config.json b/.greptile/config.json new file mode 100644 index 0000000..b0448dd --- /dev/null +++ b/.greptile/config.json @@ -0,0 +1,7 @@ +{ + "strictness": 1, + "commentTypes": ["syntax", "logic", "style", "info"], + "fileChangeLimit": 300, + "triggerOnUpdates": true, + "fixWithAI": true +} diff --git a/CLAUDE.md b/CLAUDE.md index f54bd09..3968af7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -26,6 +26,7 @@ Features: - Image generation, audio (TTS, transcription, translation) - Knowledge Bases / RAG: file upload, text extraction, chunking, vector search, re-ranking - Integrations: SQLite/Postgres, Redis, OpenTelemetry, Vault, S3 +- WASM build: runs entirely in the browser via service workers and sql.js (app.hadriangateway.com) The backend is written in Rust and uses Axum for routing and middleware. The frontend is written in React and TypeScript, with TailwindCSS for styling. @@ -84,6 +85,7 @@ Hierarchical feature profiles (default: `full`): - **`standard`** — minimal + Postgres, Redis, OTLP, Prometheus, SSO, CEL, doc extraction, OpenAPI docs, S3, secrets managers (AWS/Azure/GCP/Vault) - **`full`** — standard + SAML, Kreuzberg, ClamAV - **`headless`** — all `full` features except embedded assets (UI, docs, catalog). Used by `cargo install` and for deployments that serve the frontend separately. +- **`wasm`** — Browser-only build targeting `wasm32-unknown-unknown`. OpenAI + Anthropic + Test providers, wasm-sqlite (sql.js FFI), no server/concurrency/CLI/JWT/SSO features. Built with `wasm-pack`. ```bash cargo build --no-default-features --features tiny # Smallest binary @@ -91,6 +93,8 @@ cargo build --no-default-features --features minimal # Fast compile cargo build --no-default-features --features standard # Typical deployment cargo build # Full (default) cargo build --no-default-features --features headless # Full features, no embedded assets +./scripts/build-wasm.sh # WASM build (dev) +./scripts/build-wasm.sh --release # WASM build (release) ``` Run `hadrian features` to list enabled/disabled features at runtime. CI tests `minimal`, `standard`, and `headless` profiles; Windows uses `minimal` to avoid OpenSSL. @@ -116,6 +120,7 @@ GitHub Actions workflow (`.github/workflows/ci.yml`) runs: - E2E tests (TypeScript/Playwright with testcontainers, needs Docker build) - OpenAPI conformance check - Documentation build +- WASM build (compile to `wasm32-unknown-unknown` via `wasm-pack`, build frontend with `VITE_WASM_MODE=true`) ### Release Pipeline @@ -130,6 +135,12 @@ GitHub Actions workflow (`.github/workflows/release.yml`) triggers on version ta - Creates GitHub Release with archives and SHA256 checksums (tag push only) - Dry-run mode builds artifacts and prints a summary without creating a release +WASM deploy workflow (`.github/workflows/deploy-wasm.yml`): +- Triggers on pushes to `main` touching `src/**`, `ui/**`, `Cargo.toml`, `Cargo.lock`, or `scripts/build-wasm.sh` +- Builds WASM module + frontend with `VITE_WASM_MODE=true` +- Deploys to Cloudflare Pages (app.hadriangateway.com) +- Sets `Service-Worker-Allowed: /` and `Cache-Control: no-cache` headers on `sw.js` + Helm chart workflow (`.github/workflows/helm.yml`) runs: - `helm lint` (standard and strict mode) - `helm template` with matrix of configurations (PostgreSQL, Redis, Ingress, etc.) @@ -215,6 +226,36 @@ Per-org SSO allows each organization to configure its own identity provider (OID 4. **LLM Provider** forwards request, streams response 5. **Usage Tracking** records tokens/cost asynchronously with full principal attribution (user, org, project, team, service account) +### WASM Build Architecture + +The WASM build runs the full Hadrian Axum router inside a browser service worker, enabling a zero-backend deployment at app.hadriangateway.com. + +**Request flow:** +1. Service worker intercepts `fetch` events matching `/v1/`, `/admin/v1/`, `/health`, `/auth/`, `/api/` +2. `web_sys::Request` is converted to `http::Request` (with `/api/v1/` → `/v1/` path rewriting) +3. Request is dispatched through the same Axum `Router` used by the native server +4. `http::Response` is converted back to `web_sys::Response` +5. LLM API calls use `reqwest` which delegates to the browser's `fetch()` API + +**Three-layer gating strategy:** +1. **Cargo features** (`wasm` vs `server`) — Controls what modules/dependencies are included +2. **`#[cfg(target_arch = "wasm32")]`** — Handles Send/Sync differences (`AssertSend`, `async_trait(?Send)`, `spawn_local` vs `tokio::spawn`) +3. **`#[cfg(feature = "server")]`** / `#[cfg(feature = "concurrency")]`** — Gates server-only functionality (middleware layers, `TaskTracker`, `UsageLogBuffer`) + +**Database:** `WasmSqlitePool` is a zero-size type; actual SQLite runs in JavaScript via sql.js. Queries cross the FFI boundary via `wasm_bindgen` extern functions. The `backend.rs` abstraction provides cfg-switched type aliases (`Pool`, `Row`, `BackendError`) and traits (`ColDecode`, `RowExt`) so SQLite repo code compiles against either `sqlx::SqlitePool` or `WasmSqlitePool` without changes. + +**Persistence:** Database is persisted to IndexedDB with a debounced save (500ms) after write operations. + +**Auth:** WASM mode uses `AuthMode::None` with a bootstrapped anonymous user and org. Permissive `AuthzContext` and `AdminAuth` extensions are injected as layers. + +**Setup flow:** `WasmSetupGuard` detects if providers are configured; if not, shows a setup wizard (`WasmSetup`) supporting OpenRouter OAuth (PKCE), Ollama auto-detection, and manual API key entry for OpenAI/Anthropic/etc. + +**Known limitations:** +- Streaming responses are fully buffered (no real-time SSE token streaming for LLM calls) +- No usage tracking (no `TaskTracker`/`UsageLogBuffer` in WASM) +- No caching layer, rate limiting, or budget enforcement +- Module service workers require Chrome 91+ / Edge 91+ (Firefox support may be limited) + ### Document Processing Flow (RAG) 1. **File Upload** (`POST /v1/files`) — Store raw file in database @@ -469,6 +510,17 @@ See `agent_instructions/adding_admin_endpoint.md` for implementation patterns (r - `src/validation/` — Response validation against OpenAI schema - `src/observability/siem/` — SIEM formatters +### Backend — WASM + +- `src/wasm.rs` — WASM entry point: `HadrianGateway` struct, request/response conversion, router construction, default config +- `src/compat.rs` — WASM compatibility: `AssertSend`, `WasmHandler`, `wasm_routing` module (drop-in replacements for `axum::routing`), `spawn_detached`, `impl_wasm_handler!` macro +- `src/lib.rs` — Library exports (crate type `cdylib` + `rlib` for wasm-pack) +- `src/db/wasm_sqlite/bridge.rs` — `wasm_bindgen` FFI to `globalThis.__hadrian_sqlite` (sql.js bridge) +- `src/db/wasm_sqlite/types.rs` — `WasmParam`, `WasmValue`, `WasmRow`, `WasmDecode` trait with type conversions +- `src/db/sqlite/backend.rs` — SQLite backend abstraction: cfg-switched `Pool`/`Row`/`BackendError` type aliases, `RowExt`/`ColDecode` traits for unified repo code +- `src/middleware/types.rs` — Shared middleware types (`AuthzContext`, `AdminAuth`, `ClientInfo`) extracted from layers for WASM compatibility +- `scripts/build-wasm.sh` — Build script (invokes `wasm-pack`, copies sql-wasm.wasm) + ### Backend — Other - `src/catalog/` — Model catalog registry @@ -508,6 +560,17 @@ See `agent_instructions/adding_admin_endpoint.md` for implementation patterns (r - `ui/src/components/ToolExecution/` — Tool execution timeline UI - `ui/src/components/Artifact/` — Artifact rendering (charts, tables, images, code) +### Frontend — WASM / Service Worker + +- `ui/src/service-worker/sw.ts` — Service worker: intercepts API calls, lazily initializes `HadrianGateway` WASM module, routes requests through Axum router +- `ui/src/service-worker/sqlite-bridge.ts` — sql.js bridge: `globalThis.__hadrian_sqlite` with `init_database()`, `query()`, `execute()`, `execute_script()`; persists to IndexedDB with debounced save +- `ui/src/service-worker/register.ts` — Service worker registration with `CLAIM` message handling for hard refreshes +- `ui/src/service-worker/wasm.d.ts` — Type declarations for the WASM module exports +- `ui/src/components/WasmSetup/WasmSetup.tsx` — Three-step setup wizard (welcome → providers → done) with OpenRouter OAuth, Ollama detection, manual API key entry +- `ui/src/components/WasmSetup/WasmSetupGuard.tsx` — Guard component: auto-shows wizard when no providers configured, handles OAuth callback +- `ui/src/components/WasmSetup/openrouter-oauth.ts` — OpenRouter OAuth PKCE flow (code verifier in sessionStorage) +- `ui/src/routes/AppRoutes.tsx` — Routes extracted from App.tsx + ### Frontend — Pages & Layout - `ui/src/pages/studio/` — Studio feature (image gen, TTS, transcription) @@ -561,6 +624,30 @@ pnpm test-storybook # Run Storybook tests with vitest pnpm openapi-ts # Regenerate from /api/openapi.json ``` +### WASM Frontend Development + +The WASM mode is controlled by the `VITE_WASM_MODE=true` environment variable. When set: +- The Vite dev server uses a custom service worker plugin instead of `VitePWA` +- The proxy configuration is disabled (service worker handles API routing) +- `main.tsx` registers the service worker before rendering React +- `App.tsx` wraps the app in `WasmSetupGuard` + +```bash +# Build WASM module first (from repo root) +./scripts/build-wasm.sh + +# Then run frontend in WASM mode +cd ui && VITE_WASM_MODE=true pnpm dev +``` + +The service worker (`sw.ts`) is built separately from the Vite bundle using esbuild (via the custom `wasmServiceWorkerPlugin` in `vite.config.ts`). In dev mode it's compiled on each request; in production it's written to `dist/sw.js` during the `writeBundle` hook. + +When modifying WASM-related code: +- The `wasm_routing` module (`src/compat.rs`) provides drop-in replacements for `axum::routing::{get, post, put, patch, delete}` — route modules use cfg-switched imports +- All async trait definitions use `#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]` / `#[cfg_attr(not(target_arch = "wasm32"), async_trait)]` +- The `backend.rs` abstraction means SQLite repo code is written once — modify repos normally and both native/WASM will compile +- Server-only routes (multipart file upload, audio transcription/translation) are excluded with `#[cfg(feature = "server")]` + ### Frontend Conventions - Run the `./scripts/generate-openapi.sh` script to generate the OpenAPI client diff --git a/Cargo.lock b/Cargo.lock index 9f7c288..0222250 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3541,6 +3541,7 @@ dependencies = [ "flate2", "futures", "futures-util", + "getrandom 0.2.17", "google-cloud-auth 0.17.2", "google-cloud-secretmanager-v1", "google-cloud-token", @@ -3550,6 +3551,7 @@ dependencies = [ "http 1.4.0", "http-body-util", "ipnet", + "js-sys", "jsonschema", "jsonwebtoken", "kreuzberg", @@ -3573,6 +3575,7 @@ dependencies = [ "samael", "schemars 0.8.22", "serde", + "serde-wasm-bindgen", "serde_json", "serial_test", "sha2", @@ -3595,6 +3598,7 @@ dependencies = [ "tracing", "tracing-opentelemetry", "tracing-subscriber", + "tracing-wasm", "unicode-normalization", "url", "utoipa", @@ -3602,6 +3606,10 @@ dependencies = [ "uuid", "validator", "vaultrs", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams", + "web-sys", "wiremock", ] @@ -6784,9 +6792,9 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.11.13" +version = "0.11.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" +checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098" dependencies = [ "aws-lc-rs", "bytes", @@ -7908,6 +7916,17 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-wasm-bindgen" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8302e169f0eddcc139c70f139d19d6467353af16f9fce27e8c30158036a1e16b" +dependencies = [ + "js-sys", + "serde", + "wasm-bindgen", +] + [[package]] name = "serde_core" version = "1.0.228" @@ -9480,6 +9499,17 @@ dependencies = [ "tracing-serde", ] +[[package]] +name = "tracing-wasm" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4575c663a174420fa2d78f4108ff68f65bf2fbb7dd89f33749b6e826b3626e07" +dependencies = [ + "tracing", + "tracing-subscriber", + "wasm-bindgen", +] + [[package]] name = "try-lock" version = "0.2.5" diff --git a/Cargo.toml b/Cargo.toml index 4225448..28dc294 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,11 +20,62 @@ include = [ "README.md", ] +[lib] +name = "hadrian" +path = "src/lib.rs" +crate-type = ["cdylib", "rlib"] + [features] default = ["full"] +# ───────────────────────────────────────────────────────────────────────────── +# Execution environment features +# ───────────────────────────────────────────────────────────────────────────── + +# CLI argument parsing (clap) +cli = ["dep:clap"] + +# Native server: socket binding, filesystem serving, config file loading +server = [ + "cli", + "native-http", + "native-async", + "concurrency", + "jwt", + "dep:toml", + "dep:tracing-subscriber", + "axum/tokio", + "axum/http1", + "axum/http2", + "axum/ws", + "axum/multipart", + "tower-http/fs", + "tokio/full", + "tokio-util/rt", +] + +# Native HTTP client features (TLS, HTTP/2) +native-http = [ + "reqwest/rustls-tls", + "reqwest/http2", + "reqwest/charset", + "reqwest/macos-system-configuration", +] + +# Native async runtime features (filesystem, networking, signals) +native-async = ["tokio/net", "tokio/fs", "tokio/signal", "tokio/process"] + +# High-performance concurrent data structures (parking_lot, crossbeam) +concurrency = ["dep:parking_lot", "dep:crossbeam-channel"] + +# JWT authentication (jsonwebtoken requires native crypto) +jwt = ["dep:jsonwebtoken"] + +# ───────────────────────────────────────────────────────────────────────────── # Meta profiles for different deployment scenarios -tiny = ["provider-openai", "provider-test"] +# ───────────────────────────────────────────────────────────────────────────── + +tiny = ["server", "provider-openai", "provider-test"] minimal = [ "tiny", "database-sqlite", @@ -68,6 +119,7 @@ full = [ # Suitable for `cargo install` users who don't have build artifacts, # and for deployments that serve the frontend separately. headless = [ + "server", "cel", "csv-export", "database-postgres", @@ -98,6 +150,28 @@ headless = [ "wizard", ] +# ───────────────────────────────────────────────────────────────────────────── +# WASM browser build — runs entirely in the browser via service worker +# ───────────────────────────────────────────────────────────────────────────── + +wasm = [ + "provider-openai", + "provider-anthropic", + "provider-test", + "database-wasm-sqlite", + "dep:wasm-bindgen", + "dep:wasm-bindgen-futures", + "dep:js-sys", + "dep:web-sys", + "dep:serde-wasm-bindgen", + "dep:wasm-streams", + "dep:tracing-wasm", +] + +# ───────────────────────────────────────────────────────────────────────────── +# Component features +# ───────────────────────────────────────────────────────────────────────────── + # Providers (always-available: openai, anthropic, test) provider-openai = [] provider-anthropic = [] @@ -139,6 +213,11 @@ embed-catalog = ["dep:rust-embed"] # Databases database-sqlite = ["dep:sqlx", "sqlx/sqlite"] database-postgres = ["dep:sqlx", "sqlx/postgres"] +database-wasm-sqlite = [ + "dep:wasm-bindgen", + "dep:js-sys", + "dep:serde-wasm-bindgen", +] # Authorization cel = ["dep:cel-interpreter"] @@ -175,15 +254,17 @@ otlp = [ virus-scan = ["dep:clamav-client"] [dependencies] -# Mandatory dependencies +# ───────────────────────────────────────────────────────────────────────────── +# Always-required dependencies (work on both native and wasm32) +# ───────────────────────────────────────────────────────────────────────────── async-trait = "0.1.89" -axum = { version = "0.8.7", features = ["ws", "multipart"] } +axum = { version = "0.8.7", default-features = false, features = [ + "json", "matched-path", "original-uri", "query", "form", "tracing", +] } axum-valid = "0.24.0" base64 = "0.22" bytes = "1.11.0" chrono = { version = "0.4.39", features = ["serde"] } -clap = { version = "4.5.53", features = ["derive"] } -crossbeam-channel = "0.5" dashmap = "6.0" futures = "0.3.31" futures-util = "0.3.31" @@ -191,32 +272,61 @@ hex = "0.4" http = "1.3.1" http-body-util = "0.1.3" ipnet = "2" -jsonwebtoken = { version = "9", features = ["use_pem"] } once_cell = "1.21" -parking_lot = "0.12.5" rand = "0.8" regex = "1.12.2" -reqwest = { version = "0.12.24", default-features = false, features = ["json", "stream", "rustls-tls", "http2", "charset", "macos-system-configuration", "multipart"] } +reqwest = { version = "0.12.24", default-features = false, features = [ + "json", "stream", "multipart", +] } rust_decimal = { version = "1.40.0", features = ["macros"] } serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.145" sha2 = "0.10" subtle = "2.6.1" thiserror = "2.0.17" -tokio = { version = "1.48.0", features = ["full"] } -tokio-util = { version = "0.7.17", features = ["rt"] } -toml = "0.9.8" +tokio = { version = "1.48.0", features = [ + "rt", "macros", "sync", "time", "io-util", +] } +tokio-util = { version = "0.7.17" } tower = "0.5.2" tower-cookies = "0.11" -tower-http = { version = "0.6", features = ["cors", "trace", "request-id", "propagate-header", "fs", "set-header", "limit"] } +tower-http = { version = "0.6", features = [ + "cors", "trace", "request-id", "propagate-header", "set-header", "limit", +] } tracing = "0.1.41" -tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] } unicode-normalization = "0.1" url = "2" uuid = { version = "1.18.1", features = ["v4", "v5", "serde"] } validator = { version = "0.20.0", features = ["derive"] } -# Optional dependencies +# ───────────────────────────────────────────────────────────────────────────── +# Optional: native-only dependencies (made optional for WASM compatibility) +# ───────────────────────────────────────────────────────────────────────────── +clap = { version = "4.5.53", features = ["derive"], optional = true } +crossbeam-channel = { version = "0.5", optional = true } +jsonwebtoken = { version = "9", features = ["use_pem"], optional = true } +parking_lot = { version = "0.12.5", optional = true } +toml = { version = "0.9.8", optional = true } +tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"], optional = true } + +# ───────────────────────────────────────────────────────────────────────────── +# Optional: WASM-specific dependencies +# ───────────────────────────────────────────────────────────────────────────── +js-sys = { version = "0.3", optional = true } +serde-wasm-bindgen = { version = "0.6", optional = true } +tracing-wasm = { version = "0.2", optional = true } +wasm-bindgen = { version = "0.2", optional = true } +wasm-bindgen-futures = { version = "0.4", optional = true } +wasm-streams = { version = "0.4", optional = true } +web-sys = { version = "0.3", features = [ + "Headers", "Request", "RequestInit", "Response", "ResponseInit", + "ReadableStream", "ServiceWorkerGlobalScope", "FetchEvent", + "ExtendableEvent", "Url", +], optional = true } + +# ───────────────────────────────────────────────────────────────────────────── +# Optional: feature-gated dependencies +# ───────────────────────────────────────────────────────────────────────────── augurs = { version = "0.10.1", features = ["ets", "mstl", "forecaster"], optional = true } aws-config = { version = "1", features = ["behavior-version-latest"], optional = true } aws-credential-types = { version = "1", features = ["hardcoded-credentials"], optional = true } @@ -262,6 +372,13 @@ utoipa = { version = "5", features = ["chrono", "uuid", "axum_extras"], optional utoipa-scalar = { version = "0.3", features = ["axum"], optional = true } vaultrs = { version = "0.7.4", features = ["rustls"], optional = true } +# ───────────────────────────────────────────────────────────────────────────── +# Target-specific: WASM needs JS-backed getrandom for uuid/rand +# ───────────────────────────────────────────────────────────────────────────── +[target.'cfg(target_arch = "wasm32")'.dependencies] +getrandom = { version = "0.2", features = ["js"] } +uuid = { version = "1.18.1", features = ["v4", "v5", "serde", "js"] } + [dev-dependencies] rstest = "0.24" serial_test = "3.2" diff --git a/Dockerfile b/Dockerfile index f53f6c7..27d4d68 100644 --- a/Dockerfile +++ b/Dockerfile @@ -71,7 +71,8 @@ COPY Cargo.toml Cargo.lock ./ # Create dummy src to build dependencies RUN mkdir -p src/bin \ && echo "fn main() {}" > src/main.rs \ - && echo "fn main() {}" > src/bin/record_fixtures.rs + && echo "fn main() {}" > src/bin/record_fixtures.rs \ + && echo "" > src/lib.rs # Build dependencies only (cached layer) RUN --mount=type=cache,target=/usr/local/cargo/registry \ diff --git a/docs/app/(home)/page.tsx b/docs/app/(home)/page.tsx index 2f5a5bd..609826c 100644 --- a/docs/app/(home)/page.tsx +++ b/docs/app/(home)/page.tsx @@ -243,11 +243,19 @@ export default function HomePage() { MIT and Apache-2.0 licensed. No proprietary code, no upgrade tiers, no restrictions.

- + Try in Browser + + Get Started Building Hadrian WASM module (profile: $PROFILE)" + +# Ensure wasm32 target is installed +if ! rustup target list --installed | grep -q wasm32-unknown-unknown; then + echo "==> Installing wasm32-unknown-unknown target" + rustup target add wasm32-unknown-unknown +fi + +# Build with wasm-pack +# --dev skips wasm-opt (avoids bulk-memory feature mismatch) +# --release runs wasm-opt for size optimization +cd "$ROOT_DIR" +wasm-pack build \ + --target web \ + --out-dir "$OUT_DIR" \ + $WASM_PACK_FLAGS \ + -- \ + --no-default-features \ + --features wasm + +# Copy sql.js WASM binary alongside the Hadrian WASM output. +# The sqlite-bridge.ts service worker code loads it from /wasm/sql-wasm.wasm. +SQLJS_WASM="$ROOT_DIR/ui/node_modules/sql.js/dist/sql-wasm.wasm" +if [ -f "$SQLJS_WASM" ]; then + cp "$SQLJS_WASM" "$OUT_DIR/sql-wasm.wasm" + echo "==> Copied sql-wasm.wasm to $OUT_DIR" +else + echo "WARNING: sql-wasm.wasm not found at $SQLJS_WASM — run 'pnpm install' in ui/ first" +fi + +echo "==> WASM build complete: $OUT_DIR" +echo " Files:" +ls -lh "$OUT_DIR"/*.wasm "$OUT_DIR"/*.js 2>/dev/null || true diff --git a/src/app.rs b/src/app.rs index 1a76833..dae0bca 100644 --- a/src/app.rs +++ b/src/app.rs @@ -6,22 +6,25 @@ use axum::Json; use axum::response::Response; #[cfg(any(feature = "sso", feature = "saml"))] use axum::routing::post; +#[cfg(feature = "server")] use axum::{Router, routing::get}; #[cfg(any(feature = "embed-ui", feature = "embed-docs"))] use axum::{body::Body, response::IntoResponse}; #[cfg(any(feature = "embed-ui", feature = "embed-docs"))] use http::StatusCode; +#[cfg(any(feature = "server", feature = "embed-ui", feature = "embed-docs"))] use http::header; use reqwest::Client; #[cfg(any(feature = "embed-ui", feature = "embed-docs"))] use rust_embed::Embed; +#[cfg(feature = "server")] use tokio_util::task::TaskTracker; -use tower_http::{ - limit::RequestBodyLimitLayer, - services::{ServeDir, ServeFile}, - set_header::SetResponseHeaderLayer, - trace::TraceLayer, -}; +#[cfg(feature = "server")] +use tower_http::services::{ServeDir, ServeFile}; +#[cfg(feature = "server")] +use tower_http::set_header::SetResponseHeaderLayer; +#[cfg(feature = "server")] +use tower_http::{limit::RequestBodyLimitLayer, trace::TraceLayer}; #[cfg(feature = "utoipa")] use utoipa_scalar::{Scalar, Servable}; @@ -31,9 +34,11 @@ use crate::observability; use crate::openapi; use crate::{ auth, authz, cache, catalog, config, db, dlq, events, guardrails, - init::create_provider_instance, jobs, middleware, models, pricing, providers, routes, secrets, - services, usage_buffer, + init::create_provider_instance, jobs, models, pricing, providers, secrets, services, + usage_buffer, }; +#[cfg(feature = "server")] +use crate::{middleware, routes}; /// Embedded UI assets from ui/dist directory. /// These are compiled into the binary at build time. @@ -89,12 +94,14 @@ fn serve_embedded_file(path: &str) -> Response { } /// Add routes for serving static UI files +#[cfg(feature = "server")] fn add_ui_routes(app: Router, config: &config::GatewayConfig) -> Router { use config::AssetSource; let ui_path = config.ui.path.trim_end_matches('/'); match &config.ui.assets.source { + #[cfg(feature = "server")] AssetSource::Filesystem { path } => { let assets_path = std::path::Path::new(path); let index_file = assets_path.join("index.html"); @@ -128,6 +135,13 @@ fn add_ui_routes(app: Router, config: &config::GatewayConfig) -> Route app.nest_service(ui_path, serve_dir_with_headers) } } + #[cfg(not(feature = "server"))] + AssetSource::Filesystem { .. } => { + tracing::warn!( + "Filesystem UI assets requested but 'server' feature is not enabled, skipping" + ); + app + } #[cfg(feature = "embed-ui")] AssetSource::Embedded => { tracing::info!(ui_path = %ui_path, "Serving UI from embedded assets"); @@ -218,12 +232,14 @@ fn build_docs_response(content: rust_embed::EmbeddedFile) -> Response { } /// Add routes for serving static documentation files +#[cfg(feature = "server")] fn add_docs_routes(app: Router, config: &config::GatewayConfig) -> Router { use config::AssetSource; let docs_path = config.docs.path.trim_end_matches('/'); match &config.docs.assets.source { + #[cfg(feature = "server")] AssetSource::Filesystem { path } => { let assets_path = std::path::Path::new(path); @@ -251,6 +267,13 @@ fn add_docs_routes(app: Router, config: &config::GatewayConfig) -> Rou // Docs are always at a specific path (never root) app.nest_service(docs_path, serve_dir_with_headers) } + #[cfg(not(feature = "server"))] + AssetSource::Filesystem { .. } => { + tracing::warn!( + "Filesystem docs assets requested but 'server' feature is not enabled, skipping" + ); + app + } #[cfg(feature = "embed-docs")] AssetSource::Embedded => { tracing::info!(docs_path = %docs_path, "Serving documentation from embedded assets"); @@ -295,6 +318,7 @@ pub struct AppState { pub provider_health: jobs::ProviderHealthStateRegistry, /// Task tracker for background tasks (usage logging, etc.) /// Ensures all spawned tasks complete during graceful shutdown. + #[cfg(feature = "server")] pub task_tracker: TaskTracker, /// Registry of per-organization OIDC authenticators. /// Loaded from org_sso_configs table at startup for multi-tenant SSO. @@ -306,12 +330,14 @@ pub struct AppState { pub saml_registry: Option>, /// Registry of per-org gateway JWT validators. /// Routes incoming JWTs to the correct org-scoped validator by issuer. + #[cfg(feature = "jwt")] pub gateway_jwt_registry: Option>, /// Registry of per-organization RBAC policies. /// Loaded from org_rbac_policies table at startup for per-org authorization. pub policy_registry: Option>, /// Async buffer for usage log entries. /// Batches writes to reduce database pressure. + #[cfg(feature = "concurrency")] pub usage_buffer: Option>, /// Response cache for chat completions. /// Caches deterministic responses to reduce latency and costs. @@ -753,6 +779,7 @@ impl AppState { // Initialize per-org gateway JWT registry for multi-tenant JWT auth on /v1/*. // Validators are pre-loaded in a background task so server startup isn't // blocked by N sequential OIDC discovery HTTP requests. + #[cfg(feature = "jwt")] let gateway_jwt_registry = if db.is_some() { Some(Arc::new(auth::GatewayJwtRegistry::new())) } else { @@ -831,6 +858,7 @@ impl AppState { }; // Initialize usage log buffer with configured buffer settings and EventBus + #[cfg(feature = "concurrency")] let usage_buffer = { let buffer_config = usage_buffer::UsageBufferConfig::from(&config.observability.usage.buffer); @@ -866,9 +894,11 @@ impl AppState { }; // Create the task tracker for background tasks + #[cfg(feature = "server")] let task_tracker = TaskTracker::new(); // Initialize semantic cache if configured + #[cfg(feature = "server")] let semantic_cache = Self::init_semantic_cache( &config, cache.as_ref(), @@ -878,6 +908,8 @@ impl AppState { &task_tracker, ) .await; + #[cfg(not(feature = "server"))] + let semantic_cache: Option> = None; // Initialize input guardrails if configured let input_guardrails = match &config.features.guardrails { @@ -1038,13 +1070,16 @@ impl AppState { pricing, circuit_breakers, provider_health: jobs::ProviderHealthStateRegistry::new(), + #[cfg(feature = "server")] task_tracker, #[cfg(feature = "sso")] oidc_registry, #[cfg(feature = "saml")] saml_registry, + #[cfg(feature = "jwt")] gateway_jwt_registry, policy_registry, + #[cfg(feature = "concurrency")] usage_buffer, response_cache, semantic_cache, @@ -1067,7 +1102,7 @@ impl AppState { /// Ensure a default user exists for anonymous access when auth is disabled. /// Uses a well-known external_id so the same user is used across restarts. /// Race-safe: tries to create first, falls back to lookup on conflict. - async fn ensure_default_user( + pub(crate) async fn ensure_default_user( services: &services::Services, ) -> Result> { use crate::db::DbError; @@ -1099,7 +1134,7 @@ impl AppState { /// Ensure a default organization exists for anonymous access when auth is disabled. /// Uses a well-known slug so the same organization is used across restarts. /// Race-safe: tries to create first, falls back to lookup on conflict. - async fn ensure_default_org( + pub(crate) async fn ensure_default_org( services: &services::Services, ) -> Result> { use crate::db::DbError; @@ -1128,7 +1163,7 @@ impl AppState { } /// Ensure the default user is a member of the default organization. - async fn ensure_default_org_membership( + pub(crate) async fn ensure_default_org_membership( services: &services::Services, user_id: uuid::Uuid, org_id: uuid::Uuid, @@ -1161,6 +1196,7 @@ impl AppState { /// Initialize semantic cache if configured. /// /// Spawns the background embedding worker on the provided task tracker. + #[cfg(feature = "server")] async fn init_semantic_cache( config: &config::GatewayConfig, cache: Option<&Arc>, @@ -1782,6 +1818,7 @@ impl AppState { } } +#[cfg(feature = "server")] pub fn build_app(config: &config::GatewayConfig, state: AppState) -> Router { let mut app = Router::new() // Health check endpoint @@ -1960,6 +1997,7 @@ pub fn build_app(config: &config::GatewayConfig, state: AppState) -> Router { } // Add WebSocket route for real-time event subscriptions if enabled + #[cfg(feature = "server")] if config.features.websocket.enabled { app = app.route("/ws/events", get(routes::ws_handler)); tracing::info!("WebSocket event subscriptions enabled at /ws/events"); diff --git a/src/auth/gateway_jwt.rs b/src/auth/gateway_jwt.rs index fd659e6..8170cf3 100644 --- a/src/auth/gateway_jwt.rs +++ b/src/auth/gateway_jwt.rs @@ -44,6 +44,12 @@ pub struct GatewayJwtRegistry { load_mutex: Mutex<()>, } +impl Default for GatewayJwtRegistry { + fn default() -> Self { + Self::new() + } +} + impl GatewayJwtRegistry { /// Create an empty registry. pub fn new() -> Self { @@ -237,6 +243,11 @@ impl GatewayJwtRegistry { pub async fn len(&self) -> usize { self.inner.read().await.validators.len() } + + /// Whether the registry has no validators. + pub async fn is_empty(&self) -> bool { + self.inner.read().await.validators.is_empty() + } } /// Clean up issuer index entries for a given org_id. Operates on `&mut RegistryInner` diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 41fd7d1..2d48d05 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -1,8 +1,10 @@ #[cfg(feature = "sso")] mod discovery; mod error; +#[cfg(feature = "jwt")] pub mod gateway_jwt; mod identity; +#[cfg(feature = "jwt")] pub mod jwt; #[cfg(feature = "sso")] pub mod oidc; @@ -19,6 +21,7 @@ pub mod session_store; #[cfg(feature = "sso")] pub use discovery::fetch_jwks_uri; pub use error::AuthError; +#[cfg(feature = "jwt")] pub use gateway_jwt::GatewayJwtRegistry; pub use identity::{ApiKeyAuth, AuthenticatedRequest, Identity, IdentityKind}; #[cfg(feature = "sso")] diff --git a/src/auth/session_store.rs b/src/auth/session_store.rs index 9c5b5c4..c432e71 100644 --- a/src/auth/session_store.rs +++ b/src/auth/session_store.rs @@ -242,7 +242,8 @@ impl AuthorizationState { /// Trait for OIDC session storage. /// /// Implementations must be thread-safe and handle concurrent access. -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait SessionStore: Send + Sync { /// Store a new session. async fn create_session(&self, session: OidcSession) -> SessionResult; @@ -319,7 +320,8 @@ impl Default for MemorySessionStore { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl SessionStore for MemorySessionStore { async fn create_session(&self, session: OidcSession) -> SessionResult { let id = session.id; @@ -462,7 +464,8 @@ impl CacheSessionStore { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl SessionStore for CacheSessionStore { async fn create_session(&self, session: OidcSession) -> SessionResult { let id = session.id; diff --git a/src/authz/engine.rs b/src/authz/engine.rs index 797b6f5..92baf98 100644 --- a/src/authz/engine.rs +++ b/src/authz/engine.rs @@ -311,12 +311,7 @@ pub struct TimeContext { impl TimeContext { /// Create a new TimeContext with the current time. pub fn now() -> Self { - use std::time::{SystemTime, UNIX_EPOCH}; - - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default(); - let timestamp = now.as_secs() as i64; + let timestamp = Self::current_timestamp(); // Calculate hour and day_of_week from timestamp // This is a simplified calculation - in production you might want chrono @@ -345,6 +340,24 @@ impl TimeContext { timestamp, } } + + /// Current Unix timestamp in seconds. + /// + /// On native uses `SystemTime`; on wasm32 uses `js_sys::Date` since + /// `SystemTime::now()` panics in the browser. + #[cfg(not(target_arch = "wasm32"))] + fn current_timestamp() -> i64 { + use std::time::{SystemTime, UNIX_EPOCH}; + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() as i64 + } + + #[cfg(target_arch = "wasm32")] + fn current_timestamp() -> i64 { + (js_sys::Date::now() / 1000.0) as i64 + } } impl Default for TimeContext { diff --git a/src/cache/memory.rs b/src/cache/memory.rs index c4e8478..cadf4b1 100644 --- a/src/cache/memory.rs +++ b/src/cache/memory.rs @@ -144,7 +144,8 @@ impl MemoryCache { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl Cache for MemoryCache { async fn get_bytes(&self, key: &str) -> CacheResult>> { if let Some(mut entry) = self.data.get_mut(key) { diff --git a/src/cache/redis.rs b/src/cache/redis.rs index 4103703..40d57de 100644 --- a/src/cache/redis.rs +++ b/src/cache/redis.rs @@ -558,7 +558,8 @@ impl RedisCache { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl Cache for RedisCache { async fn get_bytes(&self, key: &str) -> CacheResult>> { let mut conn = self.get_connection().await?; diff --git a/src/cache/semantic_cache.rs b/src/cache/semantic_cache.rs index 37446fb..7838583 100644 --- a/src/cache/semantic_cache.rs +++ b/src/cache/semantic_cache.rs @@ -128,6 +128,7 @@ impl SemanticCache { /// /// # Returns /// A tuple of (SemanticCache, background task handle) + #[cfg(not(target_arch = "wasm32"))] pub fn new( cache: Arc, vector_store: Arc, diff --git a/src/cache/traits.rs b/src/cache/traits.rs index adf0f07..86dfc44 100644 --- a/src/cache/traits.rs +++ b/src/cache/traits.rs @@ -61,7 +61,8 @@ pub struct BatchLimitResult { pub rate_limit_results: Vec, } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait Cache: Send + Sync { /// Get raw bytes from cache async fn get_bytes(&self, key: &str) -> CacheResult>>; @@ -165,6 +166,7 @@ pub trait Cache: Send + Sync { } // Helper extension trait for working with JSON +#[allow(async_fn_in_trait)] pub trait CacheExt: Cache { async fn get_json(&self, key: &str) -> CacheResult> { use super::error::CacheError; diff --git a/src/cache/vector_store/mod.rs b/src/cache/vector_store/mod.rs index af3d32d..3da3846 100644 --- a/src/cache/vector_store/mod.rs +++ b/src/cache/vector_store/mod.rs @@ -252,7 +252,8 @@ pub struct ChunkFilter { /// store.delete("unique-id").await?; /// } /// ``` -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait VectorBackend: Send + Sync { /// Store an embedding with associated metadata. /// diff --git a/src/cache/vector_store/pgvector.rs b/src/cache/vector_store/pgvector.rs index 9b54bcf..80e3574 100644 --- a/src/cache/vector_store/pgvector.rs +++ b/src/cache/vector_store/pgvector.rs @@ -568,7 +568,8 @@ fn count_filter_binds(filter: &AttributeFilter) -> usize { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl VectorBackend for PgvectorStore { #[instrument( skip(self, embedding, metadata), diff --git a/src/cache/vector_store/qdrant.rs b/src/cache/vector_store/qdrant.rs index 49db8e5..ca00b90 100644 --- a/src/cache/vector_store/qdrant.rs +++ b/src/cache/vector_store/qdrant.rs @@ -59,10 +59,10 @@ impl QdrantStore { dimensions: usize, distance_metric: DistanceMetric, ) -> Self { - let client = Client::builder() - .timeout(Duration::from_secs(30)) - .build() - .expect("Failed to create HTTP client"); + let builder = Client::builder(); + #[cfg(not(target_arch = "wasm32"))] + let builder = builder.timeout(Duration::from_secs(30)); + let client = builder.build().expect("Failed to create HTTP client"); // Remove trailing slash from base_url let base_url = base_url.trim_end_matches('/').to_string(); @@ -668,7 +668,8 @@ struct ChunkSearchResultData { payload: Option>, } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl VectorBackend for QdrantStore { #[instrument( skip(self, embedding, metadata), diff --git a/src/cache/vector_store/test.rs b/src/cache/vector_store/test.rs index 0207956..6361947 100644 --- a/src/cache/vector_store/test.rs +++ b/src/cache/vector_store/test.rs @@ -35,7 +35,8 @@ impl TestVectorStore { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl VectorBackend for TestVectorStore { async fn store( &self, @@ -228,7 +229,8 @@ impl MockableTestVectorStore { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl VectorBackend for MockableTestVectorStore { async fn store( &self, diff --git a/src/catalog/registry.rs b/src/catalog/registry.rs index d49fd60..5adeb16 100644 --- a/src/catalog/registry.rs +++ b/src/catalog/registry.rs @@ -5,11 +5,10 @@ use std::{collections::HashMap, sync::Arc}; -use parking_lot::RwLock; use serde::{Deserialize, Serialize}; use super::types::{CatalogCost, CatalogModel, ModelCatalog}; -use crate::pricing::ModelPricing; +use crate::{compat::RwLock, pricing::ModelPricing}; /// Model capabilities extracted from the catalog. #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)] diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 944010f..dc2b4d4 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -17,7 +17,7 @@ use clap::Parser; /// CLI arguments for Hadrian Gateway #[derive(Parser, Debug)] #[command(version, about = "Hadrian AI Gateway", long_about = None)] -pub(crate) struct Args { +pub struct Args { #[command(subcommand)] command: Option, @@ -102,7 +102,7 @@ enum Command { } /// Dispatch to the appropriate subcommand handler. -pub(crate) async fn dispatch(args: Args) { +pub async fn dispatch(args: Args) { match args.command { Some(Command::Openapi { output }) => { #[cfg(feature = "utoipa")] @@ -220,19 +220,19 @@ pub(crate) fn default_config_dir() -> Option { } /// Get the default config file path. -pub(crate) fn default_config_path() -> Option { +pub fn default_config_path() -> Option { default_config_dir().map(|p| p.join("hadrian.toml")) } /// Get the default data directory path. #[cfg(feature = "wizard")] -pub(crate) fn default_data_dir() -> Option { +pub fn default_data_dir() -> Option { dirs::data_dir().map(|p| p.join("hadrian")) } /// Get the default data directory path. #[cfg(not(feature = "wizard"))] -pub(crate) fn default_data_dir() -> Option { +pub fn default_data_dir() -> Option { None } diff --git a/src/compat.rs b/src/compat.rs new file mode 100644 index 0000000..057d378 --- /dev/null +++ b/src/compat.rs @@ -0,0 +1,321 @@ +//! Compatibility layer for concurrency primitives and WASM routing. +//! +//! On native builds with the `concurrency` feature, this re-exports high-performance +//! types from `parking_lot` and `dashmap`. On WASM or builds without `concurrency`, +//! it provides std-based fallbacks that are safe on single-threaded runtimes. +//! +//! ## Async trait Send bounds +//! +//! WASM is single-threaded, so `Send` bounds on async trait futures are unnecessary +//! and impossible to satisfy (reqwest/wasm-bindgen futures are `!Send`). +//! +//! All `#[async_trait]` usages are replaced with conditional `cfg_attr`: +//! ```ignore +//! #[cfg_attr(not(target_arch = "wasm32"), async_trait)] +//! #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +//! ``` +//! This is behavior-identical on native, and uses `?Send` on wasm32. +//! +//! ## WASM routing (`AssertSend` / `WasmHandler`) +//! +//! Axum requires handler futures to be `Send`, but on wasm32 `reqwest`/`wasm-bindgen` +//! futures are `!Send`. Since WASM is single-threaded, `Send` is vacuously satisfied. +//! +//! [`AssertSend`] wraps any future with `unsafe impl Send`, and [`WasmHandler`] +//! wraps handler functions so they produce `AssertSend` futures. The [`wasm_routing`] +//! module provides drop-in replacements for `axum::routing::{get, post, ...}` that +//! automatically wrap handlers in `WasmHandler`. + +// ───────────────────────────────────────────────────────────────────────────── +// Spawn (fire-and-forget) +// ───────────────────────────────────────────────────────────────────────────── + +/// Spawn a fire-and-forget async task. +/// +/// On native builds this delegates to `tokio::spawn` (requires `Send`). +/// On WASM builds this uses `wasm_bindgen_futures::spawn_local` (no `Send` required). +#[cfg(not(target_arch = "wasm32"))] +pub fn spawn_detached(future: F) +where + F: std::future::Future + Send + 'static, +{ + tokio::spawn(future); +} + +/// Spawn a fire-and-forget async task (WASM version, no `Send` bound). +#[cfg(target_arch = "wasm32")] +pub fn spawn_detached(future: F) +where + F: std::future::Future + 'static, +{ + wasm_bindgen_futures::spawn_local(future); +} + +// ───────────────────────────────────────────────────────────────────────────── +// AssertSend / WasmHandler (wasm32 only) +// ───────────────────────────────────────────────────────────────────────────── + +/// A future wrapper that asserts `Send` for `!Send` futures on wasm32. +/// +/// # Safety +/// +/// WASM is single-threaded, so `Send` is vacuously satisfied — there is no other +/// thread that could observe the wrapped future. +#[cfg(target_arch = "wasm32")] +pub struct AssertSend(pub F); + +#[cfg(target_arch = "wasm32")] +// SAFETY: wasm32 is single-threaded; Send is vacuously satisfied. +unsafe impl Send for AssertSend {} + +#[cfg(target_arch = "wasm32")] +impl std::future::Future for AssertSend { + type Output = F::Output; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + // SAFETY: We only project through the newtype; pinning is preserved. + unsafe { self.map_unchecked_mut(|s| &mut s.0) }.poll(cx) + } +} + +/// Wraps any handler function so its return future is `Send` (via [`AssertSend`]). +/// +/// This is a newtype used by the [`wasm_routing`] module's drop-in routing functions. +#[cfg(target_arch = "wasm32")] +pub struct WasmHandler(pub H); + +#[cfg(target_arch = "wasm32")] +impl Clone for WasmHandler { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +// SAFETY: wasm32 is single-threaded; Sync is vacuously satisfied. +#[cfg(target_arch = "wasm32")] +unsafe impl Sync for WasmHandler {} + +/// Implement `Handler<(T1..Tn, M), S>` for `WasmHandler`. +/// +/// Mirrors axum's own `Handler` impl but removes the `Send` bound on `Fut` +/// and wraps the output in [`AssertSend`]. +#[cfg(target_arch = "wasm32")] +macro_rules! impl_wasm_handler { + ( [$($ty:ident),*], $last:ident ) => { + #[allow(non_snake_case)] + impl axum::handler::Handler<(M, $($ty,)* $last,), S> + for WasmHandler + where + F: FnOnce($($ty,)* $last,) -> Fut + Clone + Send + 'static, + Fut: std::future::Future + 'static, // no Send bound + Res: axum::response::IntoResponse, + S: Send + Sync + Clone + 'static, + $( $ty: axum::extract::FromRequestParts + Send, )* + $last: axum::extract::FromRequest + Send, + { + type Future = AssertSend + 'static>>>; + + fn call( + self, + req: axum::http::Request, + state: S, + ) -> Self::Future { + let (mut parts, body) = req.into_parts(); + AssertSend(Box::pin(async move { + use axum::response::IntoResponse as _; + $( + let $ty = match $ty::from_request_parts(&mut parts, &state).await { + Ok(value) => value, + Err(rejection) => return rejection.into_response(), + }; + )* + let req = axum::http::Request::from_parts(parts, body); + let $last = match $last::from_request(req, &state).await { + Ok(value) => value, + Err(rejection) => return rejection.into_response(), + }; + (self.0)($($ty,)* $last,).await.into_response() + })) + } + } + }; +} + +// Zero-argument handler impl (mirrors axum's `FnOnce() -> Fut` impl). +#[cfg(target_arch = "wasm32")] +impl axum::handler::Handler<((),), S> for WasmHandler +where + F: FnOnce() -> Fut + Clone + Send + 'static, + Fut: std::future::Future + 'static, + Res: axum::response::IntoResponse, +{ + type Future = AssertSend< + std::pin::Pin + 'static>>, + >; + + fn call(self, _req: axum::http::Request, _state: S) -> Self::Future { + AssertSend(Box::pin(async move { + axum::response::IntoResponse::into_response(self.0().await) + })) + } +} + +// Implement for arities 1..16 (matching axum's supported handler arities). +#[cfg(target_arch = "wasm32")] +impl_wasm_handler!([], T1); +#[cfg(target_arch = "wasm32")] +impl_wasm_handler!([T1], T2); +#[cfg(target_arch = "wasm32")] +impl_wasm_handler!([T1, T2], T3); +#[cfg(target_arch = "wasm32")] +impl_wasm_handler!([T1, T2, T3], T4); +#[cfg(target_arch = "wasm32")] +impl_wasm_handler!([T1, T2, T3, T4], T5); +#[cfg(target_arch = "wasm32")] +impl_wasm_handler!([T1, T2, T3, T4, T5], T6); +#[cfg(target_arch = "wasm32")] +impl_wasm_handler!([T1, T2, T3, T4, T5, T6], T7); +#[cfg(target_arch = "wasm32")] +impl_wasm_handler!([T1, T2, T3, T4, T5, T6, T7], T8); +#[cfg(target_arch = "wasm32")] +impl_wasm_handler!([T1, T2, T3, T4, T5, T6, T7, T8], T9); +#[cfg(target_arch = "wasm32")] +impl_wasm_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9], T10); +#[cfg(target_arch = "wasm32")] +impl_wasm_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10], T11); +#[cfg(target_arch = "wasm32")] +impl_wasm_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11], T12); +#[cfg(target_arch = "wasm32")] +impl_wasm_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12], T13); +#[cfg(target_arch = "wasm32")] +impl_wasm_handler!( + [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13], + T14 +); +#[cfg(target_arch = "wasm32")] +impl_wasm_handler!( + [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14], + T15 +); +#[cfg(target_arch = "wasm32")] +impl_wasm_handler!( + [ + T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15 + ], + T16 +); + +/// Drop-in replacements for `axum::routing::{get, post, put, patch, delete}` that +/// automatically wrap handlers in [`WasmHandler`] so `!Send` futures compile on wasm32. +/// +/// Usage in route modules: +/// ```ignore +/// #[cfg(feature = "server")] +/// use axum::routing::{get, post}; +/// #[cfg(feature = "wasm")] +/// use crate::compat::wasm_routing::{get, post}; +/// ``` +#[cfg(target_arch = "wasm32")] +pub mod wasm_routing { + use axum::{handler::Handler, routing::MethodRouter}; + + use super::WasmHandler; + + pub fn get(handler: H) -> MethodRouter + where + WasmHandler: Handler, + T: 'static, + S: Clone + Send + Sync + 'static, + { + axum::routing::get(WasmHandler(handler)) + } + + pub fn post(handler: H) -> MethodRouter + where + WasmHandler: Handler, + T: 'static, + S: Clone + Send + Sync + 'static, + { + axum::routing::post(WasmHandler(handler)) + } + + pub fn put(handler: H) -> MethodRouter + where + WasmHandler: Handler, + T: 'static, + S: Clone + Send + Sync + 'static, + { + axum::routing::put(WasmHandler(handler)) + } + + pub fn patch(handler: H) -> MethodRouter + where + WasmHandler: Handler, + T: 'static, + S: Clone + Send + Sync + 'static, + { + axum::routing::patch(WasmHandler(handler)) + } + + pub fn delete(handler: H) -> MethodRouter + where + WasmHandler: Handler, + T: 'static, + S: Clone + Send + Sync + 'static, + { + axum::routing::delete(WasmHandler(handler)) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Mutex +// ───────────────────────────────────────────────────────────────────────────── +#[cfg(feature = "concurrency")] +pub use parking_lot::Mutex; + +/// A Mutex wrapper around `std::sync::Mutex` that panics-on-poison (matching +/// `parking_lot::Mutex` semantics). Safe on single-threaded WASM. +#[cfg(not(feature = "concurrency"))] +#[derive(Debug)] +pub struct Mutex(std::sync::Mutex); + +#[cfg(not(feature = "concurrency"))] +impl Mutex { + pub fn new(value: T) -> Self { + Self(std::sync::Mutex::new(value)) + } + + pub fn lock(&self) -> std::sync::MutexGuard<'_, T> { + self.0.lock().expect("mutex poisoned") + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// RwLock +// ───────────────────────────────────────────────────────────────────────────── + +#[cfg(feature = "concurrency")] +pub use parking_lot::RwLock; + +/// An RwLock wrapper around `std::sync::RwLock` that panics-on-poison. +#[cfg(not(feature = "concurrency"))] +#[derive(Debug, Default)] +pub struct RwLock(std::sync::RwLock); + +#[cfg(not(feature = "concurrency"))] +impl RwLock { + pub fn new(value: T) -> Self { + Self(std::sync::RwLock::new(value)) + } + + pub fn read(&self) -> std::sync::RwLockReadGuard<'_, T> { + self.0.read().expect("rwlock poisoned") + } + + pub fn write(&self) -> std::sync::RwLockWriteGuard<'_, T> { + self.0.write().expect("rwlock poisoned") + } +} diff --git a/src/config/auth.rs b/src/config/auth.rs index d0fb65f..7bad9d1 100644 --- a/src/config/auth.rs +++ b/src/config/auth.rs @@ -742,6 +742,7 @@ pub enum JwtAlgorithm { EdDSA, } +#[cfg(feature = "jwt")] impl JwtAlgorithm { /// Convert to jsonwebtoken Algorithm. pub fn to_jwt_algorithm(self) -> jsonwebtoken::Algorithm { diff --git a/src/config/mod.rs b/src/config/mod.rs index eb82e65..8f3471b 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -29,6 +29,7 @@ mod server; mod storage; mod ui; +#[cfg(feature = "server")] use std::path::Path; pub use auth::*; @@ -121,15 +122,17 @@ impl GatewayConfig { /// /// Environment variables in the format `${VAR_NAME}` are expanded. /// Missing required variables will cause an error. + #[cfg(feature = "server")] pub fn from_file(path: impl AsRef) -> Result { let contents = std::fs::read_to_string(path.as_ref()) .map_err(|e| ConfigError::Io(e, path.as_ref().to_path_buf()))?; - Self::from_str(&contents) + Self::parse(&contents) } /// Parse configuration from a TOML string. - pub fn from_str(contents: &str) -> Result { + #[cfg(feature = "server")] + pub fn parse(contents: &str) -> Result { // Expand environment variables let expanded = expand_env_vars(contents)?; @@ -148,6 +151,7 @@ impl GatewayConfig { } /// Validate the configuration for consistency and completeness. + #[cfg(feature = "server")] fn validate(&mut self) -> Result<(), ConfigError> { // If auth is enabled, we need a database if self.auth.is_auth_enabled() && self.database.is_none() { @@ -217,6 +221,7 @@ pub enum ConfigError { #[error("Failed to read config file {1}: {0}")] Io(std::io::Error, std::path::PathBuf), + #[cfg(feature = "server")] #[error("Failed to parse config: {0}")] Parse(#[from] toml::de::Error), @@ -233,6 +238,7 @@ pub enum ConfigError { /// not compiled into this binary, serde produces cryptic "unknown variant" errors. /// This function inspects the raw TOML to detect such cases and produce actionable /// error messages telling the user exactly which features to enable. +#[cfg(feature = "server")] fn check_disabled_features(raw: &toml::Value) -> Result<(), ConfigError> { let mut issues: Vec<(String, &str)> = Vec::new(); @@ -334,6 +340,7 @@ fn check_disabled_features(raw: &toml::Value) -> Result<(), ConfigError> { ))) } +#[cfg(feature = "server")] fn check_provider_feature(_name: &str, type_val: &str, _issues: &mut Vec<(String, &str)>) { match type_val { #[cfg(not(feature = "provider-bedrock"))] @@ -361,6 +368,7 @@ fn check_provider_feature(_name: &str, type_val: &str, _issues: &mut Vec<(String } } +#[cfg(feature = "server")] fn check_database_feature(type_val: &str, _issues: &mut Vec<(String, &str)>) { match type_val { #[cfg(not(feature = "database-sqlite"))] @@ -377,6 +385,7 @@ fn check_database_feature(type_val: &str, _issues: &mut Vec<(String, &str)>) { } } +#[cfg(feature = "server")] fn check_secrets_feature(type_val: &str, _issues: &mut Vec<(String, &str)>) { match type_val { #[cfg(not(feature = "vault"))] @@ -403,6 +412,7 @@ fn check_secrets_feature(type_val: &str, _issues: &mut Vec<(String, &str)>) { } } +#[cfg(feature = "server")] fn check_cache_feature(type_val: &str, _issues: &mut Vec<(String, &str)>) { match type_val { #[cfg(not(feature = "redis"))] @@ -414,6 +424,7 @@ fn check_cache_feature(type_val: &str, _issues: &mut Vec<(String, &str)>) { } } +#[cfg(feature = "server")] fn check_rbac_feature(_issues: &mut Vec<(String, &str)>) { #[cfg(not(feature = "cel"))] _issues.push(( @@ -422,6 +433,7 @@ fn check_rbac_feature(_issues: &mut Vec<(String, &str)>) { )); } +#[cfg(feature = "server")] fn check_metrics_feature(_issues: &mut Vec<(String, &str)>) { #[cfg(not(feature = "prometheus"))] _issues.push(( @@ -430,6 +442,7 @@ fn check_metrics_feature(_issues: &mut Vec<(String, &str)>) { )); } +#[cfg(feature = "server")] fn check_otlp_feature(_issues: &mut Vec<(String, &str)>) { #[cfg(not(feature = "otlp"))] _issues.push(( @@ -438,6 +451,7 @@ fn check_otlp_feature(_issues: &mut Vec<(String, &str)>) { )); } +#[cfg(feature = "server")] fn check_auth_mode_feature(_raw: &toml::Value, _issues: &mut Vec<(String, &str)>) { #[cfg(not(feature = "sso"))] if _raw @@ -456,6 +470,7 @@ fn check_auth_mode_feature(_raw: &toml::Value, _issues: &mut Vec<(String, &str)> /// Expand environment variables in the format `${VAR_NAME}`. /// Skips commented lines (lines where content before the variable is a comment). +#[cfg(feature = "server")] fn expand_env_vars(input: &str) -> Result { let re = regex::Regex::new(r"\$\{([^}]+)\}").unwrap(); let mut result = String::with_capacity(input.len()); @@ -510,7 +525,7 @@ mod tests { #[test] fn test_minimal_config() { - let config = GatewayConfig::from_str( + let config = GatewayConfig::parse( r#" [providers.my-openai] type = "open_ai" @@ -525,7 +540,7 @@ mod tests { #[test] fn test_multiple_providers_config() { - let config = GatewayConfig::from_str( + let config = GatewayConfig::parse( r#" [providers] default_provider = "openrouter" @@ -599,7 +614,7 @@ key3 = "literal""# #[test] #[cfg(not(feature = "provider-bedrock"))] fn test_disabled_provider_bedrock_error() { - let err = GatewayConfig::from_str( + let err = GatewayConfig::parse( r#" [providers.my-bedrock] type = "bedrock" @@ -626,7 +641,7 @@ key3 = "literal""# #[test] #[cfg(not(feature = "vault"))] fn test_disabled_secrets_vault_error() { - let err = GatewayConfig::from_str( + let err = GatewayConfig::parse( r#" [secrets] type = "vault" @@ -651,7 +666,7 @@ key3 = "literal""# #[test] #[cfg(not(feature = "provider-bedrock"))] fn test_disabled_multiple_features_error() { - let err = GatewayConfig::from_str( + let err = GatewayConfig::parse( r#" [providers.my-bedrock] type = "bedrock" @@ -679,7 +694,7 @@ key3 = "literal""# #[test] #[cfg(not(feature = "database-sqlite"))] fn test_disabled_database_sqlite_error() { - let err = GatewayConfig::from_str( + let err = GatewayConfig::parse( r#" [database] type = "sqlite" @@ -706,7 +721,7 @@ key3 = "literal""# #[test] #[cfg(not(feature = "database-postgres"))] fn test_disabled_database_postgres_error() { - let err = GatewayConfig::from_str( + let err = GatewayConfig::parse( r#" [database] type = "postgres" @@ -755,7 +770,7 @@ key3 = "literal""# #[cfg(feature = "database-sqlite")] fn test_iap_without_trusted_proxies_non_localhost_errors() { // IAP on 0.0.0.0 without trusted_proxies should fail - let err = GatewayConfig::from_str( + let err = GatewayConfig::parse( r#" [server] host = "0.0.0.0" @@ -790,7 +805,7 @@ key3 = "literal""# #[cfg(feature = "database-sqlite")] fn test_iap_without_trusted_proxies_localhost_warns_but_ok() { // IAP on localhost without trusted_proxies should succeed (just warn) - let result = GatewayConfig::from_str( + let result = GatewayConfig::parse( r#" [server] host = "127.0.0.1" @@ -820,7 +835,7 @@ key3 = "literal""# #[cfg(feature = "database-sqlite")] fn test_iap_with_trusted_proxies_non_localhost_ok() { // IAP on 0.0.0.0 with trusted_proxies configured should succeed - let result = GatewayConfig::from_str( + let result = GatewayConfig::parse( r#" [server] host = "0.0.0.0" @@ -853,7 +868,7 @@ key3 = "literal""# #[cfg(feature = "database-sqlite")] fn test_iap_with_dangerously_trust_all_non_localhost_ok() { // IAP with dangerously_trust_all should also pass validation - let result = GatewayConfig::from_str( + let result = GatewayConfig::parse( r#" [server] host = "0.0.0.0" diff --git a/src/config/server.rs b/src/config/server.rs index fc8be77..7548f2d 100644 --- a/src/config/server.rs +++ b/src/config/server.rs @@ -635,29 +635,36 @@ impl Default for HttpClientConfig { impl HttpClientConfig { /// Build a reqwest Client from this configuration. pub fn build_client(&self) -> Result { - let mut builder = reqwest::Client::builder() - .timeout(Duration::from_secs(self.timeout_secs)) - .connection_verbose(self.verbose) - .connect_timeout(Duration::from_secs(self.connect_timeout_secs)) - .pool_max_idle_per_host(self.pool_max_idle_per_host) - .pool_idle_timeout(Duration::from_secs(self.pool_idle_timeout_secs)) - .tcp_nodelay(self.tcp_nodelay) - .user_agent(&self.user_agent); - - // HTTP/2 configuration - if self.http2_prior_knowledge { - builder = builder.http2_prior_knowledge(); - } - if self.http2_adaptive_window { - builder = builder.http2_adaptive_window(true); - } + #[cfg(not(target_arch = "wasm32"))] + { + let mut builder = reqwest::Client::builder() + .timeout(Duration::from_secs(self.timeout_secs)) + .connection_verbose(self.verbose) + .connect_timeout(Duration::from_secs(self.connect_timeout_secs)) + .pool_max_idle_per_host(self.pool_max_idle_per_host) + .pool_idle_timeout(Duration::from_secs(self.pool_idle_timeout_secs)) + .tcp_nodelay(self.tcp_nodelay) + .user_agent(&self.user_agent); + + // HTTP/2 configuration + if self.http2_prior_knowledge { + builder = builder.http2_prior_knowledge(); + } + if self.http2_adaptive_window { + builder = builder.http2_adaptive_window(true); + } - // TCP keepalive (0 means disabled) - if self.tcp_keepalive_secs > 0 { - builder = builder.tcp_keepalive(Duration::from_secs(self.tcp_keepalive_secs)); - } + // TCP keepalive (0 means disabled) + if self.tcp_keepalive_secs > 0 { + builder = builder.tcp_keepalive(Duration::from_secs(self.tcp_keepalive_secs)); + } - builder.build() + builder.build() + } + #[cfg(target_arch = "wasm32")] + { + reqwest::Client::builder().build() + } } } diff --git a/src/db/error.rs b/src/db/error.rs index 40fd8f9..58ef5c6 100644 --- a/src/db/error.rs +++ b/src/db/error.rs @@ -25,6 +25,10 @@ pub enum DbError { #[error("JSON serialization error: {0}")] Json(#[from] serde_json::Error), + #[cfg(feature = "database-wasm-sqlite")] + #[error("WASM SQLite error: {0}")] + WasmSqlite(#[from] crate::db::wasm_sqlite::WasmDbError), + #[error("Internal error: {0}")] Internal(String), } diff --git a/src/db/mod.rs b/src/db/mod.rs index 266443c..1e3c84a 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -2,8 +2,10 @@ mod error; #[cfg(feature = "database-postgres")] pub mod postgres; pub mod repos; -#[cfg(feature = "database-sqlite")] +#[cfg(any(feature = "database-sqlite", feature = "database-wasm-sqlite"))] pub mod sqlite; +#[cfg(feature = "database-wasm-sqlite")] +pub mod wasm_sqlite; #[cfg(all(test, any(feature = "database-sqlite", feature = "database-postgres")))] pub mod tests; @@ -76,7 +78,13 @@ enum PoolStorage { Sqlite(sqlx::SqlitePool), #[cfg(feature = "database-postgres")] Postgres(PgPoolPair), - #[cfg(not(any(feature = "database-sqlite", feature = "database-postgres")))] + #[cfg(feature = "database-wasm-sqlite")] + WasmSqlite(wasm_sqlite::WasmSqlitePool), + #[cfg(not(any( + feature = "database-sqlite", + feature = "database-postgres", + feature = "database-wasm-sqlite" + )))] _None(std::convert::Infallible), } @@ -87,7 +95,13 @@ pub enum DbPoolRef<'a> { Sqlite(&'a sqlx::SqlitePool), #[cfg(feature = "database-postgres")] Postgres(&'a PgPoolPair), - #[cfg(not(any(feature = "database-sqlite", feature = "database-postgres")))] + #[cfg(feature = "database-wasm-sqlite")] + WasmSqlite(&'a wasm_sqlite::WasmSqlitePool), + #[cfg(not(any( + feature = "database-sqlite", + feature = "database-postgres", + feature = "database-wasm-sqlite" + )))] _None(std::convert::Infallible, std::marker::PhantomData<&'a ()>), } @@ -139,6 +153,44 @@ impl DbPool { } } + /// Create a DbPool from a WASM SQLite pool (wa-sqlite in the browser). + #[cfg(feature = "database-wasm-sqlite")] + pub fn from_wasm_sqlite(pool: wasm_sqlite::WasmSqlitePool) -> Self { + let repos = CachedRepos { + organizations: Arc::new(sqlite::SqliteOrganizationRepo::new(pool.clone())), + projects: Arc::new(sqlite::SqliteProjectRepo::new(pool.clone())), + users: Arc::new(sqlite::SqliteUserRepo::new(pool.clone())), + api_keys: Arc::new(sqlite::SqliteApiKeyRepo::new(pool.clone())), + providers: Arc::new(sqlite::SqliteDynamicProviderRepo::new(pool.clone())), + usage: Arc::new(sqlite::SqliteUsageRepo::new(pool.clone())), + model_pricing: Arc::new(sqlite::SqliteModelPricingRepo::new(pool.clone())), + conversations: Arc::new(sqlite::SqliteConversationRepo::new(pool.clone())), + audit_logs: Arc::new(sqlite::SqliteAuditLogRepo::new(pool.clone())), + vector_stores: Arc::new(sqlite::SqliteVectorStoresRepo::new(pool.clone())), + files: Arc::new(sqlite::SqliteFilesRepo::new(pool.clone())), + teams: Arc::new(sqlite::SqliteTeamRepo::new(pool.clone())), + prompts: Arc::new(sqlite::SqlitePromptRepo::new(pool.clone())), + #[cfg(feature = "sso")] + sso_group_mappings: unreachable!("SSO not supported in WASM builds"), + #[cfg(feature = "sso")] + org_sso_configs: unreachable!("SSO not supported in WASM builds"), + #[cfg(feature = "sso")] + domain_verifications: unreachable!("SSO not supported in WASM builds"), + #[cfg(feature = "sso")] + scim_configs: unreachable!("SSO not supported in WASM builds"), + #[cfg(feature = "sso")] + scim_user_mappings: unreachable!("SSO not supported in WASM builds"), + #[cfg(feature = "sso")] + scim_group_mappings: unreachable!("SSO not supported in WASM builds"), + org_rbac_policies: Arc::new(sqlite::SqliteOrgRbacPolicyRepo::new(pool.clone())), + service_accounts: Arc::new(sqlite::SqliteServiceAccountRepo::new(pool.clone())), + }; + DbPool { + inner: PoolStorage::WasmSqlite(pool), + repos, + } + } + /// Create a DbPool from existing PostgreSQL pools. /// Primarily useful for testing. #[cfg(feature = "database-postgres")] @@ -454,7 +506,17 @@ impl DbPool { tracing::info!("PostgreSQL migrations completed successfully"); Ok(()) } - #[cfg(not(any(feature = "database-sqlite", feature = "database-postgres")))] + #[cfg(feature = "database-wasm-sqlite")] + PoolStorage::WasmSqlite(pool) => { + tracing::info!("Running WASM SQLite migrations"); + pool.run_migrations().await?; + Ok(()) + } + #[cfg(not(any( + feature = "database-sqlite", + feature = "database-postgres", + feature = "database-wasm-sqlite" + )))] PoolStorage::_None(infallible) => match *infallible {}, } } @@ -578,7 +640,13 @@ impl DbPool { PoolStorage::Sqlite(pool) => DbPoolRef::Sqlite(pool), #[cfg(feature = "database-postgres")] PoolStorage::Postgres(pools) => DbPoolRef::Postgres(pools), - #[cfg(not(any(feature = "database-sqlite", feature = "database-postgres")))] + #[cfg(feature = "database-wasm-sqlite")] + PoolStorage::WasmSqlite(pool) => DbPoolRef::WasmSqlite(pool), + #[cfg(not(any( + feature = "database-sqlite", + feature = "database-postgres", + feature = "database-wasm-sqlite" + )))] PoolStorage::_None(infallible) => match *infallible {}, } } @@ -611,7 +679,16 @@ impl DbPool { } Ok(()) } - #[cfg(not(any(feature = "database-sqlite", feature = "database-postgres")))] + #[cfg(feature = "database-wasm-sqlite")] + PoolStorage::WasmSqlite(pool) => { + pool.execute_query("SELECT 1", &[]).await?; + Ok(()) + } + #[cfg(not(any( + feature = "database-sqlite", + feature = "database-postgres", + feature = "database-wasm-sqlite" + )))] PoolStorage::_None(infallible) => match *infallible {}, } } diff --git a/src/db/postgres/api_keys.rs b/src/db/postgres/api_keys.rs index 69dae80..60cc213 100644 --- a/src/db/postgres/api_keys.rs +++ b/src/db/postgres/api_keys.rs @@ -373,7 +373,8 @@ impl PostgresApiKeyRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl ApiKeyRepo for PostgresApiKeyRepo { async fn create(&self, input: CreateApiKey, key_hash: &str) -> DbResult { let id = Uuid::new_v4(); diff --git a/src/db/postgres/audit_logs.rs b/src/db/postgres/audit_logs.rs index 9d2117e..ba02764 100644 --- a/src/db/postgres/audit_logs.rs +++ b/src/db/postgres/audit_logs.rs @@ -34,7 +34,8 @@ impl PostgresAuditLogRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl AuditLogRepo for PostgresAuditLogRepo { async fn create(&self, input: CreateAuditLog) -> DbResult { let id = Uuid::new_v4(); diff --git a/src/db/postgres/conversations.rs b/src/db/postgres/conversations.rs index aa2e699..5fe1ab4 100644 --- a/src/db/postgres/conversations.rs +++ b/src/db/postgres/conversations.rs @@ -127,7 +127,8 @@ impl PostgresConversationRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl ConversationRepo for PostgresConversationRepo { async fn create(&self, input: CreateConversation) -> DbResult { let owner_type = input.owner.owner_type(); diff --git a/src/db/postgres/domain_verifications.rs b/src/db/postgres/domain_verifications.rs index 450447c..bb6fe46 100644 --- a/src/db/postgres/domain_verifications.rs +++ b/src/db/postgres/domain_verifications.rs @@ -51,7 +51,8 @@ impl PostgresDomainVerificationRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl DomainVerificationRepo for PostgresDomainVerificationRepo { async fn create( &self, diff --git a/src/db/postgres/files.rs b/src/db/postgres/files.rs index 9cc7d03..5dfb022 100644 --- a/src/db/postgres/files.rs +++ b/src/db/postgres/files.rs @@ -28,7 +28,8 @@ impl PostgresFilesRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl FilesRepo for PostgresFilesRepo { async fn create_file(&self, input: CreateFile) -> DbResult { let id = Uuid::new_v4(); diff --git a/src/db/postgres/model_pricing.rs b/src/db/postgres/model_pricing.rs index f1eb981..f62b79f 100644 --- a/src/db/postgres/model_pricing.rs +++ b/src/db/postgres/model_pricing.rs @@ -67,7 +67,7 @@ impl PostgresModelPricingRepo { reasoning_per_1m_tokens: row.get("reasoning_per_1m_tokens"), per_second: row.get("per_second"), per_1m_characters: row.get("per_1m_characters"), - source: PricingSource::from_str(&source_str), + source: PricingSource::parse(&source_str), created_at: row.get("created_at"), updated_at: row.get("updated_at"), }) @@ -221,7 +221,8 @@ impl PostgresModelPricingRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl ModelPricingRepo for PostgresModelPricingRepo { async fn create(&self, input: CreateModelPricing) -> DbResult { let id = Uuid::new_v4(); diff --git a/src/db/postgres/org_rbac_policies.rs b/src/db/postgres/org_rbac_policies.rs index 9b66812..cd419dd 100644 --- a/src/db/postgres/org_rbac_policies.rs +++ b/src/db/postgres/org_rbac_policies.rs @@ -229,7 +229,8 @@ impl PostgresOrgRbacPolicyRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl OrgRbacPolicyRepo for PostgresOrgRbacPolicyRepo { async fn create( &self, diff --git a/src/db/postgres/org_sso_configs.rs b/src/db/postgres/org_sso_configs.rs index f9f8c05..05d1837 100644 --- a/src/db/postgres/org_sso_configs.rs +++ b/src/db/postgres/org_sso_configs.rs @@ -105,7 +105,8 @@ impl PostgresOrgSsoConfigRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl OrgSsoConfigRepo for PostgresOrgSsoConfigRepo { async fn create( &self, diff --git a/src/db/postgres/organizations.rs b/src/db/postgres/organizations.rs index 4f903f0..afe4ae8 100644 --- a/src/db/postgres/organizations.rs +++ b/src/db/postgres/organizations.rs @@ -92,7 +92,8 @@ impl PostgresOrganizationRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl OrganizationRepo for PostgresOrganizationRepo { async fn create(&self, input: CreateOrganization) -> DbResult { let id = Uuid::new_v4(); diff --git a/src/db/postgres/projects.rs b/src/db/postgres/projects.rs index 074df1d..15b4a3f 100644 --- a/src/db/postgres/projects.rs +++ b/src/db/postgres/projects.rs @@ -96,7 +96,8 @@ impl PostgresProjectRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl ProjectRepo for PostgresProjectRepo { async fn create(&self, org_id: Uuid, input: CreateProject) -> DbResult { let row = sqlx::query( diff --git a/src/db/postgres/prompts.rs b/src/db/postgres/prompts.rs index 2b07e9e..4e407c4 100644 --- a/src/db/postgres/prompts.rs +++ b/src/db/postgres/prompts.rs @@ -115,7 +115,8 @@ impl PostgresPromptRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl PromptRepo for PostgresPromptRepo { async fn create(&self, input: CreatePrompt) -> DbResult { let id = Uuid::new_v4(); diff --git a/src/db/postgres/providers.rs b/src/db/postgres/providers.rs index 64c64a5..c5ad541 100644 --- a/src/db/postgres/providers.rs +++ b/src/db/postgres/providers.rs @@ -280,7 +280,8 @@ impl PostgresDynamicProviderRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl DynamicProviderRepo for PostgresDynamicProviderRepo { async fn create(&self, id: Uuid, input: CreateDynamicProvider) -> DbResult { let (owner_type, owner_id) = Self::owner_to_parts(&input.owner); diff --git a/src/db/postgres/scim_configs.rs b/src/db/postgres/scim_configs.rs index 43a1971..3b7cdfa 100644 --- a/src/db/postgres/scim_configs.rs +++ b/src/db/postgres/scim_configs.rs @@ -54,7 +54,8 @@ impl PostgresOrgScimConfigRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl OrgScimConfigRepo for PostgresOrgScimConfigRepo { async fn create( &self, diff --git a/src/db/postgres/scim_group_mappings.rs b/src/db/postgres/scim_group_mappings.rs index 6b668b2..56f098f 100644 --- a/src/db/postgres/scim_group_mappings.rs +++ b/src/db/postgres/scim_group_mappings.rs @@ -68,7 +68,8 @@ impl PostgresScimGroupMappingRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl ScimGroupMappingRepo for PostgresScimGroupMappingRepo { async fn create( &self, diff --git a/src/db/postgres/scim_user_mappings.rs b/src/db/postgres/scim_user_mappings.rs index 9e1bb41..16294c3 100644 --- a/src/db/postgres/scim_user_mappings.rs +++ b/src/db/postgres/scim_user_mappings.rs @@ -68,7 +68,8 @@ impl PostgresScimUserMappingRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl ScimUserMappingRepo for PostgresScimUserMappingRepo { async fn create( &self, diff --git a/src/db/postgres/service_accounts.rs b/src/db/postgres/service_accounts.rs index 8890102..a48de86 100644 --- a/src/db/postgres/service_accounts.rs +++ b/src/db/postgres/service_accounts.rs @@ -91,7 +91,8 @@ impl PostgresServiceAccountRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl ServiceAccountRepo for PostgresServiceAccountRepo { async fn create(&self, org_id: Uuid, input: CreateServiceAccount) -> DbResult { let id = Uuid::new_v4(); diff --git a/src/db/postgres/sso_group_mappings.rs b/src/db/postgres/sso_group_mappings.rs index 91aeced..b500bad 100644 --- a/src/db/postgres/sso_group_mappings.rs +++ b/src/db/postgres/sso_group_mappings.rs @@ -119,7 +119,8 @@ impl PostgresSsoGroupMappingRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl SsoGroupMappingRepo for PostgresSsoGroupMappingRepo { async fn create( &self, diff --git a/src/db/postgres/teams.rs b/src/db/postgres/teams.rs index 337d84e..507beb4 100644 --- a/src/db/postgres/teams.rs +++ b/src/db/postgres/teams.rs @@ -153,7 +153,8 @@ impl PostgresTeamRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl TeamRepo for PostgresTeamRepo { async fn create(&self, org_id: Uuid, input: CreateTeam) -> DbResult { let row = sqlx::query( diff --git a/src/db/postgres/usage.rs b/src/db/postgres/usage.rs index d09510f..8c4e23d 100644 --- a/src/db/postgres/usage.rs +++ b/src/db/postgres/usage.rs @@ -45,7 +45,8 @@ impl PostgresUsageRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl UsageRepo for PostgresUsageRepo { async fn log(&self, entry: UsageLogEntry) -> DbResult<()> { let id = Uuid::new_v4(); diff --git a/src/db/postgres/users.rs b/src/db/postgres/users.rs index 996837c..3cd4088 100644 --- a/src/db/postgres/users.rs +++ b/src/db/postgres/users.rs @@ -204,7 +204,8 @@ impl PostgresUserRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl UserRepo for PostgresUserRepo { async fn create(&self, input: CreateUser) -> DbResult { let row = sqlx::query( @@ -773,7 +774,7 @@ impl UserRepo for PostgresUserRepo { org_slug: row.get("org_slug"), org_name: row.get("org_name"), role: row.get("role"), - source: MembershipSource::from_str(&source_str).unwrap_or_default(), + source: MembershipSource::parse(&source_str).unwrap_or_default(), joined_at: row.get("joined_at"), } }) @@ -814,7 +815,7 @@ impl UserRepo for PostgresUserRepo { project_name: row.get("project_name"), org_id: row.get("org_id"), role: row.get("role"), - source: MembershipSource::from_str(&source_str).unwrap_or_default(), + source: MembershipSource::parse(&source_str).unwrap_or_default(), joined_at: row.get("joined_at"), } }) @@ -852,7 +853,7 @@ impl UserRepo for PostgresUserRepo { team_name: row.get("team_name"), org_id: row.get("org_id"), role: row.get("role"), - source: MembershipSource::from_str(&source_str).unwrap_or_default(), + source: MembershipSource::parse(&source_str).unwrap_or_default(), joined_at: row.get("joined_at"), } }) diff --git a/src/db/postgres/vector_stores.rs b/src/db/postgres/vector_stores.rs index 5c575f6..de6be5f 100644 --- a/src/db/postgres/vector_stores.rs +++ b/src/db/postgres/vector_stores.rs @@ -140,7 +140,8 @@ impl PostgresVectorStoresRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl VectorStoresRepo for PostgresVectorStoresRepo { // ==================== Vector Stores CRUD ==================== diff --git a/src/db/repos/api_keys.rs b/src/db/repos/api_keys.rs index a87582b..6c354f0 100644 --- a/src/db/repos/api_keys.rs +++ b/src/db/repos/api_keys.rs @@ -8,7 +8,8 @@ use crate::{ models::{ApiKey, ApiKeyWithOwner, CachedApiKey, CreateApiKey}, }; -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait ApiKeyRepo: Send + Sync { async fn create(&self, input: CreateApiKey, key_hash: &str) -> DbResult; async fn get_by_id(&self, id: Uuid) -> DbResult>; diff --git a/src/db/repos/audit_logs.rs b/src/db/repos/audit_logs.rs index 2f7ba68..c2ba445 100644 --- a/src/db/repos/audit_logs.rs +++ b/src/db/repos/audit_logs.rs @@ -8,7 +8,8 @@ use crate::{ models::{AuditLog, AuditLogQuery, CreateAuditLog}, }; -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait AuditLogRepo: Send + Sync { /// Create a new audit log entry async fn create(&self, input: CreateAuditLog) -> DbResult; diff --git a/src/db/repos/conversations.rs b/src/db/repos/conversations.rs index 1b28ad5..c843c6c 100644 --- a/src/db/repos/conversations.rs +++ b/src/db/repos/conversations.rs @@ -11,7 +11,8 @@ use crate::{ }, }; -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait ConversationRepo: Send + Sync { /// Create a new conversation async fn create(&self, input: CreateConversation) -> DbResult; diff --git a/src/db/repos/cursor.rs b/src/db/repos/cursor.rs index 17774e9..e53f71d 100644 --- a/src/db/repos/cursor.rs +++ b/src/db/repos/cursor.rs @@ -200,7 +200,11 @@ impl PageCursors { } } -#[cfg(any(feature = "database-sqlite", feature = "database-postgres"))] +#[cfg(any( + feature = "database-sqlite", + feature = "database-postgres", + feature = "database-wasm-sqlite" +))] /// Create a cursor from a row's created_at and id fields. /// /// Convenience function for use in database queries. @@ -223,7 +227,11 @@ pub fn cursor_from_row(created_at: DateTime, id: Uuid) -> Cursor { /// // When creating entities that will use cursor pagination: /// let created_at = truncate_to_millis(Utc::now()); /// ``` -#[cfg(any(feature = "database-sqlite", feature = "database-postgres"))] +#[cfg(any( + feature = "database-sqlite", + feature = "database-postgres", + feature = "database-wasm-sqlite" +))] pub fn truncate_to_millis(dt: DateTime) -> DateTime { DateTime::from_timestamp_millis(dt.timestamp_millis()).unwrap_or(dt) } diff --git a/src/db/repos/domain_verifications.rs b/src/db/repos/domain_verifications.rs index 757b860..89edc88 100644 --- a/src/db/repos/domain_verifications.rs +++ b/src/db/repos/domain_verifications.rs @@ -12,7 +12,8 @@ use crate::{ /// Domain verifications track the ownership verification status of email domains /// claimed by an organization's SSO configuration. Domains must be verified via /// DNS TXT record before SSO can be enforced for users with that email domain. -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait DomainVerificationRepo: Send + Sync { /// Create a new domain verification record. /// diff --git a/src/db/repos/files.rs b/src/db/repos/files.rs index 8ec174c..2be7918 100644 --- a/src/db/repos/files.rs +++ b/src/db/repos/files.rs @@ -8,7 +8,8 @@ use crate::{ }; /// Repository trait for files (OpenAI Files API) operations -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait FilesRepo: Send + Sync { /// Create a new file async fn create_file(&self, input: CreateFile) -> DbResult; diff --git a/src/db/repos/model_pricing.rs b/src/db/repos/model_pricing.rs index ea1c889..94d1023 100644 --- a/src/db/repos/model_pricing.rs +++ b/src/db/repos/model_pricing.rs @@ -7,7 +7,8 @@ use crate::{ models::{CreateModelPricing, DbModelPricing, PricingOwner, UpdateModelPricing}, }; -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait ModelPricingRepo: Send + Sync { /// Create a new model pricing entry async fn create(&self, input: CreateModelPricing) -> DbResult; diff --git a/src/db/repos/org_rbac_policies.rs b/src/db/repos/org_rbac_policies.rs index 442e5fe..8fd5115 100644 --- a/src/db/repos/org_rbac_policies.rs +++ b/src/db/repos/org_rbac_policies.rs @@ -24,7 +24,8 @@ use crate::{ /// /// All query methods (get, list, count) automatically exclude soft-deleted policies. /// Delete operations set `deleted_at` rather than removing rows, preserving version history. -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait OrgRbacPolicyRepo: Send + Sync { // ========================================================================= // Policy CRUD Operations diff --git a/src/db/repos/org_sso_configs.rs b/src/db/repos/org_sso_configs.rs index c2ff63d..bdaea0a 100644 --- a/src/db/repos/org_sso_configs.rs +++ b/src/db/repos/org_sso_configs.rs @@ -14,7 +14,8 @@ use crate::{ /// /// Note: Client secrets are stored separately in a secret manager. /// The `client_secret_key` field contains a reference to the secret. -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait OrgSsoConfigRepo: Send + Sync { /// Create a new SSO configuration for an organization. /// diff --git a/src/db/repos/organizations.rs b/src/db/repos/organizations.rs index fee9a73..da45c46 100644 --- a/src/db/repos/organizations.rs +++ b/src/db/repos/organizations.rs @@ -7,7 +7,8 @@ use crate::{ models::{CreateOrganization, Organization, UpdateOrganization}, }; -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait OrganizationRepo: Send + Sync { async fn create(&self, input: CreateOrganization) -> DbResult; async fn get_by_id(&self, id: Uuid) -> DbResult>; diff --git a/src/db/repos/projects.rs b/src/db/repos/projects.rs index 69f1a71..579d2f1 100644 --- a/src/db/repos/projects.rs +++ b/src/db/repos/projects.rs @@ -7,7 +7,8 @@ use crate::{ models::{CreateProject, Project, UpdateProject}, }; -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait ProjectRepo: Send + Sync { async fn create(&self, org_id: Uuid, input: CreateProject) -> DbResult; async fn get_by_id(&self, id: Uuid) -> DbResult>; diff --git a/src/db/repos/prompts.rs b/src/db/repos/prompts.rs index 7dc49b0..a94368c 100644 --- a/src/db/repos/prompts.rs +++ b/src/db/repos/prompts.rs @@ -7,7 +7,8 @@ use crate::{ models::{CreatePrompt, Prompt, PromptOwnerType, UpdatePrompt}, }; -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait PromptRepo: Send + Sync { /// Create a new prompt. async fn create(&self, input: CreatePrompt) -> DbResult; diff --git a/src/db/repos/providers.rs b/src/db/repos/providers.rs index e195e4a..ff2d80e 100644 --- a/src/db/repos/providers.rs +++ b/src/db/repos/providers.rs @@ -7,7 +7,8 @@ use crate::{ models::{CreateDynamicProvider, DynamicProvider, ProviderOwner, UpdateDynamicProvider}, }; -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait DynamicProviderRepo: Send + Sync { async fn create(&self, id: Uuid, input: CreateDynamicProvider) -> DbResult; async fn get_by_id(&self, id: Uuid) -> DbResult>; diff --git a/src/db/repos/scim_configs.rs b/src/db/repos/scim_configs.rs index 76afab6..4896290 100644 --- a/src/db/repos/scim_configs.rs +++ b/src/db/repos/scim_configs.rs @@ -16,7 +16,8 @@ use crate::{ /// Note: The token_hash is stored in the database, not in a secret manager /// (unlike SSO client secrets), because SCIM tokens need fast lookup for /// every provisioning request. -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait OrgScimConfigRepo: Send + Sync { /// Create a new SCIM configuration for an organization. /// diff --git a/src/db/repos/scim_group_mappings.rs b/src/db/repos/scim_group_mappings.rs index 97ac342..7412e1b 100644 --- a/src/db/repos/scim_group_mappings.rs +++ b/src/db/repos/scim_group_mappings.rs @@ -15,7 +15,8 @@ use crate::{ /// SCIM group mappings link SCIM groups (from the IdP) to Hadrian teams. /// When the IdP pushes group membership changes via SCIM, we update /// team memberships in Hadrian accordingly. -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait ScimGroupMappingRepo: Send + Sync { /// Create a new SCIM group mapping. /// diff --git a/src/db/repos/scim_user_mappings.rs b/src/db/repos/scim_user_mappings.rs index 441820a..b79fc18 100644 --- a/src/db/repos/scim_user_mappings.rs +++ b/src/db/repos/scim_user_mappings.rs @@ -17,7 +17,8 @@ use crate::{ /// - Looking up Hadrian users by SCIM ID during provisioning /// - Tracking SCIM "active" status separately from user existence /// - Supporting the same user in multiple organizations with different SCIM IDs -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait ScimUserMappingRepo: Send + Sync { /// Create a new SCIM user mapping. /// diff --git a/src/db/repos/service_accounts.rs b/src/db/repos/service_accounts.rs index 4593687..662aba0 100644 --- a/src/db/repos/service_accounts.rs +++ b/src/db/repos/service_accounts.rs @@ -7,7 +7,8 @@ use crate::{ models::{CreateServiceAccount, ServiceAccount, UpdateServiceAccount}, }; -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait ServiceAccountRepo: Send + Sync { /// Create a new service account within an organization. async fn create(&self, org_id: Uuid, input: CreateServiceAccount) -> DbResult; diff --git a/src/db/repos/sso_group_mappings.rs b/src/db/repos/sso_group_mappings.rs index e2c1fa5..7d5551b 100644 --- a/src/db/repos/sso_group_mappings.rs +++ b/src/db/repos/sso_group_mappings.rs @@ -12,7 +12,8 @@ use crate::{ /// SSO group mappings define how IdP groups are mapped to Hadrian teams and roles /// during JIT (Just-in-Time) provisioning. When a user logs in via SSO, their /// IdP groups are looked up in this table to determine team memberships. -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait SsoGroupMappingRepo: Send + Sync { /// Create a new SSO group mapping. async fn create(&self, org_id: Uuid, input: CreateSsoGroupMapping) diff --git a/src/db/repos/teams.rs b/src/db/repos/teams.rs index 9dd3680..4276c05 100644 --- a/src/db/repos/teams.rs +++ b/src/db/repos/teams.rs @@ -9,7 +9,8 @@ use crate::{ }, }; -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait TeamRepo: Send + Sync { /// Create a new team within an organization. async fn create(&self, org_id: Uuid, input: CreateTeam) -> DbResult; diff --git a/src/db/repos/usage.rs b/src/db/repos/usage.rs index 41c3704..9afd29a 100644 --- a/src/db/repos/usage.rs +++ b/src/db/repos/usage.rs @@ -24,7 +24,8 @@ pub struct UsageStats { pub sample_days: i32, } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait UsageRepo: Send + Sync { /// Log a single usage entry. async fn log(&self, entry: UsageLogEntry) -> DbResult<()>; diff --git a/src/db/repos/users.rs b/src/db/repos/users.rs index 2a2cfff..41701f3 100644 --- a/src/db/repos/users.rs +++ b/src/db/repos/users.rs @@ -10,7 +10,8 @@ use crate::{ }, }; -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait UserRepo: Send + Sync { async fn create(&self, input: CreateUser) -> DbResult; async fn get_by_id(&self, id: Uuid) -> DbResult>; diff --git a/src/db/repos/vector_stores.rs b/src/db/repos/vector_stores.rs index b2a1671..226d8e2 100644 --- a/src/db/repos/vector_stores.rs +++ b/src/db/repos/vector_stores.rs @@ -12,7 +12,8 @@ use crate::{ }; /// Repository trait for collections (vector stores) operations -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait VectorStoresRepo: Send + Sync { // ==================== Vector Stores CRUD ==================== diff --git a/src/db/sqlite/api_keys.rs b/src/db/sqlite/api_keys.rs index aeb91fe..acd44e7 100644 --- a/src/db/sqlite/api_keys.rs +++ b/src/db/sqlite/api_keys.rs @@ -1,8 +1,8 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; +use super::backend::{Pool, RowExt, begin, map_unique_violation, query, query_scalar}; use crate::{ db::{ error::{DbError, DbResult}, @@ -15,11 +15,11 @@ use crate::{ }; pub struct SqliteApiKeyRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteApiKeyRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } @@ -58,39 +58,42 @@ impl SqliteApiKeyRepo { } } - fn parse_api_key(row: &sqlx::sqlite::SqliteRow) -> DbResult { - let owner = Self::parse_owner(row.get("owner_type"), row.get("owner_id"))?; - let budget_period: Option = row.get("budget_period"); + fn parse_api_key(row: &super::backend::Row) -> DbResult { + let owner_type: String = row.col("owner_type"); + let owner_id: String = row.col("owner_id"); + let owner = Self::parse_owner(&owner_type, &owner_id)?; + let budget_period: Option = row.col("budget_period"); // Parse JSON columns - let scopes: Option = row.get("scopes"); - let allowed_models: Option = row.get("allowed_models"); - let ip_allowlist: Option = row.get("ip_allowlist"); + let scopes: Option = row.col("scopes"); + let allowed_models: Option = row.col("allowed_models"); + let ip_allowlist: Option = row.col("ip_allowlist"); Ok(ApiKey { - id: Uuid::parse_str(row.get("id")).map_err(|e| DbError::Internal(e.to_string()))?, - key_prefix: row.get("key_prefix"), - name: row.get("name"), + id: Uuid::parse_str(&row.col::("id")) + .map_err(|e| DbError::Internal(e.to_string()))?, + key_prefix: row.col("key_prefix"), + name: row.col("name"), owner, - budget_limit_cents: row.get("budget_amount"), + budget_limit_cents: row.col("budget_amount"), budget_period: budget_period.and_then(|p| match p.as_str() { "daily" => Some(BudgetPeriod::Daily), "monthly" => Some(BudgetPeriod::Monthly), _ => None, }), - created_at: row.get("created_at"), - expires_at: row.get("expires_at"), - revoked_at: row.get("revoked_at"), - last_used_at: row.get("last_used_at"), + created_at: row.col("created_at"), + expires_at: row.col("expires_at"), + revoked_at: row.col("revoked_at"), + last_used_at: row.col("last_used_at"), scopes: scopes.and_then(|s| serde_json::from_str(&s).ok()), allowed_models: allowed_models.and_then(|s| serde_json::from_str(&s).ok()), ip_allowlist: ip_allowlist.and_then(|s| serde_json::from_str(&s).ok()), - rate_limit_rpm: row.get("rate_limit_rpm"), - rate_limit_tpm: row.get("rate_limit_tpm"), + rate_limit_rpm: row.col("rate_limit_rpm"), + rate_limit_tpm: row.col("rate_limit_tpm"), rotated_from_key_id: row - .get::, _>("rotated_from_key_id") + .col::>("rotated_from_key_id") .and_then(|s| Uuid::parse_str(&s).ok()), - rotation_grace_until: row.get("rotation_grace_until"), + rotation_grace_until: row.col("rotation_grace_until"), }) } @@ -106,7 +109,7 @@ impl SqliteApiKeyRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT id, key_prefix, name, owner_type, owner_id, budget_amount, budget_period, expires_at, last_used_at, created_at, revoked_at, @@ -121,7 +124,7 @@ impl SqliteApiKeyRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(org_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -160,7 +163,7 @@ impl SqliteApiKeyRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT id, key_prefix, name, owner_type, owner_id, budget_amount, budget_period, expires_at, last_used_at, created_at, revoked_at, @@ -175,7 +178,7 @@ impl SqliteApiKeyRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(project_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -214,7 +217,7 @@ impl SqliteApiKeyRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT id, key_prefix, name, owner_type, owner_id, budget_amount, budget_period, expires_at, last_used_at, created_at, revoked_at, @@ -229,7 +232,7 @@ impl SqliteApiKeyRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(team_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -268,7 +271,7 @@ impl SqliteApiKeyRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT id, key_prefix, name, owner_type, owner_id, budget_amount, budget_period, expires_at, last_used_at, created_at, revoked_at, @@ -283,7 +286,7 @@ impl SqliteApiKeyRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(user_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -322,7 +325,7 @@ impl SqliteApiKeyRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT id, key_prefix, name, owner_type, owner_id, budget_amount, budget_period, expires_at, last_used_at, created_at, revoked_at, @@ -337,7 +340,7 @@ impl SqliteApiKeyRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(service_account_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -365,7 +368,8 @@ impl SqliteApiKeyRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl ApiKeyRepo for SqliteApiKeyRepo { async fn create(&self, input: CreateApiKey, key_hash: &str) -> DbResult { let id = Uuid::new_v4(); @@ -380,7 +384,7 @@ impl ApiKeyRepo for SqliteApiKeyRepo { let (owner_type, owner_id) = Self::owner_to_parts(&input.owner); - sqlx::query( + query( r#" INSERT INTO api_keys ( id, name, key_hash, key_prefix, owner_type, owner_id, @@ -424,12 +428,9 @@ impl ApiKeyRepo for SqliteApiKeyRepo { .bind(now) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict("API key with this hash already exists".to_string()) - } - _ => DbError::from(e), - })?; + .map_err(map_unique_violation( + "API key with this hash already exists", + ))?; Ok(ApiKey { id, @@ -453,7 +454,7 @@ impl ApiKeyRepo for SqliteApiKeyRepo { } async fn get_by_id(&self, id: Uuid) -> DbResult> { - let row = sqlx::query( + let row = query( r#" SELECT id, key_prefix, name, owner_type, owner_id, budget_amount, budget_period, expires_at, last_used_at, created_at, revoked_at, @@ -476,7 +477,7 @@ impl ApiKeyRepo for SqliteApiKeyRepo { async fn get_by_hash(&self, key_hash: &str) -> DbResult> { let now = Utc::now(); - let row = sqlx::query( + let row = query( r#" SELECT k.id, k.key_prefix, k.name, k.owner_type, k.owner_id, @@ -515,15 +516,15 @@ impl ApiKeyRepo for SqliteApiKeyRepo { let key = Self::parse_api_key(&row)?; - let org_id: Option = row.get("org_id"); - let team_id: Option = row.get("team_id"); - let project_id: Option = row.get("project_id"); - let user_id: Option = row.get("user_id"); - let service_account_id: Option = row.get("service_account_id"); + let org_id: Option = row.col("org_id"); + let team_id: Option = row.col("team_id"); + let project_id: Option = row.col("project_id"); + let user_id: Option = row.col("user_id"); + let service_account_id: Option = row.col("service_account_id"); // Parse service account roles from JSON TEXT let service_account_roles: Option> = row - .get::, _>("service_account_roles") + .col::>("service_account_roles") .and_then(|s| serde_json::from_str(&s).ok()); Ok(Some(ApiKeyWithOwner { @@ -549,7 +550,7 @@ impl ApiKeyRepo for SqliteApiKeyRepo { } // First page (no cursor provided) - let rows = sqlx::query( + let rows = query( r#" SELECT id, key_prefix, name, owner_type, owner_id, budget_amount, budget_period, expires_at, last_used_at, created_at, revoked_at, @@ -582,13 +583,13 @@ impl ApiKeyRepo for SqliteApiKeyRepo { } async fn count_by_org(&self, org_id: Uuid, _include_deleted: bool) -> DbResult { - let row = sqlx::query( + let row = query( "SELECT COUNT(*) as count FROM api_keys WHERE owner_type = 'organization' AND owner_id = ?", ) .bind(org_id.to_string()) .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn list_by_team( @@ -607,7 +608,7 @@ impl ApiKeyRepo for SqliteApiKeyRepo { } // First page (no cursor provided) - let rows = sqlx::query( + let rows = query( r#" SELECT id, key_prefix, name, owner_type, owner_id, budget_amount, budget_period, expires_at, last_used_at, created_at, revoked_at, @@ -640,13 +641,13 @@ impl ApiKeyRepo for SqliteApiKeyRepo { } async fn count_by_team(&self, team_id: Uuid, _include_deleted: bool) -> DbResult { - let row = sqlx::query( + let row = query( "SELECT COUNT(*) as count FROM api_keys WHERE owner_type = 'team' AND owner_id = ?", ) .bind(team_id.to_string()) .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn list_by_project( @@ -665,7 +666,7 @@ impl ApiKeyRepo for SqliteApiKeyRepo { } // First page (no cursor provided) - let rows = sqlx::query( + let rows = query( r#" SELECT id, key_prefix, name, owner_type, owner_id, budget_amount, budget_period, expires_at, last_used_at, created_at, revoked_at, @@ -698,13 +699,13 @@ impl ApiKeyRepo for SqliteApiKeyRepo { } async fn count_by_project(&self, project_id: Uuid, _include_deleted: bool) -> DbResult { - let row = sqlx::query( + let row = query( "SELECT COUNT(*) as count FROM api_keys WHERE owner_type = 'project' AND owner_id = ?", ) .bind(project_id.to_string()) .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn list_by_user( @@ -723,7 +724,7 @@ impl ApiKeyRepo for SqliteApiKeyRepo { } // First page (no cursor provided) - let rows = sqlx::query( + let rows = query( r#" SELECT id, key_prefix, name, owner_type, owner_id, budget_amount, budget_period, expires_at, last_used_at, created_at, revoked_at, @@ -756,17 +757,17 @@ impl ApiKeyRepo for SqliteApiKeyRepo { } async fn count_by_user(&self, user_id: Uuid, _include_deleted: bool) -> DbResult { - let row = sqlx::query( + let row = query( "SELECT COUNT(*) as count FROM api_keys WHERE owner_type = 'user' AND owner_id = ?", ) .bind(user_id.to_string()) .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn revoke(&self, id: Uuid) -> DbResult<()> { - sqlx::query( + query( r#" UPDATE api_keys SET revoked_at = datetime('now'), updated_at = datetime('now') @@ -781,7 +782,7 @@ impl ApiKeyRepo for SqliteApiKeyRepo { } async fn update_last_used(&self, id: Uuid) -> DbResult<()> { - sqlx::query( + query( r#" UPDATE api_keys SET last_used_at = datetime('now') @@ -796,7 +797,7 @@ impl ApiKeyRepo for SqliteApiKeyRepo { } async fn revoke_by_user(&self, user_id: Uuid) -> DbResult { - let result = sqlx::query( + let result = query( r#" UPDATE api_keys SET revoked_at = datetime('now'), updated_at = datetime('now') @@ -832,7 +833,7 @@ impl ApiKeyRepo for SqliteApiKeyRepo { } // First page (no cursor provided) - let rows = sqlx::query( + let rows = query( r#" SELECT id, key_prefix, name, owner_type, owner_id, @@ -870,20 +871,20 @@ impl ApiKeyRepo for SqliteApiKeyRepo { service_account_id: Uuid, include_revoked: bool, ) -> DbResult { - let query = if include_revoked { + let sql = if include_revoked { "SELECT COUNT(*) as count FROM api_keys WHERE owner_type = 'service_account' AND owner_id = ?" } else { "SELECT COUNT(*) as count FROM api_keys WHERE owner_type = 'service_account' AND owner_id = ? AND revoked_at IS NULL" }; - let row = sqlx::query(query) + let row = query(sql) .bind(service_account_id.to_string()) .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn revoke_by_service_account(&self, service_account_id: Uuid) -> DbResult { - let result = sqlx::query( + let result = query( r#" UPDATE api_keys SET revoked_at = datetime('now'), updated_at = datetime('now') @@ -916,11 +917,10 @@ impl ApiKeyRepo for SqliteApiKeyRepo { let (owner_type, owner_id) = Self::owner_to_parts(&new_key_input.owner); - // Use a transaction to ensure both operations succeed or fail together - let mut tx = self.pool.begin().await?; + let mut tx = begin(&self.pool).await?; // 1. Update old key with grace period - sqlx::query( + query( r#" UPDATE api_keys SET rotation_grace_until = ?, updated_at = datetime('now') @@ -933,7 +933,7 @@ impl ApiKeyRepo for SqliteApiKeyRepo { .await?; // 2. Insert new key with rotated_from_key_id - sqlx::query( + query( r#" INSERT INTO api_keys ( id, name, key_hash, key_prefix, owner_type, owner_id, @@ -979,12 +979,9 @@ impl ApiKeyRepo for SqliteApiKeyRepo { .bind(now) .execute(&mut *tx) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict("API key with this hash already exists".to_string()) - } - _ => DbError::from(e), - })?; + .map_err(map_unique_violation( + "API key with this hash already exists", + ))?; tx.commit().await?; @@ -1013,7 +1010,7 @@ impl ApiKeyRepo for SqliteApiKeyRepo { &self, service_account_id: Uuid, ) -> DbResult> { - let hashes: Vec = sqlx::query_scalar( + let hashes: Vec = query_scalar( r#" SELECT key_hash FROM api_keys @@ -1030,7 +1027,7 @@ impl ApiKeyRepo for SqliteApiKeyRepo { } async fn get_key_hashes_by_user(&self, user_id: Uuid) -> DbResult> { - let hashes: Vec = sqlx::query_scalar( + let hashes: Vec = query_scalar( r#" SELECT key_hash FROM api_keys @@ -1047,7 +1044,7 @@ impl ApiKeyRepo for SqliteApiKeyRepo { } async fn get_by_name_and_org(&self, org_id: Uuid, name: &str) -> DbResult> { - let row = sqlx::query( + let row = query( r#" SELECT id, key_prefix, name, owner_type, owner_id, budget_amount, budget_period, expires_at, last_used_at, created_at, revoked_at, @@ -1072,6 +1069,8 @@ impl ApiKeyRepo for SqliteApiKeyRepo { #[cfg(test)] mod tests { + use sqlx::SqlitePool; + use super::*; use crate::db::repos::ApiKeyRepo; diff --git a/src/db/sqlite/audit_logs.rs b/src/db/sqlite/audit_logs.rs index 79c1f55..77b4767 100644 --- a/src/db/sqlite/audit_logs.rs +++ b/src/db/sqlite/audit_logs.rs @@ -1,9 +1,11 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; -use super::common::parse_uuid; +use super::{ + backend::{Pool, RowExt, query}, + common::parse_uuid, +}; use crate::{ db::{ error::DbResult, @@ -16,11 +18,11 @@ use crate::{ }; pub struct SqliteAuditLogRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteAuditLogRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } @@ -30,7 +32,8 @@ impl SqliteAuditLogRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl AuditLogRepo for SqliteAuditLogRepo { async fn create(&self, input: CreateAuditLog) -> DbResult { let id = Uuid::new_v4(); @@ -38,7 +41,7 @@ impl AuditLogRepo for SqliteAuditLogRepo { let now = truncate_to_millis(chrono::Utc::now()); let details_json = serde_json::to_string(&input.details)?; - sqlx::query( + query( r#" INSERT INTO audit_logs ( id, timestamp, actor_type, actor_id, action, @@ -80,7 +83,7 @@ impl AuditLogRepo for SqliteAuditLogRepo { } async fn get_by_id(&self, id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, timestamp, actor_type, actor_id, action, resource_type, resource_id, org_id, project_id, @@ -95,43 +98,43 @@ impl AuditLogRepo for SqliteAuditLogRepo { match result { Some(row) => { - let actor_id: Option = row.get("actor_id"); - let org_id: Option = row.get("org_id"); - let project_id: Option = row.get("project_id"); - let details_str: String = row.get("details"); + let actor_id: Option = row.col("actor_id"); + let org_id: Option = row.col("org_id"); + let project_id: Option = row.col("project_id"); + let details_str: String = row.col("details"); Ok(Some(AuditLog { - id: parse_uuid(&row.get::("id"))?, - timestamp: row.get("timestamp"), - actor_type: Self::parse_actor_type(&row.get::("actor_type"))?, + id: parse_uuid(&row.col::("id"))?, + timestamp: row.col("timestamp"), + actor_type: Self::parse_actor_type(&row.col::("actor_type"))?, actor_id: actor_id.map(|s| parse_uuid(&s)).transpose()?, - action: row.get("action"), - resource_type: row.get("resource_type"), - resource_id: parse_uuid(&row.get::("resource_id"))?, + action: row.col("action"), + resource_type: row.col("resource_type"), + resource_id: parse_uuid(&row.col::("resource_id"))?, org_id: org_id.map(|s| parse_uuid(&s)).transpose()?, project_id: project_id.map(|s| parse_uuid(&s)).transpose()?, details: serde_json::from_str(&details_str)?, - ip_address: row.get("ip_address"), - user_agent: row.get("user_agent"), + ip_address: row.col("ip_address"), + user_agent: row.col("user_agent"), })) } None => Ok(None), } } - async fn list(&self, query: AuditLogQuery) -> DbResult> { - let limit = query.limit.unwrap_or(100); + async fn list(&self, filter: AuditLogQuery) -> DbResult> { + let limit = filter.limit.unwrap_or(100); let fetch_limit = limit + 1; // Fetch one extra to determine if there are more items // Parse cursor if provided - let cursor = match &query.cursor { + let cursor = match &filter.cursor { Some(c) => Some(Cursor::decode(c).map_err(|e| { crate::db::error::DbError::Internal(format!("Invalid cursor: {}", e)) })?), None => None, }; - let direction = match query.direction.as_deref() { + let direction = match filter.direction.as_deref() { Some("backward") => CursorDirection::Backward, _ => CursorDirection::Forward, }; @@ -140,39 +143,39 @@ impl AuditLogRepo for SqliteAuditLogRepo { let mut conditions = Vec::new(); let mut params: Vec = Vec::new(); - if let Some(actor_type) = &query.actor_type { + if let Some(actor_type) = &filter.actor_type { conditions.push("actor_type = ?".to_string()); params.push(actor_type.to_string()); } - if let Some(actor_id) = &query.actor_id { + if let Some(actor_id) = &filter.actor_id { conditions.push("actor_id = ?".to_string()); params.push(actor_id.to_string()); } - if let Some(action) = &query.action { + if let Some(action) = &filter.action { conditions.push("action = ?".to_string()); params.push(action.clone()); } - if let Some(resource_type) = &query.resource_type { + if let Some(resource_type) = &filter.resource_type { conditions.push("resource_type = ?".to_string()); params.push(resource_type.clone()); } - if let Some(resource_id) = &query.resource_id { + if let Some(resource_id) = &filter.resource_id { conditions.push("resource_id = ?".to_string()); params.push(resource_id.to_string()); } - if let Some(org_id) = &query.org_id { + if let Some(org_id) = &filter.org_id { conditions.push("org_id = ?".to_string()); params.push(org_id.to_string()); } - if let Some(project_id) = &query.project_id { + if let Some(project_id) = &filter.project_id { conditions.push("project_id = ?".to_string()); params.push(project_id.to_string()); } - if let Some(from) = &query.from { + if let Some(from) = &filter.from { conditions.push("timestamp >= ?".to_string()); params.push(from.to_rfc3339()); } - if let Some(to) = &query.to { + if let Some(to) = &filter.to { conditions.push("timestamp < ?".to_string()); params.push(to.to_rfc3339()); } @@ -234,7 +237,7 @@ impl AuditLogRepo for SqliteAuditLogRepo { ) }; - let mut query_builder = sqlx::query(&sql); + let mut query_builder = query(&sql); for param in ¶ms { query_builder = query_builder.bind(param); } @@ -251,24 +254,24 @@ impl AuditLogRepo for SqliteAuditLogRepo { .into_iter() .take(limit as usize) .map(|row| { - let actor_id: Option = row.get("actor_id"); - let org_id: Option = row.get("org_id"); - let project_id: Option = row.get("project_id"); - let details_str: String = row.get("details"); + let actor_id: Option = row.col("actor_id"); + let org_id: Option = row.col("org_id"); + let project_id: Option = row.col("project_id"); + let details_str: String = row.col("details"); Ok(AuditLog { - id: parse_uuid(&row.get::("id"))?, - timestamp: row.get("timestamp"), - actor_type: Self::parse_actor_type(&row.get::("actor_type"))?, + id: parse_uuid(&row.col::("id"))?, + timestamp: row.col("timestamp"), + actor_type: Self::parse_actor_type(&row.col::("actor_type"))?, actor_id: actor_id.map(|s| parse_uuid(&s)).transpose()?, - action: row.get("action"), - resource_type: row.get("resource_type"), - resource_id: parse_uuid(&row.get::("resource_id"))?, + action: row.col("action"), + resource_type: row.col("resource_type"), + resource_id: parse_uuid(&row.col::("resource_id"))?, org_id: org_id.map(|s| parse_uuid(&s)).transpose()?, project_id: project_id.map(|s| parse_uuid(&s)).transpose()?, details: serde_json::from_str(&details_str)?, - ip_address: row.get("ip_address"), - user_agent: row.get("user_agent"), + ip_address: row.col("ip_address"), + user_agent: row.col("user_agent"), }) }) .collect::>>()?; @@ -287,44 +290,44 @@ impl AuditLogRepo for SqliteAuditLogRepo { Ok(ListResult::new(items, has_more, cursors)) } - async fn count(&self, query: AuditLogQuery) -> DbResult { + async fn count(&self, filter: AuditLogQuery) -> DbResult { // Build dynamic WHERE clause let mut conditions = Vec::new(); let mut params: Vec = Vec::new(); - if let Some(actor_type) = &query.actor_type { + if let Some(actor_type) = &filter.actor_type { conditions.push("actor_type = ?"); params.push(actor_type.to_string()); } - if let Some(actor_id) = &query.actor_id { + if let Some(actor_id) = &filter.actor_id { conditions.push("actor_id = ?"); params.push(actor_id.to_string()); } - if let Some(action) = &query.action { + if let Some(action) = &filter.action { conditions.push("action = ?"); params.push(action.clone()); } - if let Some(resource_type) = &query.resource_type { + if let Some(resource_type) = &filter.resource_type { conditions.push("resource_type = ?"); params.push(resource_type.clone()); } - if let Some(resource_id) = &query.resource_id { + if let Some(resource_id) = &filter.resource_id { conditions.push("resource_id = ?"); params.push(resource_id.to_string()); } - if let Some(org_id) = &query.org_id { + if let Some(org_id) = &filter.org_id { conditions.push("org_id = ?"); params.push(org_id.to_string()); } - if let Some(project_id) = &query.project_id { + if let Some(project_id) = &filter.project_id { conditions.push("project_id = ?"); params.push(project_id.to_string()); } - if let Some(from) = &query.from { + if let Some(from) = &filter.from { conditions.push("timestamp >= ?"); params.push(from.to_rfc3339()); } - if let Some(to) = &query.to { + if let Some(to) = &filter.to { conditions.push("timestamp < ?"); params.push(to.to_rfc3339()); } @@ -337,13 +340,13 @@ impl AuditLogRepo for SqliteAuditLogRepo { let sql = format!("SELECT COUNT(*) as count FROM audit_logs {}", where_clause); - let mut query_builder = sqlx::query(&sql); + let mut query_builder = query(&sql); for param in ¶ms { query_builder = query_builder.bind(param); } let row = query_builder.fetch_one(&self.pool).await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } // ==================== Retention Operations ==================== @@ -365,7 +368,7 @@ impl AuditLogRepo for SqliteAuditLogRepo { let limit = std::cmp::min(batch_size as u64, remaining) as i64; // Delete a batch using subquery to select IDs - let result = sqlx::query( + let result = query( r#" DELETE FROM audit_logs WHERE id IN ( @@ -396,6 +399,7 @@ impl AuditLogRepo for SqliteAuditLogRepo { mod tests { use chrono::Duration; use serde_json::json; + use sqlx::SqlitePool; use super::*; use crate::db::repos::AuditLogRepo; diff --git a/src/db/sqlite/backend.rs b/src/db/sqlite/backend.rs new file mode 100644 index 0000000..11e8d57 --- /dev/null +++ b/src/db/sqlite/backend.rs @@ -0,0 +1,239 @@ +//! Backend abstraction layer that allows the same repo code to work with both +//! native SQLite (via sqlx) and WASM SQLite (via wa-sqlite JS bridge). +//! +//! Provides cfg-switched type aliases, a unified row access trait, query +//! constructor, and error helpers. + +use crate::db::error::DbError; + +// ───────────────────────────────────────────────────────────────────────────── +// Pool type alias +// ───────────────────────────────────────────────────────────────────────────── + +#[cfg(feature = "database-sqlite")] +pub(crate) type Pool = sqlx::SqlitePool; + +#[cfg(feature = "database-wasm-sqlite")] +pub(crate) type Pool = crate::db::wasm_sqlite::WasmSqlitePool; + +// ───────────────────────────────────────────────────────────────────────────── +// Row type alias +// ───────────────────────────────────────────────────────────────────────────── + +#[cfg(feature = "database-sqlite")] +pub(crate) type Row = sqlx::sqlite::SqliteRow; + +#[cfg(feature = "database-wasm-sqlite")] +pub(crate) type Row = crate::db::wasm_sqlite::WasmRow; + +// ───────────────────────────────────────────────────────────────────────────── +// Error type alias +// ───────────────────────────────────────────────────────────────────────────── + +#[cfg(feature = "database-sqlite")] +pub(crate) type BackendError = sqlx::Error; + +#[cfg(feature = "database-wasm-sqlite")] +pub(crate) type BackendError = crate::db::wasm_sqlite::WasmDbError; + +// ───────────────────────────────────────────────────────────────────────────── +// ColDecode — bridging sqlx and WASM decode traits +// ───────────────────────────────────────────────────────────────────────────── + +#[cfg(feature = "database-sqlite")] +pub(crate) trait ColDecode: + for<'r> sqlx::Decode<'r, sqlx::Sqlite> + sqlx::Type +{ +} + +#[cfg(feature = "database-sqlite")] +impl sqlx::Decode<'r, sqlx::Sqlite> + sqlx::Type> ColDecode for T {} + +#[cfg(feature = "database-wasm-sqlite")] +pub(crate) trait ColDecode: crate::db::wasm_sqlite::WasmDecode {} + +#[cfg(feature = "database-wasm-sqlite")] +impl ColDecode for T {} + +// ───────────────────────────────────────────────────────────────────────────── +// RowExt — unified row access with `.col::("name")` +// ───────────────────────────────────────────────────────────────────────────── + +pub(crate) trait RowExt { + fn col(&self, name: &str) -> T; +} + +#[cfg(feature = "database-sqlite")] +impl RowExt for sqlx::sqlite::SqliteRow { + fn col(&self, name: &str) -> T { + use sqlx::Row; + self.get(name) + } +} + +#[cfg(feature = "database-wasm-sqlite")] +impl RowExt for crate::db::wasm_sqlite::WasmRow { + fn col(&self, name: &str) -> T { + self.get(name) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Query constructor +// ───────────────────────────────────────────────────────────────────────────── + +#[cfg(feature = "database-sqlite")] +pub(crate) fn query( + sql: &str, +) -> sqlx::query::Query<'_, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'_>> { + sqlx::query(sql) +} + +#[cfg(feature = "database-wasm-sqlite")] +pub(crate) fn query(sql: &str) -> crate::db::wasm_sqlite::WasmQuery { + crate::db::wasm_sqlite::query(sql) +} + +// ───────────────────────────────────────────────────────────────────────────── +// Scalar query constructor +// ───────────────────────────────────────────────────────────────────────────── + +/// Create a query that returns a single column, decoded as `T`. +/// Mirrors `sqlx::query_scalar()`. +#[cfg(feature = "database-sqlite")] +pub(crate) fn query_scalar( + sql: &str, +) -> sqlx::query::QueryScalar<'_, sqlx::Sqlite, T, sqlx::sqlite::SqliteArguments<'_>> +where + T: sqlx::Type + for<'r> sqlx::Decode<'r, sqlx::Sqlite>, + (T,): for<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow>, +{ + sqlx::query_scalar(sql) +} + +#[cfg(feature = "database-wasm-sqlite")] +pub(crate) fn query_scalar( + sql: &str, +) -> crate::db::wasm_sqlite::WasmQueryScalar { + crate::db::wasm_sqlite::WasmQueryScalar::new(sql) +} + +// ───────────────────────────────────────────────────────────────────────────── +// Transaction +// ───────────────────────────────────────────────────────────────────────────── + +/// A database transaction that works for both native and WASM SQLite. +/// +/// Native: wraps `sqlx::Transaction<'a, Sqlite>` with auto-rollback on drop. +/// WASM: sends BEGIN/COMMIT/ROLLBACK through the JS bridge. +/// +/// Use `begin(&pool)` to start, then `query(sql).execute(&mut *tx)` for +/// queries, and `tx.commit()` to finalize. +#[cfg(feature = "database-sqlite")] +pub(crate) struct Transaction<'a>(sqlx::Transaction<'a, sqlx::Sqlite>); + +#[cfg(feature = "database-wasm-sqlite")] +pub(crate) struct Transaction<'a>(Pool, std::marker::PhantomData<&'a ()>); + +/// Begin a new transaction. +#[cfg(feature = "database-sqlite")] +pub(crate) async fn begin(pool: &Pool) -> Result, BackendError> { + Ok(Transaction(pool.begin().await?)) +} + +#[cfg(feature = "database-wasm-sqlite")] +pub(crate) async fn begin(pool: &Pool) -> Result, BackendError> { + query("BEGIN").execute(pool).await?; + Ok(Transaction(pool.clone(), std::marker::PhantomData)) +} + +#[cfg(feature = "database-sqlite")] +impl Transaction<'_> { + pub async fn commit(self) -> Result<(), BackendError> { + self.0.commit().await + } +} + +#[cfg(feature = "database-wasm-sqlite")] +impl Transaction<'_> { + pub async fn commit(self) -> Result<(), BackendError> { + query("COMMIT").execute(&self.0).await?; + Ok(()) + } +} + +#[cfg(feature = "database-sqlite")] +impl std::ops::Deref for Transaction<'_> { + type Target = sqlx::SqliteConnection; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[cfg(feature = "database-sqlite")] +impl std::ops::DerefMut for Transaction<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +#[cfg(feature = "database-wasm-sqlite")] +impl std::ops::Deref for Transaction<'_> { + type Target = crate::db::wasm_sqlite::WasmSqlitePool; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[cfg(feature = "database-wasm-sqlite")] +impl std::ops::DerefMut for Transaction<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Error helpers +// ───────────────────────────────────────────────────────────────────────────── + +/// Check if a backend error is a unique constraint violation. +#[cfg(feature = "database-sqlite")] +pub(crate) fn is_unique_violation(e: &BackendError) -> bool { + matches!(e, sqlx::Error::Database(db_err) if db_err.is_unique_violation()) +} + +#[cfg(feature = "database-wasm-sqlite")] +pub(crate) fn is_unique_violation(e: &BackendError) -> bool { + e.is_unique_violation() +} + +/// Map a backend error to `DbError`, converting unique violations to `DbError::Conflict`. +pub(crate) fn map_unique_violation(msg: impl Into) -> impl FnOnce(BackendError) -> DbError { + let msg = msg.into(); + move |e: BackendError| { + if is_unique_violation(&e) { + DbError::Conflict(msg) + } else { + DbError::from(e) + } + } +} + +/// Extract the error message from a unique violation, or `None` if not a unique violation. +#[cfg(feature = "database-sqlite")] +pub(crate) fn unique_violation_message(e: &BackendError) -> Option { + match e { + sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { + Some(db_err.message().to_string()) + } + _ => None, + } +} + +#[cfg(feature = "database-wasm-sqlite")] +pub(crate) fn unique_violation_message(e: &BackendError) -> Option { + match e { + crate::db::wasm_sqlite::WasmDbError::UniqueViolation(msg) => Some(msg.clone()), + _ => None, + } +} diff --git a/src/db/sqlite/conversations.rs b/src/db/sqlite/conversations.rs index 3873aec..811cdeb 100644 --- a/src/db/sqlite/conversations.rs +++ b/src/db/sqlite/conversations.rs @@ -1,9 +1,11 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; -use super::common::parse_uuid; +use super::{ + backend::{Pool, RowExt, query}, + common::parse_uuid, +}; use crate::{ db::{ error::{DbError, DbResult}, @@ -16,11 +18,11 @@ use crate::{ }; pub struct SqliteConversationRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteConversationRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } @@ -61,7 +63,7 @@ impl SqliteConversationRepo { "AND deleted_at IS NULL" }; - let query = format!( + let sql = format!( r#" SELECT id, owner_type, owner_id, title, models, messages, pin_order, created_at, updated_at FROM conversations @@ -74,7 +76,7 @@ impl SqliteConversationRepo { comparison, deleted_filter, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(owner_type.as_str()) .bind(owner_id.to_string()) .bind(cursor.created_at) // cursor.created_at holds updated_at value @@ -88,22 +90,22 @@ impl SqliteConversationRepo { .into_iter() .take(limit as usize) .map(|row| { - let owner_type_str: String = row.get("owner_type"); - let models_json: String = row.get("models"); - let messages_json: String = row.get("messages"); + let owner_type_str: String = row.col("owner_type"); + let models_json: String = row.col("models"); + let messages_json: String = row.col("messages"); Ok(Conversation { - id: parse_uuid(&row.get::("id"))?, + id: parse_uuid(&row.col::("id"))?, owner_type: owner_type_str .parse() .map_err(|e: String| DbError::Internal(e))?, - owner_id: parse_uuid(&row.get::("owner_id"))?, - title: row.get("title"), + owner_id: parse_uuid(&row.col::("owner_id"))?, + title: row.col("title"), models: Self::parse_models(&models_json)?, messages: Self::parse_messages(&messages_json)?, - pin_order: row.get("pin_order"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + pin_order: row.col("pin_order"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) }) .collect::>>()?; @@ -125,7 +127,8 @@ impl SqliteConversationRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl ConversationRepo for SqliteConversationRepo { async fn create(&self, input: CreateConversation) -> DbResult { let id = Uuid::new_v4(); @@ -137,7 +140,7 @@ impl ConversationRepo for SqliteConversationRepo { let messages_json = serde_json::to_string(&input.messages).map_err(|e| DbError::Internal(e.to_string()))?; - sqlx::query( + query( r#" INSERT INTO conversations (id, owner_type, owner_id, title, models, messages, pin_order, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, NULL, ?, ?) @@ -168,7 +171,7 @@ impl ConversationRepo for SqliteConversationRepo { } async fn get_by_id(&self, id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, owner_type, owner_id, title, models, messages, pin_order, created_at, updated_at FROM conversations @@ -181,22 +184,22 @@ impl ConversationRepo for SqliteConversationRepo { match result { Some(row) => { - let owner_type_str: String = row.get("owner_type"); - let models_json: String = row.get("models"); - let messages_json: String = row.get("messages"); + let owner_type_str: String = row.col("owner_type"); + let models_json: String = row.col("models"); + let messages_json: String = row.col("messages"); Ok(Some(Conversation { - id: parse_uuid(&row.get::("id"))?, + id: parse_uuid(&row.col::("id"))?, owner_type: owner_type_str .parse() .map_err(|e: String| DbError::Internal(e))?, - owner_id: parse_uuid(&row.get::("owner_id"))?, - title: row.get("title"), + owner_id: parse_uuid(&row.col::("owner_id"))?, + title: row.col("title"), models: Self::parse_models(&models_json)?, messages: Self::parse_messages(&messages_json)?, - pin_order: row.get("pin_order"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + pin_order: row.col("pin_order"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), })) } None => Ok(None), @@ -204,7 +207,7 @@ impl ConversationRepo for SqliteConversationRepo { } async fn get_by_id_and_org(&self, id: Uuid, org_id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT c.id, c.owner_type, c.owner_id, c.title, c.models, c.messages, c.pin_order, c.created_at, c.updated_at FROM conversations c @@ -228,22 +231,22 @@ impl ConversationRepo for SqliteConversationRepo { match result { Some(row) => { - let owner_type_str: String = row.get("owner_type"); - let models_json: String = row.get("models"); - let messages_json: String = row.get("messages"); + let owner_type_str: String = row.col("owner_type"); + let models_json: String = row.col("models"); + let messages_json: String = row.col("messages"); Ok(Some(Conversation { - id: parse_uuid(&row.get::("id"))?, + id: parse_uuid(&row.col::("id"))?, owner_type: owner_type_str .parse() .map_err(|e: String| DbError::Internal(e))?, - owner_id: parse_uuid(&row.get::("owner_id"))?, - title: row.get("title"), + owner_id: parse_uuid(&row.col::("owner_id"))?, + title: row.col("title"), models: Self::parse_models(&models_json)?, messages: Self::parse_messages(&messages_json)?, - pin_order: row.get("pin_order"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + pin_order: row.col("pin_order"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), })) } None => Ok(None), @@ -268,7 +271,7 @@ impl ConversationRepo for SqliteConversationRepo { } // First page (no cursor provided) - let query = if params.include_deleted { + let sql = if params.include_deleted { r#" SELECT id, owner_type, owner_id, title, models, messages, pin_order, created_at, updated_at FROM conversations @@ -286,7 +289,7 @@ impl ConversationRepo for SqliteConversationRepo { "# }; - let rows = sqlx::query(query) + let rows = query(sql) .bind(owner_type.as_str()) .bind(owner_id.to_string()) .bind(fetch_limit) @@ -298,22 +301,22 @@ impl ConversationRepo for SqliteConversationRepo { .into_iter() .take(limit as usize) .map(|row| { - let owner_type_str: String = row.get("owner_type"); - let models_json: String = row.get("models"); - let messages_json: String = row.get("messages"); + let owner_type_str: String = row.col("owner_type"); + let models_json: String = row.col("models"); + let messages_json: String = row.col("messages"); Ok(Conversation { - id: parse_uuid(&row.get::("id"))?, + id: parse_uuid(&row.col::("id"))?, owner_type: owner_type_str .parse() .map_err(|e: String| DbError::Internal(e))?, - owner_id: parse_uuid(&row.get::("owner_id"))?, - title: row.get("title"), + owner_id: parse_uuid(&row.col::("owner_id"))?, + title: row.col("title"), models: Self::parse_models(&models_json)?, messages: Self::parse_messages(&messages_json)?, - pin_order: row.get("pin_order"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + pin_order: row.col("pin_order"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) }) .collect::>>()?; @@ -336,18 +339,18 @@ impl ConversationRepo for SqliteConversationRepo { owner_id: Uuid, include_deleted: bool, ) -> DbResult { - let query = if include_deleted { + let sql = if include_deleted { "SELECT COUNT(*) as count FROM conversations WHERE owner_type = ? AND owner_id = ?" } else { "SELECT COUNT(*) as count FROM conversations WHERE owner_type = ? AND owner_id = ? AND deleted_at IS NULL" }; - let row = sqlx::query(query) + let row = query(sql) .bind(owner_type.as_str()) .bind(owner_id.to_string()) .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn update(&self, id: Uuid, input: UpdateConversation) -> DbResult { @@ -358,11 +361,11 @@ impl ConversationRepo for SqliteConversationRepo { // other writers until this transaction completes. // Note: SQLite doesn't support FOR UPDATE, so we use BEGIN IMMEDIATE instead. let mut conn = self.pool.acquire().await?; - sqlx::query("BEGIN IMMEDIATE").execute(&mut *conn).await?; + query("BEGIN IMMEDIATE").execute(&mut *conn).await?; let result = async { // Read current state within transaction (with write lock held) - let current_row = sqlx::query( + let current_row = query( r#" SELECT id, owner_type, owner_id, title, models, messages, pin_order, created_at, updated_at FROM conversations @@ -374,16 +377,16 @@ impl ConversationRepo for SqliteConversationRepo { .await? .ok_or(DbError::NotFound)?; - let current_owner_type_str: String = current_row.get("owner_type"); + let current_owner_type_str: String = current_row.col("owner_type"); let current_owner_type: ConversationOwnerType = current_owner_type_str .parse() .map_err(|e: String| DbError::Internal(e))?; - let current_owner_id = parse_uuid(¤t_row.get::("owner_id"))?; - let current_title: String = current_row.get("title"); - let current_models_json: String = current_row.get("models"); - let current_messages_json: String = current_row.get("messages"); - let pin_order: Option = current_row.get("pin_order"); - let created_at = current_row.get("created_at"); + let current_owner_id = parse_uuid(¤t_row.col::("owner_id"))?; + let current_title: String = current_row.col("title"); + let current_models_json: String = current_row.col("models"); + let current_messages_json: String = current_row.col("messages"); + let pin_order: Option = current_row.col("pin_order"); + let created_at = current_row.col("created_at"); // Determine new owner (if provided) or keep current let (new_owner_type, new_owner_id) = if let Some(ref owner) = input.owner { @@ -404,7 +407,7 @@ impl ConversationRepo for SqliteConversationRepo { let messages_json = serde_json::to_string(&new_messages) .map_err(|e| DbError::Internal(e.to_string()))?; - let update_result = sqlx::query( + let update_result = query( r#" UPDATE conversations SET owner_type = ?, owner_id = ?, title = ?, models = ?, messages = ?, updated_at = ? @@ -442,10 +445,10 @@ impl ConversationRepo for SqliteConversationRepo { // Commit or rollback based on result match &result { Ok(_) => { - sqlx::query("COMMIT").execute(&mut *conn).await?; + query("COMMIT").execute(&mut *conn).await?; } Err(_) => { - let _ = sqlx::query("ROLLBACK").execute(&mut *conn).await; + let _ = query("ROLLBACK").execute(&mut *conn).await; } } @@ -460,11 +463,11 @@ impl ConversationRepo for SqliteConversationRepo { // other writers until this transaction completes. // Note: SQLite doesn't support FOR UPDATE, so we use BEGIN IMMEDIATE instead. let mut conn = self.pool.acquire().await?; - sqlx::query("BEGIN IMMEDIATE").execute(&mut *conn).await?; + query("BEGIN IMMEDIATE").execute(&mut *conn).await?; let result = async { // Get current messages within transaction (with write lock held) - let current_row = sqlx::query( + let current_row = query( r#" SELECT messages FROM conversations @@ -476,7 +479,7 @@ impl ConversationRepo for SqliteConversationRepo { .await? .ok_or(DbError::NotFound)?; - let current_messages_json: String = current_row.get("messages"); + let current_messages_json: String = current_row.col("messages"); let mut messages = Self::parse_messages(¤t_messages_json)?; // Append new messages @@ -485,7 +488,7 @@ impl ConversationRepo for SqliteConversationRepo { let messages_json = serde_json::to_string(&messages).map_err(|e| DbError::Internal(e.to_string()))?; - let update_result = sqlx::query( + let update_result = query( r#" UPDATE conversations SET messages = ?, updated_at = ? @@ -509,10 +512,10 @@ impl ConversationRepo for SqliteConversationRepo { // Commit or rollback based on result match &result { Ok(_) => { - sqlx::query("COMMIT").execute(&mut *conn).await?; + query("COMMIT").execute(&mut *conn).await?; } Err(_) => { - let _ = sqlx::query("ROLLBACK").execute(&mut *conn).await; + let _ = query("ROLLBACK").execute(&mut *conn).await; } } @@ -522,7 +525,7 @@ impl ConversationRepo for SqliteConversationRepo { async fn delete(&self, id: Uuid) -> DbResult<()> { let now = chrono::Utc::now(); - let result = sqlx::query( + let result = query( r#" UPDATE conversations SET deleted_at = ? @@ -559,7 +562,7 @@ impl ConversationRepo for SqliteConversationRepo { "AND c.deleted_at IS NULL" }; - let query = format!( + let sql = format!( r#" SELECT c.id, @@ -604,7 +607,7 @@ impl ConversationRepo for SqliteConversationRepo { "# ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(user_id.to_string()) .bind(user_id.to_string()) .bind(limit) @@ -614,26 +617,26 @@ impl ConversationRepo for SqliteConversationRepo { let items: Vec = rows .into_iter() .map(|row| { - let owner_type_str: String = row.get("owner_type"); - let models_json: String = row.get("models"); - let messages_json: String = row.get("messages"); - let project_id: Option = row.get("project_id"); - let project_name: Option = row.get("project_name"); - let project_slug: Option = row.get("project_slug"); + let owner_type_str: String = row.col("owner_type"); + let models_json: String = row.col("models"); + let messages_json: String = row.col("messages"); + let project_id: Option = row.col("project_id"); + let project_name: Option = row.col("project_name"); + let project_slug: Option = row.col("project_slug"); Ok(ConversationWithProject { conversation: Conversation { - id: parse_uuid(&row.get::("id"))?, + id: parse_uuid(&row.col::("id"))?, owner_type: owner_type_str .parse() .map_err(|e: String| DbError::Internal(e))?, - owner_id: parse_uuid(&row.get::("owner_id"))?, - title: row.get("title"), + owner_id: parse_uuid(&row.col::("owner_id"))?, + title: row.col("title"), models: Self::parse_models(&models_json)?, messages: Self::parse_messages(&messages_json)?, - pin_order: row.get("pin_order"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + pin_order: row.col("pin_order"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }, project_id: project_id.map(|s| parse_uuid(&s)).transpose()?, project_name, @@ -650,11 +653,11 @@ impl ConversationRepo for SqliteConversationRepo { // Use IMMEDIATE transaction mode to acquire write lock let mut conn = self.pool.acquire().await?; - sqlx::query("BEGIN IMMEDIATE").execute(&mut *conn).await?; + query("BEGIN IMMEDIATE").execute(&mut *conn).await?; let result = async { // Read current state within transaction (with write lock held) - let current_row = sqlx::query( + let current_row = query( r#" SELECT id, owner_type, owner_id, title, models, messages, pin_order, created_at, updated_at FROM conversations @@ -666,17 +669,17 @@ impl ConversationRepo for SqliteConversationRepo { .await? .ok_or(DbError::NotFound)?; - let owner_type_str: String = current_row.get("owner_type"); + let owner_type_str: String = current_row.col("owner_type"); let owner_type: ConversationOwnerType = owner_type_str .parse() .map_err(|e: String| DbError::Internal(e))?; - let owner_id = parse_uuid(¤t_row.get::("owner_id"))?; - let title: String = current_row.get("title"); - let models_json: String = current_row.get("models"); - let messages_json: String = current_row.get("messages"); - let created_at = current_row.get("created_at"); + let owner_id = parse_uuid(¤t_row.col::("owner_id"))?; + let title: String = current_row.col("title"); + let models_json: String = current_row.col("models"); + let messages_json: String = current_row.col("messages"); + let created_at = current_row.col("created_at"); - let update_result = sqlx::query( + let update_result = query( r#" UPDATE conversations SET pin_order = ?, updated_at = ? @@ -710,10 +713,10 @@ impl ConversationRepo for SqliteConversationRepo { // Commit or rollback based on result match &result { Ok(_) => { - sqlx::query("COMMIT").execute(&mut *conn).await?; + query("COMMIT").execute(&mut *conn).await?; } Err(_) => { - let _ = sqlx::query("ROLLBACK").execute(&mut *conn).await; + let _ = query("ROLLBACK").execute(&mut *conn).await; } } @@ -739,7 +742,7 @@ impl ConversationRepo for SqliteConversationRepo { let limit = std::cmp::min(batch_size as u64, remaining) as i64; // Hard delete conversations that were soft-deleted before the cutoff - let result = sqlx::query( + let result = query( r#" DELETE FROM conversations WHERE id IN ( @@ -768,6 +771,8 @@ impl ConversationRepo for SqliteConversationRepo { #[cfg(test)] mod tests { + use sqlx::SqlitePool; + use super::*; use crate::{ db::repos::ConversationRepo, diff --git a/src/db/sqlite/domain_verifications.rs b/src/db/sqlite/domain_verifications.rs index 64920fd..c099d41 100644 --- a/src/db/sqlite/domain_verifications.rs +++ b/src/db/sqlite/domain_verifications.rs @@ -1,8 +1,10 @@ use async_trait::async_trait; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; -use super::common::parse_uuid; +use super::{ + backend::{Pool, Row, RowExt, query}, + common::parse_uuid, +}; use crate::{ db::{ error::{DbError, DbResult}, @@ -15,39 +17,40 @@ use crate::{ }; pub struct SqliteDomainVerificationRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteDomainVerificationRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } /// Parse a DomainVerification from a database row. - fn parse_verification(row: &sqlx::sqlite::SqliteRow) -> DbResult { - let status_str: String = row.get("status"); + fn parse_verification(row: &Row) -> DbResult { + let status_str: String = row.col("status"); let status = status_str .parse::() .unwrap_or_default(); Ok(DomainVerification { - id: parse_uuid(&row.get::("id"))?, - org_sso_config_id: parse_uuid(&row.get::("org_sso_config_id"))?, - domain: row.get("domain"), - verification_token: row.get("verification_token"), + id: parse_uuid(&row.col::("id"))?, + org_sso_config_id: parse_uuid(&row.col::("org_sso_config_id"))?, + domain: row.col("domain"), + verification_token: row.col("verification_token"), status, - dns_txt_record: row.get("dns_txt_record"), - verification_attempts: row.get("verification_attempts"), - last_attempt_at: row.get("last_attempt_at"), - verified_at: row.get("verified_at"), - expires_at: row.get("expires_at"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + dns_txt_record: row.col("dns_txt_record"), + verification_attempts: row.col("verification_attempts"), + last_attempt_at: row.col("last_attempt_at"), + verified_at: row.col("verified_at"), + expires_at: row.col("expires_at"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl DomainVerificationRepo for SqliteDomainVerificationRepo { async fn create( &self, @@ -58,7 +61,7 @@ impl DomainVerificationRepo for SqliteDomainVerificationRepo { let id = Uuid::new_v4(); let now = chrono::Utc::now(); - sqlx::query( + query( r#" INSERT INTO domain_verifications ( id, org_sso_config_id, domain, verification_token, status, @@ -75,15 +78,18 @@ impl DomainVerificationRepo for SqliteDomainVerificationRepo { .bind(now) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict(format!( + .map_err(|e| { + if super::backend::is_unique_violation(&e) { + return DbError::Conflict(format!( "Domain '{}' already exists for this SSO configuration", input.domain - )) + )); + } + #[cfg(feature = "database-sqlite")] + if matches!(&e, sqlx::Error::Database(db_err) if db_err.is_foreign_key_violation()) { + return DbError::NotFound; } - sqlx::Error::Database(db_err) if db_err.is_foreign_key_violation() => DbError::NotFound, - _ => DbError::from(e), + DbError::from(e) })?; Ok(DomainVerification { @@ -103,7 +109,7 @@ impl DomainVerificationRepo for SqliteDomainVerificationRepo { } async fn get_by_id(&self, id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, org_sso_config_id, domain, verification_token, status, dns_txt_record, verification_attempts, last_attempt_at, verified_at, @@ -127,7 +133,7 @@ impl DomainVerificationRepo for SqliteDomainVerificationRepo { org_sso_config_id: Uuid, domain: &str, ) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, org_sso_config_id, domain, verification_token, status, dns_txt_record, verification_attempts, last_attempt_at, verified_at, @@ -154,7 +160,7 @@ impl DomainVerificationRepo for SqliteDomainVerificationRepo { ) -> DbResult> { let limit = params.limit.unwrap_or(100); - let rows = sqlx::query( + let rows = query( r#" SELECT id, org_sso_config_id, domain, verification_token, status, dns_txt_record, verification_attempts, last_attempt_at, verified_at, @@ -176,14 +182,13 @@ impl DomainVerificationRepo for SqliteDomainVerificationRepo { } async fn count_by_config(&self, org_sso_config_id: Uuid) -> DbResult { - let row = sqlx::query( - "SELECT COUNT(*) as count FROM domain_verifications WHERE org_sso_config_id = ?", - ) - .bind(org_sso_config_id.to_string()) - .fetch_one(&self.pool) - .await?; + let row = + query("SELECT COUNT(*) as count FROM domain_verifications WHERE org_sso_config_id = ?") + .bind(org_sso_config_id.to_string()) + .fetch_one(&self.pool) + .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn update( @@ -196,7 +201,7 @@ impl DomainVerificationRepo for SqliteDomainVerificationRepo { // Fetch existing record let existing = self.get_by_id(id).await?.ok_or(DbError::NotFound)?; - sqlx::query( + query( r#" UPDATE domain_verifications SET status = ?, @@ -228,7 +233,7 @@ impl DomainVerificationRepo for SqliteDomainVerificationRepo { } async fn delete(&self, id: Uuid) -> DbResult<()> { - let result = sqlx::query("DELETE FROM domain_verifications WHERE id = ?") + let result = query("DELETE FROM domain_verifications WHERE id = ?") .bind(id.to_string()) .execute(&self.pool) .await?; @@ -241,7 +246,7 @@ impl DomainVerificationRepo for SqliteDomainVerificationRepo { } async fn find_verified_by_domain(&self, domain: &str) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT dv.id, dv.org_sso_config_id, dv.domain, dv.verification_token, dv.status, dv.dns_txt_record, dv.verification_attempts, dv.last_attempt_at, dv.verified_at, @@ -269,7 +274,7 @@ impl DomainVerificationRepo for SqliteDomainVerificationRepo { &self, org_sso_config_id: Uuid, ) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT id, org_sso_config_id, domain, verification_token, status, dns_txt_record, verification_attempts, last_attempt_at, verified_at, @@ -291,7 +296,7 @@ impl DomainVerificationRepo for SqliteDomainVerificationRepo { } async fn has_verified_domain(&self, org_sso_config_id: Uuid) -> DbResult { - let row = sqlx::query( + let row = query( r#" SELECT EXISTS( SELECT 1 FROM domain_verifications @@ -306,7 +311,7 @@ impl DomainVerificationRepo for SqliteDomainVerificationRepo { .await?; // SQLite returns 0 or 1 for EXISTS - Ok(row.get::("has_verified") != 0) + Ok(row.col::("has_verified") != 0) } async fn create_auto_verified( @@ -318,7 +323,7 @@ impl DomainVerificationRepo for SqliteDomainVerificationRepo { let id = Uuid::new_v4(); let now = chrono::Utc::now(); - sqlx::query( + query( r#" INSERT INTO domain_verifications ( id, org_sso_config_id, domain, verification_token, status, @@ -336,15 +341,18 @@ impl DomainVerificationRepo for SqliteDomainVerificationRepo { .bind(now) // updated_at .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict(format!( + .map_err(|e| { + if super::backend::is_unique_violation(&e) { + return DbError::Conflict(format!( "Domain '{}' already exists for this SSO configuration", input.domain - )) + )); } - sqlx::Error::Database(db_err) if db_err.is_foreign_key_violation() => DbError::NotFound, - _ => DbError::from(e), + #[cfg(feature = "database-sqlite")] + if matches!(&e, sqlx::Error::Database(db_err) if db_err.is_foreign_key_violation()) { + return DbError::NotFound; + } + DbError::from(e) })?; Ok(DomainVerification { @@ -366,6 +374,8 @@ impl DomainVerificationRepo for SqliteDomainVerificationRepo { #[cfg(test)] mod tests { + use sqlx::SqlitePool; + use super::*; async fn create_test_pool() -> SqlitePool { diff --git a/src/db/sqlite/files.rs b/src/db/sqlite/files.rs index 51821b2..a783402 100644 --- a/src/db/sqlite/files.rs +++ b/src/db/sqlite/files.rs @@ -1,8 +1,10 @@ use async_trait::async_trait; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; -use super::common::parse_uuid; +use super::{ + backend::{Pool, RowExt, query}, + common::parse_uuid, +}; use crate::{ db::{ error::{DbError, DbResult}, @@ -12,11 +14,11 @@ use crate::{ }; pub struct SqliteFilesRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteFilesRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } @@ -25,13 +27,14 @@ impl SqliteFilesRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl FilesRepo for SqliteFilesRepo { async fn create_file(&self, input: CreateFile) -> DbResult { let id = Uuid::new_v4(); let now = chrono::Utc::now(); - sqlx::query( + query( r#" INSERT INTO files (id, owner_type, owner_id, filename, purpose, content_type, size_bytes, status, content_hash, storage_backend, file_data, storage_path, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) @@ -73,7 +76,7 @@ impl FilesRepo for SqliteFilesRepo { } async fn get_file(&self, id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, owner_type, owner_id, filename, purpose, content_type, size_bytes, status, status_details, content_hash, storage_backend, storage_path, created_at, expires_at @@ -87,35 +90,35 @@ impl FilesRepo for SqliteFilesRepo { match result { Some(row) => { - let owner_type_str: String = row.get("owner_type"); - let purpose_str: String = row.get("purpose"); - let status_str: String = row.get("status"); - let storage_backend_str: String = row.get("storage_backend"); + let owner_type_str: String = row.col("owner_type"); + let purpose_str: String = row.col("purpose"); + let status_str: String = row.col("status"); + let storage_backend_str: String = row.col("storage_backend"); Ok(Some(File { - id: parse_uuid(&row.get::("id"))?, + id: parse_uuid(&row.col::("id"))?, object: OBJECT_TYPE_FILE.to_string(), owner_type: owner_type_str .parse() .map_err(|e: String| DbError::Internal(e))?, - owner_id: parse_uuid(&row.get::("owner_id"))?, - filename: row.get("filename"), + owner_id: parse_uuid(&row.col::("owner_id"))?, + filename: row.col("filename"), purpose: purpose_str .parse() .map_err(|e: String| DbError::Internal(e))?, - content_type: row.get("content_type"), - size_bytes: row.get("size_bytes"), + content_type: row.col("content_type"), + size_bytes: row.col("size_bytes"), status: status_str .parse() .map_err(|e: String| DbError::Internal(e))?, - status_details: row.get("status_details"), - content_hash: row.get("content_hash"), + status_details: row.col("status_details"), + content_hash: row.col("content_hash"), storage_backend: storage_backend_str .parse() .map_err(|e: String| DbError::Internal(e))?, - storage_path: row.get("storage_path"), - created_at: row.get("created_at"), - expires_at: row.get("expires_at"), + storage_path: row.col("storage_path"), + created_at: row.col("created_at"), + expires_at: row.col("expires_at"), })) } None => Ok(None), @@ -123,7 +126,7 @@ impl FilesRepo for SqliteFilesRepo { } async fn get_file_data(&self, id: Uuid) -> DbResult>> { - let result = sqlx::query( + let result = query( r#" SELECT file_data FROM files @@ -134,7 +137,7 @@ impl FilesRepo for SqliteFilesRepo { .fetch_optional(&self.pool) .await?; - Ok(result.and_then(|row| row.get::>, _>("file_data"))) + Ok(result.and_then(|row| row.col::>>("file_data"))) } async fn list_files( @@ -152,7 +155,7 @@ impl FilesRepo for SqliteFilesRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let (query, bind_purpose) = match &purpose { + let (sql, bind_purpose) = match &purpose { Some(p) => ( format!( r#" @@ -186,7 +189,7 @@ impl FilesRepo for SqliteFilesRepo { }; let rows = if let Some(purpose_str) = bind_purpose { - sqlx::query(&query) + query(&sql) .bind(owner_type.as_str()) .bind(owner_id.to_string()) .bind(purpose_str) @@ -196,7 +199,7 @@ impl FilesRepo for SqliteFilesRepo { .fetch_all(&self.pool) .await? } else { - sqlx::query(&query) + query(&sql) .bind(owner_type.as_str()) .bind(owner_id.to_string()) .bind(cursor.created_at) @@ -211,35 +214,35 @@ impl FilesRepo for SqliteFilesRepo { .into_iter() .take(limit as usize) .map(|row| { - let owner_type_str: String = row.get("owner_type"); - let purpose_str: String = row.get("purpose"); - let status_str: String = row.get("status"); - let storage_backend_str: String = row.get("storage_backend"); + let owner_type_str: String = row.col("owner_type"); + let purpose_str: String = row.col("purpose"); + let status_str: String = row.col("status"); + let storage_backend_str: String = row.col("storage_backend"); Ok(File { - id: parse_uuid(&row.get::("id"))?, + id: parse_uuid(&row.col::("id"))?, object: OBJECT_TYPE_FILE.to_string(), owner_type: owner_type_str .parse() .map_err(|e: String| DbError::Internal(e))?, - owner_id: parse_uuid(&row.get::("owner_id"))?, - filename: row.get("filename"), + owner_id: parse_uuid(&row.col::("owner_id"))?, + filename: row.col("filename"), purpose: purpose_str .parse() .map_err(|e: String| DbError::Internal(e))?, - content_type: row.get("content_type"), - size_bytes: row.get("size_bytes"), + content_type: row.col("content_type"), + size_bytes: row.col("size_bytes"), status: status_str .parse() .map_err(|e: String| DbError::Internal(e))?, - status_details: row.get("status_details"), - content_hash: row.get("content_hash"), + status_details: row.col("status_details"), + content_hash: row.col("content_hash"), storage_backend: storage_backend_str .parse() .map_err(|e: String| DbError::Internal(e))?, - storage_path: row.get("storage_path"), - created_at: row.get("created_at"), - expires_at: row.get("expires_at"), + storage_path: row.col("storage_path"), + created_at: row.col("created_at"), + expires_at: row.col("expires_at"), }) }) .collect::>>()?; @@ -261,7 +264,7 @@ impl FilesRepo for SqliteFilesRepo { // First page (no cursor) let order = params.sort_order.as_sql(); - let (query, bind_purpose) = match &purpose { + let (sql, bind_purpose) = match &purpose { Some(p) => ( format!( r#" @@ -293,7 +296,7 @@ impl FilesRepo for SqliteFilesRepo { }; let rows = if let Some(purpose_str) = bind_purpose { - sqlx::query(&query) + query(&sql) .bind(owner_type.as_str()) .bind(owner_id.to_string()) .bind(purpose_str) @@ -301,7 +304,7 @@ impl FilesRepo for SqliteFilesRepo { .fetch_all(&self.pool) .await? } else { - sqlx::query(&query) + query(&sql) .bind(owner_type.as_str()) .bind(owner_id.to_string()) .bind(fetch_limit) @@ -314,35 +317,35 @@ impl FilesRepo for SqliteFilesRepo { .into_iter() .take(limit as usize) .map(|row| { - let owner_type_str: String = row.get("owner_type"); - let purpose_str: String = row.get("purpose"); - let status_str: String = row.get("status"); - let storage_backend_str: String = row.get("storage_backend"); + let owner_type_str: String = row.col("owner_type"); + let purpose_str: String = row.col("purpose"); + let status_str: String = row.col("status"); + let storage_backend_str: String = row.col("storage_backend"); Ok(File { - id: parse_uuid(&row.get::("id"))?, + id: parse_uuid(&row.col::("id"))?, object: OBJECT_TYPE_FILE.to_string(), owner_type: owner_type_str .parse() .map_err(|e: String| DbError::Internal(e))?, - owner_id: parse_uuid(&row.get::("owner_id"))?, - filename: row.get("filename"), + owner_id: parse_uuid(&row.col::("owner_id"))?, + filename: row.col("filename"), purpose: purpose_str .parse() .map_err(|e: String| DbError::Internal(e))?, - content_type: row.get("content_type"), - size_bytes: row.get("size_bytes"), + content_type: row.col("content_type"), + size_bytes: row.col("size_bytes"), status: status_str .parse() .map_err(|e: String| DbError::Internal(e))?, - status_details: row.get("status_details"), - content_hash: row.get("content_hash"), + status_details: row.col("status_details"), + content_hash: row.col("content_hash"), storage_backend: storage_backend_str .parse() .map_err(|e: String| DbError::Internal(e))?, - storage_path: row.get("storage_path"), - created_at: row.get("created_at"), - expires_at: row.get("expires_at"), + storage_path: row.col("storage_path"), + created_at: row.col("created_at"), + expires_at: row.col("expires_at"), }) }) .collect::>>()?; @@ -359,7 +362,7 @@ impl FilesRepo for SqliteFilesRepo { } async fn delete_file(&self, id: Uuid) -> DbResult<()> { - let result = sqlx::query( + let result = query( r#" DELETE FROM files WHERE id = ? @@ -382,7 +385,7 @@ impl FilesRepo for SqliteFilesRepo { status: FileStatus, status_details: Option, ) -> DbResult<()> { - let result = sqlx::query( + let result = query( r#" UPDATE files SET status = ?, status_details = ? @@ -403,7 +406,7 @@ impl FilesRepo for SqliteFilesRepo { } async fn count_file_references(&self, file_id: Uuid) -> DbResult { - let result = sqlx::query( + let result = query( r#" SELECT COUNT(*) as count FROM vector_store_files @@ -414,6 +417,6 @@ impl FilesRepo for SqliteFilesRepo { .fetch_one(&self.pool) .await?; - Ok(result.get("count")) + Ok(result.col("count")) } } diff --git a/src/db/sqlite/mod.rs b/src/db/sqlite/mod.rs index bd4a736..4a15277 100644 --- a/src/db/sqlite/mod.rs +++ b/src/db/sqlite/mod.rs @@ -1,5 +1,6 @@ mod api_keys; mod audit_logs; +pub(crate) mod backend; mod common; mod conversations; #[cfg(feature = "sso")] diff --git a/src/db/sqlite/model_pricing.rs b/src/db/sqlite/model_pricing.rs index 67de854..27122db 100644 --- a/src/db/sqlite/model_pricing.rs +++ b/src/db/sqlite/model_pricing.rs @@ -1,8 +1,10 @@ use async_trait::async_trait; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; -use super::common::parse_uuid; +use super::{ + backend::{Pool, RowExt, begin, map_unique_violation, query}, + common::parse_uuid, +}; use crate::{ db::{ error::{DbError, DbResult}, @@ -15,11 +17,11 @@ use crate::{ }; pub struct SqliteModelPricingRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteModelPricingRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } @@ -52,28 +54,28 @@ impl SqliteModelPricingRepo { } } - fn row_to_pricing(row: &sqlx::sqlite::SqliteRow) -> DbResult { - let owner_type: Option = row.get("owner_type"); - let owner_id: Option = row.get("owner_id"); - let source_str: String = row.get("source"); + fn row_to_pricing(row: &super::backend::Row) -> DbResult { + let owner_type: Option = row.col("owner_type"); + let owner_id: Option = row.col("owner_id"); + let source_str: String = row.col("source"); Ok(DbModelPricing { - id: parse_uuid(&row.get::("id"))?, + id: parse_uuid(&row.col::("id"))?, owner: Self::parse_owner(owner_type.as_deref(), owner_id.as_deref())?, - provider: row.get("provider"), - model: row.get("model"), - input_per_1m_tokens: row.get("input_per_1m_tokens"), - output_per_1m_tokens: row.get("output_per_1m_tokens"), - per_image: row.get("per_image"), - per_request: row.get("per_request"), - cached_input_per_1m_tokens: row.get("cached_input_per_1m_tokens"), - cache_write_per_1m_tokens: row.get("cache_write_per_1m_tokens"), - reasoning_per_1m_tokens: row.get("reasoning_per_1m_tokens"), - per_second: row.get("per_second"), - per_1m_characters: row.get("per_1m_characters"), - source: PricingSource::from_str(&source_str), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + provider: row.col("provider"), + model: row.col("model"), + input_per_1m_tokens: row.col("input_per_1m_tokens"), + output_per_1m_tokens: row.col("output_per_1m_tokens"), + per_image: row.col("per_image"), + per_request: row.col("per_request"), + cached_input_per_1m_tokens: row.col("cached_input_per_1m_tokens"), + cache_write_per_1m_tokens: row.col("cache_write_per_1m_tokens"), + reasoning_per_1m_tokens: row.col("reasoning_per_1m_tokens"), + per_second: row.col("per_second"), + per_1m_characters: row.col("per_1m_characters"), + source: PricingSource::parse(&source_str), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) } @@ -99,7 +101,7 @@ impl SqliteModelPricingRepo { ) }; - let query = format!( + let sql = format!( r#" SELECT id, owner_type, owner_id, provider, model, input_per_1m_tokens, output_per_1m_tokens, per_image, per_request, @@ -113,7 +115,7 @@ impl SqliteModelPricingRepo { cursor_condition, order, order ); - let mut query_builder = sqlx::query(&query); + let mut query_builder = query(&sql); for bind in &binds { query_builder = query_builder.bind(bind); } @@ -156,7 +158,7 @@ impl SqliteModelPricingRepo { ) -> DbResult> { let fetch_limit = limit + 1; - let query = format!( + let sql = format!( r#" SELECT id, owner_type, owner_id, provider, model, input_per_1m_tokens, output_per_1m_tokens, per_image, per_request, @@ -170,7 +172,7 @@ impl SqliteModelPricingRepo { where_clause ); - let mut query_builder = sqlx::query(&query); + let mut query_builder = query(&sql); for bind in &binds { query_builder = query_builder.bind(bind); } @@ -198,14 +200,15 @@ impl SqliteModelPricingRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl ModelPricingRepo for SqliteModelPricingRepo { async fn create(&self, input: CreateModelPricing) -> DbResult { let id = Uuid::new_v4(); let now = chrono::Utc::now(); let (owner_type, owner_id) = Self::owner_to_parts(&input.owner); - sqlx::query( + query( r#" INSERT INTO model_pricing ( id, owner_type, owner_id, provider, model, @@ -236,15 +239,10 @@ impl ModelPricingRepo for SqliteModelPricingRepo { .bind(now) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict(format!( - "Pricing for provider '{}' model '{}' already exists", - input.provider, input.model - )) - } - _ => DbError::from(e), - })?; + .map_err(map_unique_violation(format!( + "Pricing for provider '{}' model '{}' already exists", + input.provider, input.model + )))?; Ok(DbModelPricing { id, @@ -267,7 +265,7 @@ impl ModelPricingRepo for SqliteModelPricingRepo { } async fn get_by_id(&self, id: Uuid) -> DbResult> { - let row = sqlx::query( + let row = query( r#" SELECT id, owner_type, owner_id, provider, model, input_per_1m_tokens, output_per_1m_tokens, per_image, per_request, @@ -293,7 +291,7 @@ impl ModelPricingRepo for SqliteModelPricingRepo { let (owner_type, owner_id) = Self::owner_to_parts(owner); let row = if owner_type.is_none() { - sqlx::query( + query( r#" SELECT id, owner_type, owner_id, provider, model, input_per_1m_tokens, output_per_1m_tokens, per_image, per_request, @@ -308,7 +306,7 @@ impl ModelPricingRepo for SqliteModelPricingRepo { .fetch_optional(&self.pool) .await? } else { - sqlx::query( + query( r#" SELECT id, owner_type, owner_id, provider, model, input_per_1m_tokens, output_per_1m_tokens, per_image, per_request, @@ -339,7 +337,7 @@ impl ModelPricingRepo for SqliteModelPricingRepo { ) -> DbResult> { // Single query with priority ordering: user > project > org > global // Uses CASE expression to assign priority and LIMIT 1 to get highest priority match - let row = sqlx::query( + let row = query( r#" SELECT id, owner_type, owner_id, provider, model, input_per_1m_tokens, output_per_1m_tokens, per_image, per_request, @@ -396,7 +394,7 @@ impl ModelPricingRepo for SqliteModelPricingRepo { } async fn count_by_org(&self, org_id: Uuid) -> DbResult { - let row = sqlx::query( + let row = query( r#" SELECT COUNT(*) as count FROM model_pricing @@ -407,7 +405,7 @@ impl ModelPricingRepo for SqliteModelPricingRepo { .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn list_by_project( @@ -430,7 +428,7 @@ impl ModelPricingRepo for SqliteModelPricingRepo { } async fn count_by_project(&self, project_id: Uuid) -> DbResult { - let row = sqlx::query( + let row = query( r#" SELECT COUNT(*) as count FROM model_pricing @@ -441,7 +439,7 @@ impl ModelPricingRepo for SqliteModelPricingRepo { .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn list_by_user( @@ -464,7 +462,7 @@ impl ModelPricingRepo for SqliteModelPricingRepo { } async fn count_by_user(&self, user_id: Uuid) -> DbResult { - let row = sqlx::query( + let row = query( r#" SELECT COUNT(*) as count FROM model_pricing @@ -475,7 +473,7 @@ impl ModelPricingRepo for SqliteModelPricingRepo { .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn list_global(&self, params: ListParams) -> DbResult> { @@ -494,7 +492,7 @@ impl ModelPricingRepo for SqliteModelPricingRepo { } async fn count_global(&self) -> DbResult { - let row = sqlx::query( + let row = query( r#" SELECT COUNT(*) as count FROM model_pricing @@ -504,7 +502,7 @@ impl ModelPricingRepo for SqliteModelPricingRepo { .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn list_by_provider( @@ -527,7 +525,7 @@ impl ModelPricingRepo for SqliteModelPricingRepo { } async fn count_by_provider(&self, provider: &str) -> DbResult { - let row = sqlx::query( + let row = query( r#" SELECT COUNT(*) as count FROM model_pricing @@ -538,13 +536,13 @@ impl ModelPricingRepo for SqliteModelPricingRepo { .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn update(&self, id: Uuid, input: UpdateModelPricing) -> DbResult { let now = chrono::Utc::now(); - sqlx::query( + query( r#" UPDATE model_pricing SET input_per_1m_tokens = COALESCE(?, input_per_1m_tokens), @@ -576,7 +574,7 @@ impl ModelPricingRepo for SqliteModelPricingRepo { } async fn delete(&self, id: Uuid) -> DbResult<()> { - sqlx::query("DELETE FROM model_pricing WHERE id = ?") + query("DELETE FROM model_pricing WHERE id = ?") .bind(id.to_string()) .execute(&self.pool) .await?; @@ -593,7 +591,7 @@ impl ModelPricingRepo for SqliteModelPricingRepo { // Different conflict targets for global vs scoped pricing due to partial indexes if owner_type.is_none() { // Global pricing: conflict on (provider, model) where owner_type IS NULL - sqlx::query( + query( r#" INSERT INTO model_pricing ( id, owner_type, owner_id, provider, model, @@ -636,7 +634,7 @@ impl ModelPricingRepo for SqliteModelPricingRepo { .await?; } else { // Scoped pricing: conflict on (owner_type, owner_id, provider, model) - sqlx::query( + query( r#" INSERT INTO model_pricing ( id, owner_type, owner_id, provider, model, @@ -695,16 +693,14 @@ impl ModelPricingRepo for SqliteModelPricingRepo { let now = chrono::Utc::now(); let count = entries.len(); - // Process all entries in a single transaction for atomicity - // If any entry fails, the entire batch is rolled back - let mut tx = self.pool.begin().await?; + let mut tx = begin(&self.pool).await?; for entry in entries { let id = Uuid::new_v4(); let (owner_type, owner_id) = Self::owner_to_parts(&entry.owner); if owner_type.is_none() { - sqlx::query( + query( r#" INSERT INTO model_pricing ( id, owner_type, owner_id, provider, model, @@ -746,7 +742,7 @@ impl ModelPricingRepo for SqliteModelPricingRepo { .execute(&mut *tx) .await?; } else { - sqlx::query( + query( r#" INSERT INTO model_pricing ( id, owner_type, owner_id, provider, model, @@ -793,12 +789,15 @@ impl ModelPricingRepo for SqliteModelPricingRepo { } tx.commit().await?; + Ok(count) } } #[cfg(test)] mod tests { + use sqlx::SqlitePool; + use super::*; use crate::db::repos::{ListParams, ModelPricingRepo}; diff --git a/src/db/sqlite/org_rbac_policies.rs b/src/db/sqlite/org_rbac_policies.rs index c8684aa..4cc942f 100644 --- a/src/db/sqlite/org_rbac_policies.rs +++ b/src/db/sqlite/org_rbac_policies.rs @@ -1,9 +1,11 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; -use super::common::parse_uuid; +use super::{ + backend::{Pool, RowExt, Transaction, begin, map_unique_violation, query}, + common::parse_uuid, +}; use crate::{ db::{ error::{DbError, DbResult}, @@ -19,64 +21,64 @@ use crate::{ }; pub struct SqliteOrgRbacPolicyRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteOrgRbacPolicyRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } - fn parse_policy(row: &sqlx::sqlite::SqliteRow) -> DbResult { - let effect_str: String = row.get("effect"); + fn parse_policy(row: &super::backend::Row) -> DbResult { + let effect_str: String = row.col("effect"); let effect: RbacPolicyEffect = effect_str .parse() .map_err(|e: String| DbError::Internal(e))?; - let enabled: i32 = row.get("enabled"); + let enabled: i32 = row.col("enabled"); Ok(OrgRbacPolicy { - id: parse_uuid(row.get("id"))?, - org_id: parse_uuid(row.get("org_id"))?, - name: row.get("name"), - description: row.get("description"), - resource: row.get("resource"), - action: row.get("action"), - condition: row.get("condition"), + id: parse_uuid(&row.col::("id"))?, + org_id: parse_uuid(&row.col::("org_id"))?, + name: row.col("name"), + description: row.col("description"), + resource: row.col("resource"), + action: row.col("action"), + condition: row.col("condition"), effect, - priority: row.get("priority"), + priority: row.col("priority"), enabled: enabled != 0, - version: row.get("version"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), - deleted_at: row.get("deleted_at"), + version: row.col("version"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), + deleted_at: row.col("deleted_at"), }) } - fn parse_version(row: &sqlx::sqlite::SqliteRow) -> DbResult { - let effect_str: String = row.get("effect"); + fn parse_version(row: &super::backend::Row) -> DbResult { + let effect_str: String = row.col("effect"); let effect: RbacPolicyEffect = effect_str .parse() .map_err(|e: String| DbError::Internal(e))?; - let enabled: i32 = row.get("enabled"); - let created_by: Option = row.get("created_by"); + let enabled: i32 = row.col("enabled"); + let created_by: Option = row.col("created_by"); Ok(OrgRbacPolicyVersion { - id: parse_uuid(row.get("id"))?, - policy_id: parse_uuid(row.get("policy_id"))?, - version: row.get("version"), - name: row.get("name"), - description: row.get("description"), - resource: row.get("resource"), - action: row.get("action"), - condition: row.get("condition"), + id: parse_uuid(&row.col::("id"))?, + policy_id: parse_uuid(&row.col::("policy_id"))?, + version: row.col("version"), + name: row.col("name"), + description: row.col("description"), + resource: row.col("resource"), + action: row.col("action"), + condition: row.col("condition"), effect, - priority: row.get("priority"), + priority: row.col("priority"), enabled: enabled != 0, created_by: created_by.and_then(|s| Uuid::parse_str(&s).ok()), - reason: row.get("reason"), - created_at: row.get("created_at"), + reason: row.col("reason"), + created_at: row.col("created_at"), }) } @@ -98,7 +100,7 @@ impl SqliteOrgRbacPolicyRepo { "AND deleted_at IS NULL" }; - let query = format!( + let sql = format!( r#" SELECT id, org_id, name, description, resource, action, condition, effect, priority, enabled, version, created_at, updated_at, deleted_at @@ -111,7 +113,7 @@ impl SqliteOrgRbacPolicyRepo { comparison, deleted_filter, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(org_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -150,7 +152,7 @@ impl SqliteOrgRbacPolicyRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT id, policy_id, version, name, description, resource, action, condition, effect, priority, enabled, created_by, reason, created_at @@ -162,7 +164,7 @@ impl SqliteOrgRbacPolicyRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(policy_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -192,17 +194,17 @@ impl SqliteOrgRbacPolicyRepo { Ok(ListResult::new(items, has_more, cursors)) } - /// Create a version record from a policy snapshot + /// Create a version record from a policy snapshot within a transaction. async fn create_version_record( &self, - tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, + tx: &mut Transaction<'_>, policy: &OrgRbacPolicy, created_by: Option, reason: Option, ) -> DbResult<()> { let version_id = Uuid::new_v4(); - sqlx::query( + query( r#" INSERT INTO org_rbac_policy_versions ( id, policy_id, version, name, description, resource, action, @@ -232,7 +234,8 @@ impl SqliteOrgRbacPolicyRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { async fn create( &self, @@ -243,42 +246,6 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { let id = Uuid::new_v4(); let now: DateTime = Utc::now(); - let mut tx = self.pool.begin().await?; - - // Insert the policy - sqlx::query( - r#" - INSERT INTO org_rbac_policies ( - id, org_id, name, description, resource, action, condition, - effect, priority, enabled, version, created_at, updated_at - ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 1, ?, ?) - "#, - ) - .bind(id.to_string()) - .bind(org_id.to_string()) - .bind(&input.name) - .bind(&input.description) - .bind(&input.resource) - .bind(&input.action) - .bind(&input.condition) - .bind(input.effect.to_string()) - .bind(input.priority) - .bind(if input.enabled { 1 } else { 0 }) - .bind(now) - .bind(now) - .execute(&mut *tx) - .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict(format!( - "Policy with name '{}' already exists in this organization", - input.name - )) - } - _ => DbError::from(e), - })?; - let policy = OrgRbacPolicy { id, org_id, @@ -296,7 +263,36 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { deleted_at: None, }; - // Create version 1 record + let mut tx = begin(&self.pool).await?; + + query( + r#" + INSERT INTO org_rbac_policies ( + id, org_id, name, description, resource, action, condition, + effect, priority, enabled, version, created_at, updated_at + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 1, ?, ?) + "#, + ) + .bind(id.to_string()) + .bind(org_id.to_string()) + .bind(&policy.name) + .bind(&policy.description) + .bind(&policy.resource) + .bind(&policy.action) + .bind(&policy.condition) + .bind(policy.effect.to_string()) + .bind(policy.priority) + .bind(if policy.enabled { 1 } else { 0 }) + .bind(now) + .bind(now) + .execute(&mut *tx) + .await + .map_err(map_unique_violation(format!( + "Policy with name '{}' already exists in this organization", + policy.name + )))?; + self.create_version_record(&mut tx, &policy, created_by, input.reason) .await?; @@ -306,7 +302,7 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { } async fn get_by_id(&self, id: Uuid) -> DbResult> { - let row = sqlx::query( + let row = query( r#" SELECT id, org_id, name, description, resource, action, condition, effect, priority, enabled, version, created_at, updated_at, deleted_at @@ -329,7 +325,7 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { org_id: Uuid, name: &str, ) -> DbResult> { - let row = sqlx::query( + let row = query( r#" SELECT id, org_id, name, description, resource, action, condition, effect, priority, enabled, version, created_at, updated_at, deleted_at @@ -349,7 +345,7 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { } async fn list_by_org(&self, org_id: Uuid) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT id, org_id, name, description, resource, action, condition, effect, priority, enabled, version, created_at, updated_at, deleted_at @@ -380,7 +376,7 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { } // First page (no cursor) - let query = if params.include_deleted { + let sql = if params.include_deleted { r#" SELECT id, org_id, name, description, resource, action, condition, effect, priority, enabled, version, created_at, updated_at, deleted_at @@ -400,7 +396,7 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { "# }; - let rows = sqlx::query(query) + let rows = query(sql) .bind(org_id.to_string()) .bind(fetch_limit) .fetch_all(&self.pool) @@ -422,7 +418,7 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { } async fn list_enabled_by_org(&self, org_id: Uuid) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT id, org_id, name, description, resource, action, condition, effect, priority, enabled, version, created_at, updated_at, deleted_at @@ -439,7 +435,7 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { } async fn list_all_enabled(&self) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT id, org_id, name, description, resource, action, condition, effect, priority, enabled, version, created_at, updated_at, deleted_at @@ -460,10 +456,8 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { input: UpdateOrgRbacPolicy, updated_by: Option, ) -> DbResult { - let mut tx = self.pool.begin().await?; - // Fetch current policy (excluding soft-deleted) - let row = sqlx::query( + let row = query( r#" SELECT id, org_id, name, description, resource, action, condition, effect, priority, enabled, version, created_at, updated_at, deleted_at @@ -472,7 +466,7 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { "#, ) .bind(id.to_string()) - .fetch_optional(&mut *tx) + .fetch_optional(&self.pool) .await?; let Some(row) = row else { @@ -514,8 +508,9 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { policy.version += 1; policy.updated_at = Utc::now(); - // Update the policy with optimistic locking (check original version and not deleted) - let result = sqlx::query( + let mut tx = begin(&self.pool).await?; + + let result = query( r#" UPDATE org_rbac_policies SET name = ?, description = ?, resource = ?, action = ?, condition = ?, @@ -537,24 +532,17 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { .bind(original_version) .execute(&mut *tx) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict(format!( - "Policy with name '{}' already exists in this organization", - policy.name - )) - } - _ => DbError::from(e), - })?; - - // Check for concurrent modification (optimistic locking) + .map_err(map_unique_violation(format!( + "Policy with name '{}' already exists in this organization", + policy.name + )))?; + if result.rows_affected() == 0 { return Err(DbError::Conflict( "Policy was modified concurrently. Please refresh and try again.".to_string(), )); } - // Create version record self.create_version_record(&mut tx, &policy, updated_by, input.reason) .await?; @@ -567,7 +555,7 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { let now = Utc::now(); // Soft-delete by setting deleted_at timestamp - let result = sqlx::query( + let result = query( r#" UPDATE org_rbac_policies SET deleted_at = ? @@ -592,10 +580,8 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { input: RollbackOrgRbacPolicy, rolled_back_by: Option, ) -> DbResult { - let mut tx = self.pool.begin().await?; - // Fetch the target version - let version_row = sqlx::query( + let version_row = query( r#" SELECT id, policy_id, version, name, description, resource, action, condition, effect, priority, enabled, created_by, reason, created_at @@ -605,7 +591,7 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { ) .bind(id.to_string()) .bind(input.target_version) - .fetch_optional(&mut *tx) + .fetch_optional(&self.pool) .await?; let Some(version_row) = version_row else { @@ -615,7 +601,7 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { let target_version = Self::parse_version(&version_row)?; // Fetch current policy to get current version number and timestamps (excluding soft-deleted) - let policy_row = sqlx::query( + let policy_row = query( r#" SELECT id, org_id, name, description, resource, action, condition, effect, priority, enabled, version, created_at, updated_at, deleted_at @@ -624,7 +610,7 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { "#, ) .bind(id.to_string()) - .fetch_optional(&mut *tx) + .fetch_optional(&self.pool) .await?; let Some(policy_row) = policy_row else { @@ -654,8 +640,13 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { deleted_at: None, }; - // Update the policy with rolled-back values (with optimistic locking and not deleted) - let result = sqlx::query( + let reason = input + .reason + .unwrap_or_else(|| format!("Rolled back to version {}", input.target_version)); + + let mut tx = begin(&self.pool).await?; + + let result = query( r#" UPDATE org_rbac_policies SET name = ?, description = ?, resource = ?, action = ?, condition = ?, @@ -674,21 +665,16 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { .bind(policy.version) .bind(policy.updated_at) .bind(id.to_string()) - .bind(current_policy.version) // Original version for optimistic locking + .bind(current_policy.version) .execute(&mut *tx) .await?; - // Check for concurrent modification (optimistic locking) if result.rows_affected() == 0 { return Err(DbError::Conflict( "Policy was modified concurrently. Please refresh and try again.".to_string(), )); } - // Create version record for the rollback - let reason = input - .reason - .unwrap_or_else(|| format!("Rolled back to version {}", input.target_version)); self.create_version_record(&mut tx, &policy, rolled_back_by, Some(reason)) .await?; @@ -702,7 +688,7 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { policy_id: Uuid, version: i32, ) -> DbResult> { - let row = sqlx::query( + let row = query( r#" SELECT id, policy_id, version, name, description, resource, action, condition, effect, priority, enabled, created_by, reason, created_at @@ -722,7 +708,7 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { } async fn list_versions(&self, policy_id: Uuid) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT id, policy_id, version, name, description, resource, action, condition, effect, priority, enabled, created_by, reason, created_at @@ -744,7 +730,7 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { limit: u32, offset: u32, ) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT id, policy_id, version, name, description, resource, action, condition, effect, priority, enabled, created_by, reason, created_at @@ -778,7 +764,7 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { } // First page (no cursor) - let rows = sqlx::query( + let rows = query( r#" SELECT id, policy_id, version, name, description, resource, action, condition, effect, priority, enabled, created_by, reason, created_at @@ -812,39 +798,39 @@ impl OrgRbacPolicyRepo for SqliteOrgRbacPolicyRepo { } async fn count_versions(&self, policy_id: Uuid) -> DbResult { - let row = sqlx::query( - "SELECT COUNT(*) as count FROM org_rbac_policy_versions WHERE policy_id = ?", - ) - .bind(policy_id.to_string()) - .fetch_one(&self.pool) - .await?; + let row = + query("SELECT COUNT(*) as count FROM org_rbac_policy_versions WHERE policy_id = ?") + .bind(policy_id.to_string()) + .fetch_one(&self.pool) + .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn count_by_org(&self, org_id: Uuid) -> DbResult { - let row = sqlx::query( + let row = query( "SELECT COUNT(*) as count FROM org_rbac_policies WHERE org_id = ? AND deleted_at IS NULL", ) .bind(org_id.to_string()) .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn count_all(&self) -> DbResult { - let row = - sqlx::query("SELECT COUNT(*) as count FROM org_rbac_policies WHERE deleted_at IS NULL") - .fetch_one(&self.pool) - .await?; + let row = query("SELECT COUNT(*) as count FROM org_rbac_policies WHERE deleted_at IS NULL") + .fetch_one(&self.pool) + .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } } #[cfg(test)] mod tests { + use sqlx::SqlitePool; + use super::*; async fn create_test_pool() -> SqlitePool { diff --git a/src/db/sqlite/org_sso_configs.rs b/src/db/sqlite/org_sso_configs.rs index a746988..57d7c2e 100644 --- a/src/db/sqlite/org_sso_configs.rs +++ b/src/db/sqlite/org_sso_configs.rs @@ -1,8 +1,10 @@ use async_trait::async_trait; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; -use super::common::parse_uuid; +use super::{ + backend::{Pool, Row, RowExt, map_unique_violation, query}, + common::parse_uuid, +}; use crate::{ db::{ error::{DbError, DbResult}, @@ -15,87 +17,87 @@ use crate::{ }; pub struct SqliteOrgSsoConfigRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteOrgSsoConfigRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } /// Parse an OrgSsoConfig from a database row. - fn parse_config(row: &sqlx::sqlite::SqliteRow) -> DbResult { - let default_team_id: Option = row.get("default_team_id"); + fn parse_config(row: &Row) -> DbResult { + let default_team_id: Option = row.col("default_team_id"); let default_team_id = default_team_id.map(|s| parse_uuid(&s)).transpose()?; - let scopes_str: String = row.get("scopes"); + let scopes_str: String = row.col("scopes"); let scopes: Vec = scopes_str.split_whitespace().map(String::from).collect(); - let allowed_domains_json: Option = row.get("allowed_email_domains"); + let allowed_domains_json: Option = row.col("allowed_email_domains"); let allowed_email_domains: Vec = allowed_domains_json .map(|json| serde_json::from_str(&json).unwrap_or_default()) .unwrap_or_default(); - let provider_type_str: String = row.get("provider_type"); + let provider_type_str: String = row.col("provider_type"); let provider_type = provider_type_str .parse::() .unwrap_or_default(); - let enforcement_mode_str: String = row.get("enforcement_mode"); + let enforcement_mode_str: String = row.col("enforcement_mode"); let enforcement_mode = enforcement_mode_str .parse::() .unwrap_or_default(); Ok(OrgSsoConfig { - id: parse_uuid(&row.get::("id"))?, - org_id: parse_uuid(&row.get::("org_id"))?, + id: parse_uuid(&row.col::("id"))?, + org_id: parse_uuid(&row.col::("org_id"))?, provider_type, // OIDC fields - issuer: row.get("issuer"), - discovery_url: row.get("discovery_url"), - client_id: row.get("client_id"), - redirect_uri: row.get("redirect_uri"), + issuer: row.col("issuer"), + discovery_url: row.col("discovery_url"), + client_id: row.col("client_id"), + redirect_uri: row.col("redirect_uri"), scopes, - identity_claim: row.get("identity_claim"), - org_claim: row.get("org_claim"), - groups_claim: row.get("groups_claim"), + identity_claim: row.col("identity_claim"), + org_claim: row.col("org_claim"), + groups_claim: row.col("groups_claim"), // SAML fields - saml_metadata_url: row.get("saml_metadata_url"), - saml_idp_entity_id: row.get("saml_idp_entity_id"), - saml_idp_sso_url: row.get("saml_idp_sso_url"), - saml_idp_slo_url: row.get("saml_idp_slo_url"), - saml_idp_certificate: row.get("saml_idp_certificate"), - saml_sp_entity_id: row.get("saml_sp_entity_id"), - saml_name_id_format: row.get("saml_name_id_format"), - saml_sign_requests: row.get::("saml_sign_requests") != 0, - saml_sp_certificate: row.get("saml_sp_certificate"), - saml_force_authn: row.get::("saml_force_authn") != 0, - saml_authn_context_class_ref: row.get("saml_authn_context_class_ref"), - saml_identity_attribute: row.get("saml_identity_attribute"), - saml_email_attribute: row.get("saml_email_attribute"), - saml_name_attribute: row.get("saml_name_attribute"), - saml_groups_attribute: row.get("saml_groups_attribute"), + saml_metadata_url: row.col("saml_metadata_url"), + saml_idp_entity_id: row.col("saml_idp_entity_id"), + saml_idp_sso_url: row.col("saml_idp_sso_url"), + saml_idp_slo_url: row.col("saml_idp_slo_url"), + saml_idp_certificate: row.col("saml_idp_certificate"), + saml_sp_entity_id: row.col("saml_sp_entity_id"), + saml_name_id_format: row.col("saml_name_id_format"), + saml_sign_requests: row.col::("saml_sign_requests") != 0, + saml_sp_certificate: row.col("saml_sp_certificate"), + saml_force_authn: row.col::("saml_force_authn") != 0, + saml_authn_context_class_ref: row.col("saml_authn_context_class_ref"), + saml_identity_attribute: row.col("saml_identity_attribute"), + saml_email_attribute: row.col("saml_email_attribute"), + saml_name_attribute: row.col("saml_name_attribute"), + saml_groups_attribute: row.col("saml_groups_attribute"), // JIT provisioning - provisioning_enabled: row.get::("provisioning_enabled") != 0, - create_users: row.get::("create_users") != 0, + provisioning_enabled: row.col::("provisioning_enabled") != 0, + create_users: row.col::("create_users") != 0, default_team_id, - default_org_role: row.get("default_org_role"), - default_team_role: row.get("default_team_role"), + default_org_role: row.col("default_org_role"), + default_team_role: row.col("default_team_role"), allowed_email_domains, - sync_attributes_on_login: row.get::("sync_attributes_on_login") != 0, - sync_memberships_on_login: row.get::("sync_memberships_on_login") != 0, + sync_attributes_on_login: row.col::("sync_attributes_on_login") != 0, + sync_memberships_on_login: row.col::("sync_memberships_on_login") != 0, enforcement_mode, - enabled: row.get::("enabled") != 0, - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + enabled: row.col::("enabled") != 0, + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) } /// Parse an OrgSsoConfigWithSecret from a database row. - fn parse_config_with_secret(row: &sqlx::sqlite::SqliteRow) -> DbResult { + fn parse_config_with_secret(row: &Row) -> DbResult { let config = Self::parse_config(row)?; - let client_secret_key: Option = row.get("client_secret_key"); - let saml_sp_private_key_ref: Option = row.get("saml_sp_private_key_ref"); + let client_secret_key: Option = row.col("client_secret_key"); + let saml_sp_private_key_ref: Option = row.col("saml_sp_private_key_ref"); Ok(OrgSsoConfigWithSecret { config, client_secret_key, @@ -104,7 +106,8 @@ impl SqliteOrgSsoConfigRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl OrgSsoConfigRepo for SqliteOrgSsoConfigRepo { async fn create( &self, @@ -123,7 +126,7 @@ impl OrgSsoConfigRepo for SqliteOrgSsoConfigRepo { Some(serde_json::to_string(&input.allowed_email_domains).unwrap_or_default()) }; - sqlx::query( + query( r#" INSERT INTO org_sso_configs ( id, org_id, provider_type, @@ -189,12 +192,9 @@ impl OrgSsoConfigRepo for SqliteOrgSsoConfigRepo { .bind(now) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict("Organization already has an SSO configuration".into()) - } - _ => DbError::from(e), - })?; + .map_err(map_unique_violation( + "Organization already has an SSO configuration", + ))?; Ok(OrgSsoConfig { id, @@ -242,7 +242,7 @@ impl OrgSsoConfigRepo for SqliteOrgSsoConfigRepo { } async fn get_by_id(&self, id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, org_id, provider_type, issuer, discovery_url, client_id, client_secret_key, @@ -270,7 +270,7 @@ impl OrgSsoConfigRepo for SqliteOrgSsoConfigRepo { } async fn get_by_org_id(&self, org_id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, org_id, provider_type, issuer, discovery_url, client_id, client_secret_key, @@ -298,7 +298,7 @@ impl OrgSsoConfigRepo for SqliteOrgSsoConfigRepo { } async fn get_with_secret(&self, id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, org_id, provider_type, issuer, discovery_url, client_id, client_secret_key, @@ -329,7 +329,7 @@ impl OrgSsoConfigRepo for SqliteOrgSsoConfigRepo { &self, org_id: Uuid, ) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, org_id, provider_type, issuer, discovery_url, client_id, client_secret_key, @@ -393,7 +393,7 @@ impl OrgSsoConfigRepo for SqliteOrgSsoConfigRepo { } }); - sqlx::query( + query( r#" UPDATE org_sso_configs SET provider_type = ?, issuer = ?, discovery_url = ?, client_id = ?, client_secret_key = ?, @@ -486,7 +486,7 @@ impl OrgSsoConfigRepo for SqliteOrgSsoConfigRepo { } async fn delete(&self, id: Uuid) -> DbResult<()> { - let result = sqlx::query("DELETE FROM org_sso_configs WHERE id = ?") + let result = query("DELETE FROM org_sso_configs WHERE id = ?") .bind(id.to_string()) .execute(&self.pool) .await?; @@ -499,7 +499,7 @@ impl OrgSsoConfigRepo for SqliteOrgSsoConfigRepo { } async fn find_enabled_oidc_by_issuer(&self, issuer: &str) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT id, org_id, provider_type, issuer, discovery_url, client_id, client_secret_key, @@ -526,7 +526,7 @@ impl OrgSsoConfigRepo for SqliteOrgSsoConfigRepo { async fn find_by_email_domain(&self, domain: &str) -> DbResult> { // Search for configs where the domain is in the allowed_email_domains JSON array // SQLite uses json_each to search JSON arrays - let result = sqlx::query( + let result = query( r#" SELECT c.id, c.org_id, c.provider_type, c.issuer, c.discovery_url, c.client_id, c.client_secret_key, @@ -555,7 +555,7 @@ impl OrgSsoConfigRepo for SqliteOrgSsoConfigRepo { } async fn list_enabled(&self) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT id, org_id, provider_type, issuer, discovery_url, client_id, client_secret_key, @@ -582,16 +582,18 @@ impl OrgSsoConfigRepo for SqliteOrgSsoConfigRepo { } async fn any_enabled(&self) -> DbResult { - let result: (i32,) = - sqlx::query_as("SELECT EXISTS(SELECT 1 FROM org_sso_configs WHERE enabled = 1)") + let row = + query("SELECT EXISTS(SELECT 1 FROM org_sso_configs WHERE enabled = 1) as has_any") .fetch_one(&self.pool) .await?; - Ok(result.0 != 0) + Ok(row.col::("has_any") != 0) } } #[cfg(test)] mod tests { + use sqlx::SqlitePool; + use super::*; async fn create_test_pool() -> SqlitePool { diff --git a/src/db/sqlite/organizations.rs b/src/db/sqlite/organizations.rs index 5b3713c..c307288 100644 --- a/src/db/sqlite/organizations.rs +++ b/src/db/sqlite/organizations.rs @@ -1,8 +1,10 @@ use async_trait::async_trait; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; -use super::common::parse_uuid; +use super::{ + backend::{Pool, RowExt, map_unique_violation, query}, + common::parse_uuid, +}; use crate::{ db::{ error::{DbError, DbResult}, @@ -15,11 +17,11 @@ use crate::{ }; pub struct SqliteOrganizationRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteOrganizationRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } @@ -42,7 +44,7 @@ impl SqliteOrganizationRepo { "AND deleted_at IS NULL" }; - let query = format!( + let sql = format!( r#" SELECT id, slug, name, created_at, updated_at FROM organizations @@ -54,7 +56,7 @@ impl SqliteOrganizationRepo { comparison, deleted_filter, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(cursor.created_at) .bind(cursor.id.to_string()) .bind(fetch_limit) @@ -67,11 +69,11 @@ impl SqliteOrganizationRepo { .take(limit as usize) .map(|row| { Ok(Organization { - id: parse_uuid(&row.get::("id"))?, - slug: row.get("slug"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + slug: row.col("slug"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) }) .collect::>>()?; @@ -90,13 +92,14 @@ impl SqliteOrganizationRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl OrganizationRepo for SqliteOrganizationRepo { async fn create(&self, input: CreateOrganization) -> DbResult { let id = Uuid::new_v4(); let now = chrono::Utc::now(); - sqlx::query( + query( r#" INSERT INTO organizations (id, slug, name, created_at, updated_at) VALUES (?, ?, ?, ?, ?) @@ -109,12 +112,10 @@ impl OrganizationRepo for SqliteOrganizationRepo { .bind(now) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => DbError::Conflict( - format!("Organization with slug '{}' already exists", input.slug), - ), - _ => DbError::from(e), - })?; + .map_err(map_unique_violation(format!( + "Organization with slug '{}' already exists", + input.slug + )))?; Ok(Organization { id, @@ -126,7 +127,7 @@ impl OrganizationRepo for SqliteOrganizationRepo { } async fn get_by_id(&self, id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, slug, name, created_at, updated_at FROM organizations @@ -139,18 +140,18 @@ impl OrganizationRepo for SqliteOrganizationRepo { match result { Some(row) => Ok(Some(Organization { - id: parse_uuid(&row.get::("id"))?, - slug: row.get("slug"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + slug: row.col("slug"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), })), None => Ok(None), } } async fn get_by_slug(&self, slug: &str) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, slug, name, created_at, updated_at FROM organizations @@ -163,11 +164,11 @@ impl OrganizationRepo for SqliteOrganizationRepo { match result { Some(row) => Ok(Some(Organization { - id: parse_uuid(&row.get::("id"))?, - slug: row.get("slug"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + slug: row.col("slug"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), })), None => Ok(None), } @@ -186,7 +187,7 @@ impl OrganizationRepo for SqliteOrganizationRepo { } // First page (no cursor provided) - let query = if params.include_deleted { + let sql = if params.include_deleted { r#" SELECT id, slug, name, created_at, updated_at FROM organizations @@ -203,10 +204,7 @@ impl OrganizationRepo for SqliteOrganizationRepo { "# }; - let rows = sqlx::query(query) - .bind(fetch_limit) - .fetch_all(&self.pool) - .await?; + let rows = query(sql).bind(fetch_limit).fetch_all(&self.pool).await?; let has_more = rows.len() as i64 > limit; let items: Vec = rows @@ -214,11 +212,11 @@ impl OrganizationRepo for SqliteOrganizationRepo { .take(limit as usize) .map(|row| { Ok(Organization { - id: parse_uuid(&row.get::("id"))?, - slug: row.get("slug"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + slug: row.col("slug"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) }) .collect::>>()?; @@ -233,21 +231,21 @@ impl OrganizationRepo for SqliteOrganizationRepo { } async fn count(&self, include_deleted: bool) -> DbResult { - let query = if include_deleted { + let sql = if include_deleted { "SELECT COUNT(*) as count FROM organizations" } else { "SELECT COUNT(*) as count FROM organizations WHERE deleted_at IS NULL" }; - let row = sqlx::query(query).fetch_one(&self.pool).await?; - Ok(row.get::("count")) + let row = query(sql).fetch_one(&self.pool).await?; + Ok(row.col::("count")) } async fn update(&self, id: Uuid, input: UpdateOrganization) -> DbResult { if let Some(name) = input.name { let now = chrono::Utc::now(); - let result = sqlx::query( + let result = query( r#" UPDATE organizations SET name = ?, updated_at = ? @@ -274,7 +272,7 @@ impl OrganizationRepo for SqliteOrganizationRepo { async fn delete(&self, id: Uuid) -> DbResult<()> { let now = chrono::Utc::now(); - let result = sqlx::query( + let result = query( r#" UPDATE organizations SET deleted_at = ? @@ -296,6 +294,8 @@ impl OrganizationRepo for SqliteOrganizationRepo { #[cfg(test)] mod tests { + use sqlx::SqlitePool; + use super::*; use crate::db::repos::OrganizationRepo; diff --git a/src/db/sqlite/projects.rs b/src/db/sqlite/projects.rs index 6d671e0..8ca9dfd 100644 --- a/src/db/sqlite/projects.rs +++ b/src/db/sqlite/projects.rs @@ -1,8 +1,10 @@ use async_trait::async_trait; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; -use super::common::parse_uuid; +use super::{ + backend::{Pool, RowExt, map_unique_violation, query}, + common::parse_uuid, +}; use crate::{ db::{ error::{DbError, DbResult}, @@ -15,11 +17,11 @@ use crate::{ }; pub struct SqliteProjectRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteProjectRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } @@ -43,7 +45,7 @@ impl SqliteProjectRepo { "AND deleted_at IS NULL" }; - let query = format!( + let sql = format!( r#" SELECT id, org_id, team_id, slug, name, created_at, updated_at FROM projects @@ -55,7 +57,7 @@ impl SqliteProjectRepo { comparison, deleted_filter, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(org_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -68,15 +70,15 @@ impl SqliteProjectRepo { .into_iter() .take(limit as usize) .map(|row| { - let team_id: Option = row.get("team_id"); + let team_id: Option = row.col("team_id"); Ok(Project { - id: parse_uuid(&row.get::("id"))?, - org_id: parse_uuid(&row.get::("org_id"))?, + id: parse_uuid(&row.col::("id"))?, + org_id: parse_uuid(&row.col::("org_id"))?, team_id: team_id.as_deref().map(parse_uuid).transpose()?, - slug: row.get("slug"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + slug: row.col("slug"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) }) .collect::>>()?; @@ -95,13 +97,14 @@ impl SqliteProjectRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl ProjectRepo for SqliteProjectRepo { async fn create(&self, org_id: Uuid, input: CreateProject) -> DbResult { let id = Uuid::new_v4(); let now = chrono::Utc::now(); - sqlx::query( + query( r#" INSERT INTO projects (id, org_id, team_id, slug, name, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?) @@ -116,15 +119,10 @@ impl ProjectRepo for SqliteProjectRepo { .bind(now) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict(format!( - "Project with slug '{}' already exists in this organization", - input.slug - )) - } - _ => DbError::from(e), - })?; + .map_err(map_unique_violation(format!( + "Project with slug '{}' already exists in this organization", + input.slug + )))?; Ok(Project { id, @@ -138,7 +136,7 @@ impl ProjectRepo for SqliteProjectRepo { } async fn get_by_id(&self, id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, org_id, team_id, slug, name, created_at, updated_at FROM projects @@ -151,15 +149,15 @@ impl ProjectRepo for SqliteProjectRepo { match result { Some(row) => { - let team_id: Option = row.get("team_id"); + let team_id: Option = row.col("team_id"); Ok(Some(Project { - id: parse_uuid(&row.get::("id"))?, - org_id: parse_uuid(&row.get::("org_id"))?, + id: parse_uuid(&row.col::("id"))?, + org_id: parse_uuid(&row.col::("org_id"))?, team_id: team_id.as_deref().map(parse_uuid).transpose()?, - slug: row.get("slug"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + slug: row.col("slug"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), })) } None => Ok(None), @@ -167,7 +165,7 @@ impl ProjectRepo for SqliteProjectRepo { } async fn get_by_id_and_org(&self, id: Uuid, org_id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, org_id, team_id, slug, name, created_at, updated_at FROM projects @@ -181,15 +179,15 @@ impl ProjectRepo for SqliteProjectRepo { match result { Some(row) => { - let team_id: Option = row.get("team_id"); + let team_id: Option = row.col("team_id"); Ok(Some(Project { - id: parse_uuid(&row.get::("id"))?, - org_id: parse_uuid(&row.get::("org_id"))?, + id: parse_uuid(&row.col::("id"))?, + org_id: parse_uuid(&row.col::("org_id"))?, team_id: team_id.as_deref().map(parse_uuid).transpose()?, - slug: row.get("slug"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + slug: row.col("slug"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), })) } None => Ok(None), @@ -197,7 +195,7 @@ impl ProjectRepo for SqliteProjectRepo { } async fn get_by_slug(&self, org_id: Uuid, slug: &str) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, org_id, team_id, slug, name, created_at, updated_at FROM projects @@ -211,15 +209,15 @@ impl ProjectRepo for SqliteProjectRepo { match result { Some(row) => { - let team_id: Option = row.get("team_id"); + let team_id: Option = row.col("team_id"); Ok(Some(Project { - id: parse_uuid(&row.get::("id"))?, - org_id: parse_uuid(&row.get::("org_id"))?, + id: parse_uuid(&row.col::("id"))?, + org_id: parse_uuid(&row.col::("org_id"))?, team_id: team_id.as_deref().map(parse_uuid).transpose()?, - slug: row.get("slug"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + slug: row.col("slug"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), })) } None => Ok(None), @@ -239,7 +237,7 @@ impl ProjectRepo for SqliteProjectRepo { } // First page (no cursor provided) - let query = if params.include_deleted { + let sql = if params.include_deleted { r#" SELECT id, org_id, team_id, slug, name, created_at, updated_at FROM projects @@ -257,7 +255,7 @@ impl ProjectRepo for SqliteProjectRepo { "# }; - let rows = sqlx::query(query) + let rows = query(sql) .bind(org_id.to_string()) .bind(fetch_limit) .fetch_all(&self.pool) @@ -268,15 +266,15 @@ impl ProjectRepo for SqliteProjectRepo { .into_iter() .take(limit as usize) .map(|row| { - let team_id: Option = row.get("team_id"); + let team_id: Option = row.col("team_id"); Ok(Project { - id: parse_uuid(&row.get::("id"))?, - org_id: parse_uuid(&row.get::("org_id"))?, + id: parse_uuid(&row.col::("id"))?, + org_id: parse_uuid(&row.col::("org_id"))?, team_id: team_id.as_deref().map(parse_uuid).transpose()?, - slug: row.get("slug"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + slug: row.col("slug"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) }) .collect::>>()?; @@ -291,17 +289,17 @@ impl ProjectRepo for SqliteProjectRepo { } async fn count_by_org(&self, org_id: Uuid, include_deleted: bool) -> DbResult { - let query = if include_deleted { + let sql = if include_deleted { "SELECT COUNT(*) as count FROM projects WHERE org_id = ?" } else { "SELECT COUNT(*) as count FROM projects WHERE org_id = ? AND deleted_at IS NULL" }; - let row = sqlx::query(query) + let row = query(sql) .bind(org_id.to_string()) .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn update(&self, id: Uuid, input: UpdateProject) -> DbResult { @@ -323,12 +321,12 @@ impl ProjectRepo for SqliteProjectRepo { set_clauses.push("team_id = ?"); } - let query = format!( + let sql = format!( "UPDATE projects SET {} WHERE id = ? AND deleted_at IS NULL", set_clauses.join(", ") ); - let mut query_builder = sqlx::query(&query).bind(now); + let mut query_builder = query(&sql).bind(now); if let Some(ref name) = input.name { query_builder = query_builder.bind(name); @@ -352,7 +350,7 @@ impl ProjectRepo for SqliteProjectRepo { async fn delete(&self, id: Uuid) -> DbResult<()> { let now = chrono::Utc::now(); - let result = sqlx::query( + let result = query( r#" UPDATE projects SET deleted_at = ? @@ -374,6 +372,8 @@ impl ProjectRepo for SqliteProjectRepo { #[cfg(test)] mod tests { + use sqlx::SqlitePool; + use super::*; use crate::db::repos::ProjectRepo; diff --git a/src/db/sqlite/prompts.rs b/src/db/sqlite/prompts.rs index 1aedf19..f9c1654 100644 --- a/src/db/sqlite/prompts.rs +++ b/src/db/sqlite/prompts.rs @@ -1,10 +1,12 @@ use std::collections::HashMap; use async_trait::async_trait; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; -use super::common::parse_uuid; +use super::{ + backend::{Pool, RowExt, map_unique_violation, query}, + common::parse_uuid, +}; use crate::{ db::{ error::{DbError, DbResult}, @@ -17,37 +19,37 @@ use crate::{ }; pub struct SqlitePromptRepo { - pool: SqlitePool, + pool: Pool, } impl SqlitePromptRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } /// Parse a Prompt from a database row. - fn parse_prompt(row: &sqlx::sqlite::SqliteRow) -> DbResult { - let owner_type_str: String = row.get("owner_type"); + fn parse_prompt(row: &super::backend::Row) -> DbResult { + let owner_type_str: String = row.col("owner_type"); let owner_type: PromptOwnerType = owner_type_str .parse() .map_err(|e: String| DbError::Internal(e))?; - let metadata: Option = row.get("metadata"); + let metadata: Option = row.col("metadata"); let metadata: Option> = metadata .map(|s| serde_json::from_str(&s)) .transpose() .map_err(|e| DbError::Internal(format!("Failed to parse metadata: {}", e)))?; Ok(Prompt { - id: parse_uuid(&row.get::("id"))?, + id: parse_uuid(&row.col::("id"))?, owner_type, - owner_id: parse_uuid(&row.get::("owner_id"))?, - name: row.get("name"), - description: row.get("description"), - content: row.get("content"), + owner_id: parse_uuid(&row.col::("owner_id"))?, + name: row.col("name"), + description: row.col("description"), + content: row.col("content"), metadata, - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) } @@ -70,7 +72,7 @@ impl SqlitePromptRepo { "AND deleted_at IS NULL" }; - let query = format!( + let sql = format!( r#" SELECT id, owner_type, owner_id, name, description, content, metadata, created_at, updated_at FROM prompts @@ -82,7 +84,7 @@ impl SqlitePromptRepo { comparison, deleted_filter, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(owner_type.as_str()) .bind(owner_id.to_string()) .bind(cursor.created_at) @@ -111,7 +113,8 @@ impl SqlitePromptRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl PromptRepo for SqlitePromptRepo { async fn create(&self, input: CreatePrompt) -> DbResult { let id = Uuid::new_v4(); @@ -126,7 +129,7 @@ impl PromptRepo for SqlitePromptRepo { .transpose() .map_err(|e| DbError::Internal(format!("Failed to serialize metadata: {}", e)))?; - sqlx::query( + query( r#" INSERT INTO prompts (id, owner_type, owner_id, name, description, content, metadata, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) @@ -143,15 +146,10 @@ impl PromptRepo for SqlitePromptRepo { .bind(now) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict(format!( - "Prompt with name '{}' already exists for this owner", - input.name - )) - } - _ => DbError::from(e), - })?; + .map_err(map_unique_violation(format!( + "Prompt with name '{}' already exists for this owner", + input.name + )))?; Ok(Prompt { id, @@ -167,7 +165,7 @@ impl PromptRepo for SqlitePromptRepo { } async fn get_by_id(&self, id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, owner_type, owner_id, name, description, content, metadata, created_at, updated_at FROM prompts @@ -185,7 +183,7 @@ impl PromptRepo for SqlitePromptRepo { } async fn get_by_id_and_org(&self, id: Uuid, org_id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT p.id, p.owner_type, p.owner_id, p.name, p.description, p.content, p.metadata, p.created_at, p.updated_at FROM prompts p @@ -236,7 +234,7 @@ impl PromptRepo for SqlitePromptRepo { .await; } - let query = if params.include_deleted { + let sql = if params.include_deleted { r#" SELECT id, owner_type, owner_id, name, description, content, metadata, created_at, updated_at FROM prompts @@ -254,7 +252,7 @@ impl PromptRepo for SqlitePromptRepo { "# }; - let rows = sqlx::query(query) + let rows = query(sql) .bind(owner_type.as_str()) .bind(owner_id.to_string()) .bind(fetch_limit) @@ -282,19 +280,19 @@ impl PromptRepo for SqlitePromptRepo { owner_id: Uuid, include_deleted: bool, ) -> DbResult { - let query = if include_deleted { + let sql = if include_deleted { "SELECT COUNT(*) as count FROM prompts WHERE owner_type = ? AND owner_id = ?" } else { "SELECT COUNT(*) as count FROM prompts WHERE owner_type = ? AND owner_id = ? AND deleted_at IS NULL" }; - let row = sqlx::query(query) + let row = query(sql) .bind(owner_type.as_str()) .bind(owner_id.to_string()) .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn update(&self, id: Uuid, input: UpdatePrompt) -> DbResult { @@ -323,12 +321,12 @@ impl PromptRepo for SqlitePromptRepo { set_clauses.push("metadata = ?"); } - let query = format!( + let sql = format!( "UPDATE prompts SET {} WHERE id = ? AND deleted_at IS NULL", set_clauses.join(", ") ); - let mut query_builder = sqlx::query(&query).bind(now); + let mut query_builder = query(&sql).bind(now); if let Some(ref name) = input.name { query_builder = query_builder.bind(name); @@ -349,12 +347,9 @@ impl PromptRepo for SqlitePromptRepo { .bind(id.to_string()) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict("Prompt with this name already exists for this owner".into()) - } - _ => DbError::from(e), - })?; + .map_err(map_unique_violation( + "Prompt with this name already exists for this owner", + ))?; if result.rows_affected() == 0 { return Err(DbError::NotFound); @@ -366,7 +361,7 @@ impl PromptRepo for SqlitePromptRepo { async fn delete(&self, id: Uuid) -> DbResult<()> { let now = chrono::Utc::now(); - let result = sqlx::query( + let result = query( r#" UPDATE prompts SET deleted_at = ? @@ -388,6 +383,8 @@ impl PromptRepo for SqlitePromptRepo { #[cfg(test)] mod tests { + use sqlx::SqlitePool; + use super::*; use crate::models::PromptOwner; diff --git a/src/db/sqlite/providers.rs b/src/db/sqlite/providers.rs index 096203b..6530e79 100644 --- a/src/db/sqlite/providers.rs +++ b/src/db/sqlite/providers.rs @@ -1,7 +1,7 @@ use async_trait::async_trait; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; +use super::backend::{Pool, Row, RowExt, map_unique_violation, query}; use crate::{ db::{ error::{DbError, DbResult}, @@ -13,11 +13,11 @@ use crate::{ }; pub struct SqliteDynamicProviderRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteDynamicProviderRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } @@ -50,28 +50,31 @@ impl SqliteDynamicProviderRepo { } } - fn parse_provider(row: &sqlx::sqlite::SqliteRow) -> DbResult { - let owner = Self::parse_owner(row.get("owner_type"), row.get("owner_id"))?; - let models_json: String = row.get("models"); + fn parse_provider(row: &Row) -> DbResult { + let owner_type: String = row.col("owner_type"); + let owner_id: String = row.col("owner_id"); + let owner = Self::parse_owner(&owner_type, &owner_id)?; + let models_json: String = row.col("models"); let models: Vec = serde_json::from_str(&models_json).map_err(|e| DbError::Internal(e.to_string()))?; let config: Option = row - .get::, _>("config") + .col::>("config") .and_then(|s| serde_json::from_str(&s).ok()); + let id_str: String = row.col("id"); Ok(DynamicProvider { - id: Uuid::parse_str(row.get("id")).map_err(|e| DbError::Internal(e.to_string()))?, - name: row.get("name"), + id: Uuid::parse_str(&id_str).map_err(|e| DbError::Internal(e.to_string()))?, + name: row.col("name"), owner, - provider_type: row.get("provider_type"), - base_url: row.get("base_url"), - api_key_secret_ref: row.get("api_key_secret_ref"), + provider_type: row.col("provider_type"), + base_url: row.col("base_url"), + api_key_secret_ref: row.col("api_key_secret_ref"), config, models, - is_enabled: row.get::("is_enabled") != 0, - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + is_enabled: row.col::("is_enabled") != 0, + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) } @@ -87,7 +90,7 @@ impl SqliteDynamicProviderRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT id, owner_type, owner_id, name, provider_type, base_url, api_key_secret_ref, config, models, is_enabled, created_at, updated_at @@ -100,7 +103,7 @@ impl SqliteDynamicProviderRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(org_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -139,7 +142,7 @@ impl SqliteDynamicProviderRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT id, owner_type, owner_id, name, provider_type, base_url, api_key_secret_ref, config, models, is_enabled, created_at, updated_at @@ -152,7 +155,7 @@ impl SqliteDynamicProviderRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(project_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -191,7 +194,7 @@ impl SqliteDynamicProviderRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT id, owner_type, owner_id, name, provider_type, base_url, api_key_secret_ref, config, models, is_enabled, created_at, updated_at @@ -204,7 +207,7 @@ impl SqliteDynamicProviderRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(team_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -243,7 +246,7 @@ impl SqliteDynamicProviderRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT id, owner_type, owner_id, name, provider_type, base_url, api_key_secret_ref, config, models, is_enabled, created_at, updated_at @@ -256,7 +259,7 @@ impl SqliteDynamicProviderRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(user_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -284,7 +287,8 @@ impl SqliteDynamicProviderRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl DynamicProviderRepo for SqliteDynamicProviderRepo { async fn create(&self, id: Uuid, input: CreateDynamicProvider) -> DbResult { let now = chrono::Utc::now(); @@ -298,7 +302,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { .map(|c| serde_json::to_string(c).map_err(|e| DbError::Internal(e.to_string()))) .transpose()?; - sqlx::query( + query( r#" INSERT INTO dynamic_providers ( id, owner_type, owner_id, name, provider_type, base_url, @@ -320,12 +324,10 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { .bind(now) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => DbError::Conflict( - format!("Provider '{}' already exists for this owner", input.name), - ), - _ => DbError::from(e), - })?; + .map_err(map_unique_violation(format!( + "Provider '{}' already exists for this owner", + input.name + )))?; Ok(DynamicProvider { id, @@ -343,7 +345,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { } async fn get_by_id(&self, id: Uuid) -> DbResult> { - let row = sqlx::query( + let row = query( r#" SELECT id, owner_type, owner_id, name, provider_type, base_url, api_key_secret_ref, config, models, is_enabled, created_at, updated_at @@ -365,7 +367,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { ) -> DbResult> { let (owner_type, owner_id) = Self::owner_to_parts(owner); - let row = sqlx::query( + let row = query( r#" SELECT id, owner_type, owner_id, name, provider_type, base_url, api_key_secret_ref, config, models, is_enabled, created_at, updated_at @@ -398,7 +400,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { } // First page (no cursor provided) - let rows = sqlx::query( + let rows = query( r#" SELECT id, owner_type, owner_id, name, provider_type, base_url, api_key_secret_ref, config, models, is_enabled, created_at, updated_at @@ -428,7 +430,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { } async fn count_by_org(&self, org_id: Uuid) -> DbResult { - let row = sqlx::query( + let row = query( r#" SELECT COUNT(*) as count FROM dynamic_providers @@ -439,7 +441,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn list_by_team( @@ -458,7 +460,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { } // First page (no cursor provided) - let rows = sqlx::query( + let rows = query( r#" SELECT id, owner_type, owner_id, name, provider_type, base_url, api_key_secret_ref, config, models, is_enabled, created_at, updated_at @@ -488,7 +490,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { } async fn count_by_team(&self, team_id: Uuid) -> DbResult { - let row = sqlx::query( + let row = query( r#" SELECT COUNT(*) as count FROM dynamic_providers @@ -499,7 +501,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn list_by_project( @@ -518,7 +520,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { } // First page (no cursor provided) - let rows = sqlx::query( + let rows = query( r#" SELECT id, owner_type, owner_id, name, provider_type, base_url, api_key_secret_ref, config, models, is_enabled, created_at, updated_at @@ -548,7 +550,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { } async fn count_by_project(&self, project_id: Uuid) -> DbResult { - let row = sqlx::query( + let row = query( r#" SELECT COUNT(*) as count FROM dynamic_providers @@ -559,7 +561,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn list_by_user( @@ -578,7 +580,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { } // First page (no cursor provided) - let rows = sqlx::query( + let rows = query( r#" SELECT id, owner_type, owner_id, name, provider_type, base_url, api_key_secret_ref, config, models, is_enabled, created_at, updated_at @@ -608,7 +610,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { } async fn count_by_user(&self, user_id: Uuid) -> DbResult { - let row = sqlx::query( + let row = query( r#" SELECT COUNT(*) as count FROM dynamic_providers @@ -619,7 +621,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn list_enabled_by_user( @@ -634,7 +636,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT id, owner_type, owner_id, name, provider_type, base_url, api_key_secret_ref, config, models, is_enabled, created_at, updated_at @@ -647,7 +649,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(user_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -674,7 +676,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { return Ok(ListResult::new(items, has_more, cursors)); } - let rows = sqlx::query( + let rows = query( r#" SELECT id, owner_type, owner_id, name, provider_type, base_url, api_key_secret_ref, config, models, is_enabled, created_at, updated_at @@ -715,7 +717,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT id, owner_type, owner_id, name, provider_type, base_url, api_key_secret_ref, config, models, is_enabled, created_at, updated_at @@ -728,7 +730,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(org_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -755,7 +757,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { return Ok(ListResult::new(items, has_more, cursors)); } - let rows = sqlx::query( + let rows = query( r#" SELECT id, owner_type, owner_id, name, provider_type, base_url, api_key_secret_ref, config, models, is_enabled, created_at, updated_at @@ -796,7 +798,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT id, owner_type, owner_id, name, provider_type, base_url, api_key_secret_ref, config, models, is_enabled, created_at, updated_at @@ -809,7 +811,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(project_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -836,7 +838,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { return Ok(ListResult::new(items, has_more, cursors)); } - let rows = sqlx::query( + let rows = query( r#" SELECT id, owner_type, owner_id, name, provider_type, base_url, api_key_secret_ref, config, models, is_enabled, created_at, updated_at @@ -877,7 +879,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT id, owner_type, owner_id, name, provider_type, base_url, api_key_secret_ref, config, models, is_enabled, created_at, updated_at @@ -890,7 +892,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(team_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -917,7 +919,7 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { return Ok(ListResult::new(items, has_more, cursors)); } - let rows = sqlx::query( + let rows = query( r#" SELECT id, owner_type, owner_id, name, provider_type, base_url, api_key_secret_ref, config, models, is_enabled, created_at, updated_at @@ -983,13 +985,13 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { updates.join(", ") ); - let mut query = sqlx::query(&query_str); - query = query.bind(now); + let mut q = query(&query_str); + q = q.bind(now); if has_base_url { - query = query.bind(&input.base_url); + q = q.bind(&input.base_url); } if has_api_key { - query = query.bind(&input.api_key); + q = q.bind(&input.api_key); } if has_config { let config_json = input @@ -997,26 +999,26 @@ impl DynamicProviderRepo for SqliteDynamicProviderRepo { .as_ref() .map(|c| serde_json::to_string(c).map_err(|e| DbError::Internal(e.to_string()))) .transpose()?; - query = query.bind(config_json); + q = q.bind(config_json); } if has_models { let models_json = serde_json::to_string(&input.models.as_ref().unwrap()) .map_err(|e| DbError::Internal(e.to_string()))?; - query = query.bind(models_json); + q = q.bind(models_json); } if has_is_enabled { - query = query.bind(input.is_enabled.map(|b| if b { 1 } else { 0 })); + q = q.bind(input.is_enabled.map(|b| if b { 1 } else { 0 })); } - query = query.bind(id.to_string()); + q = q.bind(id.to_string()); - query.execute(&self.pool).await?; + q.execute(&self.pool).await?; // Fetch and return updated record self.get_by_id(id).await?.ok_or(DbError::NotFound) } async fn delete(&self, id: Uuid) -> DbResult<()> { - sqlx::query("DELETE FROM dynamic_providers WHERE id = ?") + query("DELETE FROM dynamic_providers WHERE id = ?") .bind(id.to_string()) .execute(&self.pool) .await?; diff --git a/src/db/sqlite/scim_configs.rs b/src/db/sqlite/scim_configs.rs index 4ec1fb5..5401cf1 100644 --- a/src/db/sqlite/scim_configs.rs +++ b/src/db/sqlite/scim_configs.rs @@ -1,10 +1,12 @@ //! SQLite implementation of the SCIM config repository. use async_trait::async_trait; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; -use super::common::parse_uuid; +use super::{ + backend::{Pool, Row, RowExt, map_unique_violation, query}, + common::parse_uuid, +}; use crate::{ db::{ error::{DbError, DbResult}, @@ -14,46 +16,47 @@ use crate::{ }; pub struct SqliteOrgScimConfigRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteOrgScimConfigRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } /// Parse an OrgScimConfig from a database row. - fn parse_config(row: &sqlx::sqlite::SqliteRow) -> DbResult { - let default_team_id: Option = row.get("default_team_id"); + fn parse_config(row: &Row) -> DbResult { + let default_team_id: Option = row.col("default_team_id"); let default_team_id = default_team_id.map(|s| parse_uuid(&s)).transpose()?; Ok(OrgScimConfig { - id: parse_uuid(&row.get::("id"))?, - org_id: parse_uuid(&row.get::("org_id"))?, - enabled: row.get::("enabled") != 0, - token_prefix: row.get("token_prefix"), - token_last_used_at: row.get("token_last_used_at"), - create_users: row.get::("create_users") != 0, + id: parse_uuid(&row.col::("id"))?, + org_id: parse_uuid(&row.col::("org_id"))?, + enabled: row.col::("enabled") != 0, + token_prefix: row.col("token_prefix"), + token_last_used_at: row.col("token_last_used_at"), + create_users: row.col::("create_users") != 0, default_team_id, - default_org_role: row.get("default_org_role"), - default_team_role: row.get("default_team_role"), - sync_display_name: row.get::("sync_display_name") != 0, - deactivate_deletes_user: row.get::("deactivate_deletes_user") != 0, - revoke_api_keys_on_deactivate: row.get::("revoke_api_keys_on_deactivate") != 0, - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + default_org_role: row.col("default_org_role"), + default_team_role: row.col("default_team_role"), + sync_display_name: row.col::("sync_display_name") != 0, + deactivate_deletes_user: row.col::("deactivate_deletes_user") != 0, + revoke_api_keys_on_deactivate: row.col::("revoke_api_keys_on_deactivate") != 0, + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) } /// Parse an OrgScimConfigWithHash from a database row. - fn parse_config_with_hash(row: &sqlx::sqlite::SqliteRow) -> DbResult { + fn parse_config_with_hash(row: &Row) -> DbResult { let config = Self::parse_config(row)?; - let token_hash: String = row.get("token_hash"); + let token_hash: String = row.col("token_hash"); Ok(OrgScimConfigWithHash { config, token_hash }) } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl OrgScimConfigRepo for SqliteOrgScimConfigRepo { async fn create( &self, @@ -65,7 +68,7 @@ impl OrgScimConfigRepo for SqliteOrgScimConfigRepo { let id = Uuid::new_v4(); let now = chrono::Utc::now(); - sqlx::query( + query( r#" INSERT INTO org_scim_configs ( id, org_id, enabled, token_hash, token_prefix, @@ -92,12 +95,9 @@ impl OrgScimConfigRepo for SqliteOrgScimConfigRepo { .bind(now) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict("Organization already has a SCIM configuration".into()) - } - _ => DbError::from(e), - })?; + .map_err(map_unique_violation( + "Organization already has a SCIM configuration", + ))?; Ok(OrgScimConfig { id, @@ -118,7 +118,7 @@ impl OrgScimConfigRepo for SqliteOrgScimConfigRepo { } async fn get_by_id(&self, id: Uuid) -> DbResult> { - let row = sqlx::query("SELECT * FROM org_scim_configs WHERE id = ?") + let row = query("SELECT * FROM org_scim_configs WHERE id = ?") .bind(id.to_string()) .fetch_optional(&self.pool) .await?; @@ -127,7 +127,7 @@ impl OrgScimConfigRepo for SqliteOrgScimConfigRepo { } async fn get_by_org_id(&self, org_id: Uuid) -> DbResult> { - let row = sqlx::query("SELECT * FROM org_scim_configs WHERE org_id = ?") + let row = query("SELECT * FROM org_scim_configs WHERE org_id = ?") .bind(org_id.to_string()) .fetch_optional(&self.pool) .await?; @@ -139,7 +139,7 @@ impl OrgScimConfigRepo for SqliteOrgScimConfigRepo { &self, org_id: Uuid, ) -> DbResult> { - let row = sqlx::query("SELECT * FROM org_scim_configs WHERE org_id = ?") + let row = query("SELECT * FROM org_scim_configs WHERE org_id = ?") .bind(org_id.to_string()) .fetch_optional(&self.pool) .await?; @@ -148,7 +148,7 @@ impl OrgScimConfigRepo for SqliteOrgScimConfigRepo { } async fn get_by_token_hash(&self, token_hash: &str) -> DbResult> { - let row = sqlx::query("SELECT * FROM org_scim_configs WHERE token_hash = ?") + let row = query("SELECT * FROM org_scim_configs WHERE token_hash = ?") .bind(token_hash) .fetch_optional(&self.pool) .await?; @@ -182,7 +182,7 @@ impl OrgScimConfigRepo for SqliteOrgScimConfigRepo { let now = chrono::Utc::now(); - sqlx::query( + query( r#" UPDATE org_scim_configs SET enabled = ?, create_users = ?, default_team_id = ?, @@ -231,7 +231,7 @@ impl OrgScimConfigRepo for SqliteOrgScimConfigRepo { ) -> DbResult { let now = chrono::Utc::now(); - let result = sqlx::query( + let result = query( r#" UPDATE org_scim_configs SET token_hash = ?, token_prefix = ?, token_last_used_at = NULL, updated_at = ? @@ -255,7 +255,7 @@ impl OrgScimConfigRepo for SqliteOrgScimConfigRepo { async fn update_token_last_used(&self, id: Uuid) -> DbResult<()> { let now = chrono::Utc::now(); - sqlx::query("UPDATE org_scim_configs SET token_last_used_at = ? WHERE id = ?") + query("UPDATE org_scim_configs SET token_last_used_at = ? WHERE id = ?") .bind(now) .bind(id.to_string()) .execute(&self.pool) @@ -265,7 +265,7 @@ impl OrgScimConfigRepo for SqliteOrgScimConfigRepo { } async fn delete(&self, id: Uuid) -> DbResult<()> { - let result = sqlx::query("DELETE FROM org_scim_configs WHERE id = ?") + let result = query("DELETE FROM org_scim_configs WHERE id = ?") .bind(id.to_string()) .execute(&self.pool) .await?; @@ -278,7 +278,7 @@ impl OrgScimConfigRepo for SqliteOrgScimConfigRepo { } async fn list_enabled(&self) -> DbResult> { - let rows = sqlx::query("SELECT * FROM org_scim_configs WHERE enabled = 1") + let rows = query("SELECT * FROM org_scim_configs WHERE enabled = 1") .fetch_all(&self.pool) .await?; diff --git a/src/db/sqlite/scim_group_mappings.rs b/src/db/sqlite/scim_group_mappings.rs index 4c613db..18aa968 100644 --- a/src/db/sqlite/scim_group_mappings.rs +++ b/src/db/sqlite/scim_group_mappings.rs @@ -1,10 +1,12 @@ //! SQLite implementation of the SCIM group mapping repository. use async_trait::async_trait; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; -use super::common::parse_uuid; +use super::{ + backend::{Pool, Row, RowExt, query}, + common::parse_uuid, +}; use crate::{ db::{ error::{DbError, DbResult}, @@ -19,52 +21,53 @@ use crate::{ }; pub struct SqliteScimGroupMappingRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteScimGroupMappingRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } /// Parse a ScimGroupMapping from a database row. - fn parse_mapping(row: &sqlx::sqlite::SqliteRow) -> DbResult { + fn parse_mapping(row: &Row) -> DbResult { Ok(ScimGroupMapping { - id: parse_uuid(&row.get::("id"))?, - org_id: parse_uuid(&row.get::("org_id"))?, - scim_group_id: row.get("scim_group_id"), - team_id: parse_uuid(&row.get::("team_id"))?, - display_name: row.get("display_name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + org_id: parse_uuid(&row.col::("org_id"))?, + scim_group_id: row.col("scim_group_id"), + team_id: parse_uuid(&row.col::("team_id"))?, + display_name: row.col("display_name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) } /// Parse a ScimGroupWithTeam from a joined row with aliased columns. - fn parse_mapping_with_team(row: &sqlx::sqlite::SqliteRow) -> DbResult { + fn parse_mapping_with_team(row: &Row) -> DbResult { Ok(ScimGroupWithTeam { mapping: ScimGroupMapping { - id: parse_uuid(&row.get::("m_id"))?, - org_id: parse_uuid(&row.get::("m_org_id"))?, - scim_group_id: row.get("m_scim_group_id"), - team_id: parse_uuid(&row.get::("m_team_id"))?, - display_name: row.get("m_display_name"), - created_at: row.get("m_created_at"), - updated_at: row.get("m_updated_at"), + id: parse_uuid(&row.col::("m_id"))?, + org_id: parse_uuid(&row.col::("m_org_id"))?, + scim_group_id: row.col("m_scim_group_id"), + team_id: parse_uuid(&row.col::("m_team_id"))?, + display_name: row.col("m_display_name"), + created_at: row.col("m_created_at"), + updated_at: row.col("m_updated_at"), }, team: Team { - id: parse_uuid(&row.get::("t_id"))?, - org_id: parse_uuid(&row.get::("t_org_id"))?, - slug: row.get("t_slug"), - name: row.get("t_name"), - created_at: row.get("t_created_at"), - updated_at: row.get("t_updated_at"), + id: parse_uuid(&row.col::("t_id"))?, + org_id: parse_uuid(&row.col::("t_org_id"))?, + slug: row.col("t_slug"), + name: row.col("t_name"), + created_at: row.col("t_created_at"), + updated_at: row.col("t_updated_at"), }, }) } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl ScimGroupMappingRepo for SqliteScimGroupMappingRepo { async fn create( &self, @@ -74,7 +77,7 @@ impl ScimGroupMappingRepo for SqliteScimGroupMappingRepo { let id = Uuid::new_v4(); let now = chrono::Utc::now(); - sqlx::query( + query( r#" INSERT INTO scim_group_mappings ( id, org_id, scim_group_id, team_id, display_name, created_at, updated_at @@ -91,14 +94,17 @@ impl ScimGroupMappingRepo for SqliteScimGroupMappingRepo { .bind(now) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict("SCIM group ID already mapped in this organization".into()) + .map_err(|e| { + if super::backend::is_unique_violation(&e) { + return DbError::Conflict( + "SCIM group ID already mapped in this organization".into(), + ); } - sqlx::Error::Database(db_err) if db_err.is_foreign_key_violation() => { - DbError::Conflict("Referenced team or organization not found".into()) + #[cfg(feature = "database-sqlite")] + if matches!(&e, sqlx::Error::Database(db_err) if db_err.is_foreign_key_violation()) { + return DbError::Conflict("Referenced team or organization not found".into()); } - _ => DbError::from(e), + DbError::from(e) })?; Ok(ScimGroupMapping { @@ -113,7 +119,7 @@ impl ScimGroupMappingRepo for SqliteScimGroupMappingRepo { } async fn get_by_id(&self, id: Uuid) -> DbResult> { - let row = sqlx::query("SELECT * FROM scim_group_mappings WHERE id = ?") + let row = query("SELECT * FROM scim_group_mappings WHERE id = ?") .bind(id.to_string()) .fetch_optional(&self.pool) .await?; @@ -126,12 +132,11 @@ impl ScimGroupMappingRepo for SqliteScimGroupMappingRepo { org_id: Uuid, scim_group_id: &str, ) -> DbResult> { - let row = - sqlx::query("SELECT * FROM scim_group_mappings WHERE org_id = ? AND scim_group_id = ?") - .bind(org_id.to_string()) - .bind(scim_group_id) - .fetch_optional(&self.pool) - .await?; + let row = query("SELECT * FROM scim_group_mappings WHERE org_id = ? AND scim_group_id = ?") + .bind(org_id.to_string()) + .bind(scim_group_id) + .fetch_optional(&self.pool) + .await?; row.map(|r| Self::parse_mapping(&r)).transpose() } @@ -141,7 +146,7 @@ impl ScimGroupMappingRepo for SqliteScimGroupMappingRepo { org_id: Uuid, team_id: Uuid, ) -> DbResult> { - let row = sqlx::query("SELECT * FROM scim_group_mappings WHERE org_id = ? AND team_id = ?") + let row = query("SELECT * FROM scim_group_mappings WHERE org_id = ? AND team_id = ?") .bind(org_id.to_string()) .bind(team_id.to_string()) .fetch_optional(&self.pool) @@ -162,7 +167,7 @@ impl ScimGroupMappingRepo for SqliteScimGroupMappingRepo { params.sort_order.cursor_query_params(params.direction); let rows = if let Some(cursor) = ¶ms.cursor { - let query = format!( + let sql = format!( r#" SELECT * FROM scim_group_mappings WHERE org_id = ? @@ -172,7 +177,7 @@ impl ScimGroupMappingRepo for SqliteScimGroupMappingRepo { "#, comparison_op, order_dir, order_dir ); - sqlx::query(&query) + query(&sql) .bind(org_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -180,7 +185,7 @@ impl ScimGroupMappingRepo for SqliteScimGroupMappingRepo { .fetch_all(&self.pool) .await? } else { - let query = format!( + let sql = format!( r#" SELECT * FROM scim_group_mappings WHERE org_id = ? @@ -189,7 +194,7 @@ impl ScimGroupMappingRepo for SqliteScimGroupMappingRepo { "#, order_dir, order_dir ); - sqlx::query(&query) + query(&sql) .bind(org_id.to_string()) .bind(fetch_limit) .fetch_all(&self.pool) @@ -282,7 +287,7 @@ impl ScimGroupMappingRepo for SqliteScimGroupMappingRepo { ); // Execute count query - let mut count_query = sqlx::query(&count_sql).bind(org_id.to_string()); + let mut count_query = query(&count_sql).bind(org_id.to_string()); if let Some(f) = filter { for val in &f.bindings { count_query = match val { @@ -293,10 +298,10 @@ impl ScimGroupMappingRepo for SqliteScimGroupMappingRepo { } } let count_row = count_query.fetch_one(&self.pool).await?; - let total: i64 = count_row.get("cnt"); + let total: i64 = count_row.col("cnt"); // Execute data query - let mut data_query = sqlx::query(&data_sql).bind(org_id.to_string()); + let mut data_query = query(&data_sql).bind(org_id.to_string()); if let Some(f) = filter { for val in &f.bindings { data_query = match val { @@ -327,7 +332,7 @@ impl ScimGroupMappingRepo for SqliteScimGroupMappingRepo { }; let now = chrono::Utc::now(); - sqlx::query( + query( "UPDATE scim_group_mappings SET team_id = ?, display_name = ?, updated_at = ? WHERE id = ?", ) .bind(team_id.to_string()) @@ -336,11 +341,13 @@ impl ScimGroupMappingRepo for SqliteScimGroupMappingRepo { .bind(id.to_string()) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_foreign_key_violation() => { - DbError::Conflict("Referenced team not found".into()) + .map_err(|e| { + #[cfg(feature = "database-sqlite")] + if matches!(&e, sqlx::Error::Database(db_err) if db_err.is_foreign_key_violation()) + { + return DbError::Conflict("Referenced team not found".into()); } - _ => DbError::from(e), + DbError::from(e) })?; Ok(ScimGroupMapping { @@ -355,7 +362,7 @@ impl ScimGroupMappingRepo for SqliteScimGroupMappingRepo { } async fn delete(&self, id: Uuid) -> DbResult<()> { - let result = sqlx::query("DELETE FROM scim_group_mappings WHERE id = ?") + let result = query("DELETE FROM scim_group_mappings WHERE id = ?") .bind(id.to_string()) .execute(&self.pool) .await?; @@ -368,7 +375,7 @@ impl ScimGroupMappingRepo for SqliteScimGroupMappingRepo { } async fn delete_by_team(&self, team_id: Uuid) -> DbResult { - let result = sqlx::query("DELETE FROM scim_group_mappings WHERE team_id = ?") + let result = query("DELETE FROM scim_group_mappings WHERE team_id = ?") .bind(team_id.to_string()) .execute(&self.pool) .await?; @@ -377,11 +384,11 @@ impl ScimGroupMappingRepo for SqliteScimGroupMappingRepo { } async fn count_by_org(&self, org_id: Uuid) -> DbResult { - let row = sqlx::query("SELECT COUNT(*) as count FROM scim_group_mappings WHERE org_id = ?") + let row = query("SELECT COUNT(*) as count FROM scim_group_mappings WHERE org_id = ?") .bind(org_id.to_string()) .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } } diff --git a/src/db/sqlite/scim_user_mappings.rs b/src/db/sqlite/scim_user_mappings.rs index ad14b43..2c862de 100644 --- a/src/db/sqlite/scim_user_mappings.rs +++ b/src/db/sqlite/scim_user_mappings.rs @@ -1,10 +1,12 @@ //! SQLite implementation of the SCIM user mapping repository. use async_trait::async_trait; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; -use super::common::parse_uuid; +use super::{ + backend::{Pool, Row, RowExt, query}, + common::parse_uuid, +}; use crate::{ db::{ error::{DbError, DbResult}, @@ -19,52 +21,53 @@ use crate::{ }; pub struct SqliteScimUserMappingRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteScimUserMappingRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } /// Parse a ScimUserMapping from a database row. - fn parse_mapping(row: &sqlx::sqlite::SqliteRow) -> DbResult { + fn parse_mapping(row: &Row) -> DbResult { Ok(ScimUserMapping { - id: parse_uuid(&row.get::("id"))?, - org_id: parse_uuid(&row.get::("org_id"))?, - scim_external_id: row.get("scim_external_id"), - user_id: parse_uuid(&row.get::("user_id"))?, - active: row.get::("active") != 0, - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + org_id: parse_uuid(&row.col::("org_id"))?, + scim_external_id: row.col("scim_external_id"), + user_id: parse_uuid(&row.col::("user_id"))?, + active: row.col::("active") != 0, + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) } /// Parse a ScimUserWithMapping from a joined row with aliased columns. - fn parse_mapping_with_user(row: &sqlx::sqlite::SqliteRow) -> DbResult { + fn parse_mapping_with_user(row: &Row) -> DbResult { Ok(ScimUserWithMapping { mapping: ScimUserMapping { - id: parse_uuid(&row.get::("m_id"))?, - org_id: parse_uuid(&row.get::("m_org_id"))?, - scim_external_id: row.get("m_scim_external_id"), - user_id: parse_uuid(&row.get::("m_user_id"))?, - active: row.get::("m_active") != 0, - created_at: row.get("m_created_at"), - updated_at: row.get("m_updated_at"), + id: parse_uuid(&row.col::("m_id"))?, + org_id: parse_uuid(&row.col::("m_org_id"))?, + scim_external_id: row.col("m_scim_external_id"), + user_id: parse_uuid(&row.col::("m_user_id"))?, + active: row.col::("m_active") != 0, + created_at: row.col("m_created_at"), + updated_at: row.col("m_updated_at"), }, user: User { - id: parse_uuid(&row.get::("u_id"))?, - external_id: row.get("u_external_id"), - email: row.get("u_email"), - name: row.get("u_name"), - created_at: row.get("u_created_at"), - updated_at: row.get("u_updated_at"), + id: parse_uuid(&row.col::("u_id"))?, + external_id: row.col("u_external_id"), + email: row.col("u_email"), + name: row.col("u_name"), + created_at: row.col("u_created_at"), + updated_at: row.col("u_updated_at"), }, }) } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl ScimUserMappingRepo for SqliteScimUserMappingRepo { async fn create( &self, @@ -74,7 +77,7 @@ impl ScimUserMappingRepo for SqliteScimUserMappingRepo { let id = Uuid::new_v4(); let now = chrono::Utc::now(); - sqlx::query( + query( r#" INSERT INTO scim_user_mappings ( id, org_id, scim_external_id, user_id, active, created_at, updated_at @@ -91,14 +94,17 @@ impl ScimUserMappingRepo for SqliteScimUserMappingRepo { .bind(now) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict("SCIM external ID already mapped in this organization".into()) + .map_err(|e| { + if super::backend::is_unique_violation(&e) { + return DbError::Conflict( + "SCIM external ID already mapped in this organization".into(), + ); } - sqlx::Error::Database(db_err) if db_err.is_foreign_key_violation() => { - DbError::Conflict("Referenced user or organization not found".into()) + #[cfg(feature = "database-sqlite")] + if matches!(&e, sqlx::Error::Database(db_err) if db_err.is_foreign_key_violation()) { + return DbError::Conflict("Referenced user or organization not found".into()); } - _ => DbError::from(e), + DbError::from(e) })?; Ok(ScimUserMapping { @@ -113,7 +119,7 @@ impl ScimUserMappingRepo for SqliteScimUserMappingRepo { } async fn get_by_id(&self, id: Uuid) -> DbResult> { - let row = sqlx::query("SELECT * FROM scim_user_mappings WHERE id = ?") + let row = query("SELECT * FROM scim_user_mappings WHERE id = ?") .bind(id.to_string()) .fetch_optional(&self.pool) .await?; @@ -126,13 +132,12 @@ impl ScimUserMappingRepo for SqliteScimUserMappingRepo { org_id: Uuid, scim_external_id: &str, ) -> DbResult> { - let row = sqlx::query( - "SELECT * FROM scim_user_mappings WHERE org_id = ? AND scim_external_id = ?", - ) - .bind(org_id.to_string()) - .bind(scim_external_id) - .fetch_optional(&self.pool) - .await?; + let row = + query("SELECT * FROM scim_user_mappings WHERE org_id = ? AND scim_external_id = ?") + .bind(org_id.to_string()) + .bind(scim_external_id) + .fetch_optional(&self.pool) + .await?; row.map(|r| Self::parse_mapping(&r)).transpose() } @@ -142,7 +147,7 @@ impl ScimUserMappingRepo for SqliteScimUserMappingRepo { org_id: Uuid, user_id: Uuid, ) -> DbResult> { - let row = sqlx::query("SELECT * FROM scim_user_mappings WHERE org_id = ? AND user_id = ?") + let row = query("SELECT * FROM scim_user_mappings WHERE org_id = ? AND user_id = ?") .bind(org_id.to_string()) .bind(user_id.to_string()) .fetch_optional(&self.pool) @@ -163,7 +168,7 @@ impl ScimUserMappingRepo for SqliteScimUserMappingRepo { params.sort_order.cursor_query_params(params.direction); let rows = if let Some(cursor) = ¶ms.cursor { - let query = format!( + let sql = format!( r#" SELECT * FROM scim_user_mappings WHERE org_id = ? @@ -173,7 +178,7 @@ impl ScimUserMappingRepo for SqliteScimUserMappingRepo { "#, comparison_op, order_dir, order_dir ); - sqlx::query(&query) + query(&sql) .bind(org_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -181,7 +186,7 @@ impl ScimUserMappingRepo for SqliteScimUserMappingRepo { .fetch_all(&self.pool) .await? } else { - let query = format!( + let sql = format!( r#" SELECT * FROM scim_user_mappings WHERE org_id = ? @@ -190,7 +195,7 @@ impl ScimUserMappingRepo for SqliteScimUserMappingRepo { "#, order_dir, order_dir ); - sqlx::query(&query) + query(&sql) .bind(org_id.to_string()) .bind(fetch_limit) .fetch_all(&self.pool) @@ -283,7 +288,7 @@ impl ScimUserMappingRepo for SqliteScimUserMappingRepo { ); // Execute count query - let mut count_query = sqlx::query(&count_sql).bind(org_id.to_string()); + let mut count_query = query(&count_sql).bind(org_id.to_string()); if let Some(f) = filter { for val in &f.bindings { count_query = match val { @@ -294,10 +299,10 @@ impl ScimUserMappingRepo for SqliteScimUserMappingRepo { } } let count_row = count_query.fetch_one(&self.pool).await?; - let total: i64 = count_row.get("cnt"); + let total: i64 = count_row.col("cnt"); // Execute data query - let mut data_query = sqlx::query(&data_sql).bind(org_id.to_string()); + let mut data_query = query(&data_sql).bind(org_id.to_string()); if let Some(f) = filter { for val in &f.bindings { data_query = match val { @@ -319,7 +324,7 @@ impl ScimUserMappingRepo for SqliteScimUserMappingRepo { } async fn list_by_user(&self, user_id: Uuid) -> DbResult> { - let rows = sqlx::query("SELECT * FROM scim_user_mappings WHERE user_id = ?") + let rows = query("SELECT * FROM scim_user_mappings WHERE user_id = ?") .bind(user_id.to_string()) .fetch_all(&self.pool) .await?; @@ -333,7 +338,7 @@ impl ScimUserMappingRepo for SqliteScimUserMappingRepo { let active = input.active.unwrap_or(current.active); let now = chrono::Utc::now(); - sqlx::query("UPDATE scim_user_mappings SET active = ?, updated_at = ? WHERE id = ?") + query("UPDATE scim_user_mappings SET active = ?, updated_at = ? WHERE id = ?") .bind(active as i32) .bind(now) .bind(id.to_string()) @@ -362,7 +367,7 @@ impl ScimUserMappingRepo for SqliteScimUserMappingRepo { } async fn delete(&self, id: Uuid) -> DbResult<()> { - let result = sqlx::query("DELETE FROM scim_user_mappings WHERE id = ?") + let result = query("DELETE FROM scim_user_mappings WHERE id = ?") .bind(id.to_string()) .execute(&self.pool) .await?; @@ -375,7 +380,7 @@ impl ScimUserMappingRepo for SqliteScimUserMappingRepo { } async fn delete_by_user(&self, user_id: Uuid) -> DbResult { - let result = sqlx::query("DELETE FROM scim_user_mappings WHERE user_id = ?") + let result = query("DELETE FROM scim_user_mappings WHERE user_id = ?") .bind(user_id.to_string()) .execute(&self.pool) .await?; @@ -384,11 +389,11 @@ impl ScimUserMappingRepo for SqliteScimUserMappingRepo { } async fn count_by_org(&self, org_id: Uuid) -> DbResult { - let row = sqlx::query("SELECT COUNT(*) as count FROM scim_user_mappings WHERE org_id = ?") + let row = query("SELECT COUNT(*) as count FROM scim_user_mappings WHERE org_id = ?") .bind(org_id.to_string()) .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } } diff --git a/src/db/sqlite/service_accounts.rs b/src/db/sqlite/service_accounts.rs index 5943e68..4841ab1 100644 --- a/src/db/sqlite/service_accounts.rs +++ b/src/db/sqlite/service_accounts.rs @@ -1,9 +1,11 @@ use async_trait::async_trait; use chrono::SubsecRound; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; -use super::common::parse_uuid; +use super::{ + backend::{Pool, RowExt, begin, map_unique_violation, query, query_scalar}, + common::parse_uuid, +}; use crate::{ db::{ error::{DbError, DbResult}, @@ -16,11 +18,11 @@ use crate::{ }; pub struct SqliteServiceAccountRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteServiceAccountRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } @@ -46,7 +48,7 @@ impl SqliteServiceAccountRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT id, org_id, slug, name, description, roles, created_at, updated_at FROM service_accounts @@ -57,7 +59,7 @@ impl SqliteServiceAccountRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(org_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -71,14 +73,14 @@ impl SqliteServiceAccountRepo { .take(limit as usize) .map(|row| { Ok(ServiceAccount { - id: parse_uuid(&row.get::("id"))?, - org_id: parse_uuid(&row.get::("org_id"))?, - slug: row.get("slug"), - name: row.get("name"), - description: row.get("description"), - roles: Self::parse_roles(&row.get::("roles")), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + org_id: parse_uuid(&row.col::("org_id"))?, + slug: row.col("slug"), + name: row.col("name"), + description: row.col("description"), + roles: Self::parse_roles(&row.col::("roles")), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) }) .collect::>>()?; @@ -96,14 +98,15 @@ impl SqliteServiceAccountRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl ServiceAccountRepo for SqliteServiceAccountRepo { async fn create(&self, org_id: Uuid, input: CreateServiceAccount) -> DbResult { let id = Uuid::new_v4(); let now = chrono::Utc::now().trunc_subsecs(3); let roles_json = Self::serialize_roles(&input.roles); - sqlx::query( + query( r#" INSERT INTO service_accounts (id, org_id, slug, name, description, roles, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?) @@ -119,15 +122,10 @@ impl ServiceAccountRepo for SqliteServiceAccountRepo { .bind(now) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict(format!( - "Service account with slug '{}' already exists in this organization", - input.slug - )) - } - _ => DbError::from(e), - })?; + .map_err(map_unique_violation(format!( + "Service account with slug '{}' already exists in this organization", + input.slug + )))?; Ok(ServiceAccount { id, @@ -142,7 +140,7 @@ impl ServiceAccountRepo for SqliteServiceAccountRepo { } async fn get_by_id(&self, id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, org_id, slug, name, description, roles, created_at, updated_at FROM service_accounts @@ -155,21 +153,21 @@ impl ServiceAccountRepo for SqliteServiceAccountRepo { match result { Some(row) => Ok(Some(ServiceAccount { - id: parse_uuid(&row.get::("id"))?, - org_id: parse_uuid(&row.get::("org_id"))?, - slug: row.get("slug"), - name: row.get("name"), - description: row.get("description"), - roles: Self::parse_roles(&row.get::("roles")), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + org_id: parse_uuid(&row.col::("org_id"))?, + slug: row.col("slug"), + name: row.col("name"), + description: row.col("description"), + roles: Self::parse_roles(&row.col::("roles")), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), })), None => Ok(None), } } async fn get_by_slug(&self, org_id: Uuid, slug: &str) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, org_id, slug, name, description, roles, created_at, updated_at FROM service_accounts @@ -183,14 +181,14 @@ impl ServiceAccountRepo for SqliteServiceAccountRepo { match result { Some(row) => Ok(Some(ServiceAccount { - id: parse_uuid(&row.get::("id"))?, - org_id: parse_uuid(&row.get::("org_id"))?, - slug: row.get("slug"), - name: row.get("name"), - description: row.get("description"), - roles: Self::parse_roles(&row.get::("roles")), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + org_id: parse_uuid(&row.col::("org_id"))?, + slug: row.col("slug"), + name: row.col("name"), + description: row.col("description"), + roles: Self::parse_roles(&row.col::("roles")), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), })), None => Ok(None), } @@ -210,7 +208,7 @@ impl ServiceAccountRepo for SqliteServiceAccountRepo { .await; } - let rows = sqlx::query( + let rows = query( r#" SELECT id, org_id, slug, name, description, roles, created_at, updated_at FROM service_accounts @@ -230,14 +228,14 @@ impl ServiceAccountRepo for SqliteServiceAccountRepo { .take(limit as usize) .map(|row| { Ok(ServiceAccount { - id: parse_uuid(&row.get::("id"))?, - org_id: parse_uuid(&row.get::("org_id"))?, - slug: row.get("slug"), - name: row.get("name"), - description: row.get("description"), - roles: Self::parse_roles(&row.get::("roles")), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + org_id: parse_uuid(&row.col::("org_id"))?, + slug: row.col("slug"), + name: row.col("name"), + description: row.col("description"), + roles: Self::parse_roles(&row.col::("roles")), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) }) .collect::>>()?; @@ -251,13 +249,13 @@ impl ServiceAccountRepo for SqliteServiceAccountRepo { } async fn count_by_org(&self, org_id: Uuid) -> DbResult { - let row = sqlx::query( + let row = query( "SELECT COUNT(*) as count FROM service_accounts WHERE org_id = ? AND deleted_at IS NULL", ) .bind(org_id.to_string()) .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn update(&self, id: Uuid, input: UpdateServiceAccount) -> DbResult { @@ -270,7 +268,7 @@ impl ServiceAccountRepo for SqliteServiceAccountRepo { let new_roles = input.roles.unwrap_or(existing.roles); let roles_json = Self::serialize_roles(&new_roles); - let result = sqlx::query( + let result = query( r#" UPDATE service_accounts SET name = ?, description = ?, roles = ?, updated_at = ? @@ -305,7 +303,7 @@ impl ServiceAccountRepo for SqliteServiceAccountRepo { async fn delete(&self, id: Uuid) -> DbResult<()> { let now = chrono::Utc::now().trunc_subsecs(3); - let result = sqlx::query( + let result = query( r#" UPDATE service_accounts SET deleted_at = ? @@ -325,69 +323,46 @@ impl ServiceAccountRepo for SqliteServiceAccountRepo { } async fn delete_with_api_key_revocation(&self, id: Uuid) -> DbResult> { - // SQLite IMMEDIATE transactions provide write locking, preventing race conditions - // where new API keys could be created between deleting the SA and revoking its keys. - // Note: sqlx uses IMMEDIATE mode for write transactions on SQLite by default. - let mut tx = self.pool.begin().await?; let now = chrono::Utc::now().trunc_subsecs(3); - // 1. Check if the service account exists (locks the row in SQLite's transaction) - let exists = sqlx::query( - r#" - SELECT id FROM service_accounts - WHERE id = ? AND deleted_at IS NULL - "#, - ) - .bind(id.to_string()) - .fetch_optional(&mut *tx) - .await?; + let mut tx = begin(&self.pool).await?; + + let exists = + query(r#"SELECT id FROM service_accounts WHERE id = ? AND deleted_at IS NULL"#) + .bind(id.to_string()) + .fetch_optional(&mut *tx) + .await?; if exists.is_none() { return Err(DbError::NotFound); } - // 2. Get API key IDs before revoking (SQLite RETURNING is available since 3.35) - let revoked_ids: Vec = sqlx::query_scalar( - r#" - UPDATE api_keys - SET revoked_at = ?, updated_at = ? - WHERE owner_type = 'service_account' AND owner_id = ? AND revoked_at IS NULL - RETURNING id - "#, + let revoked_ids: Vec = query_scalar( + r#"UPDATE api_keys SET revoked_at = ?, updated_at = ? WHERE owner_type = 'service_account' AND owner_id = ? AND revoked_at IS NULL RETURNING id"#, ) - .bind(now) - .bind(now) - .bind(id.to_string()) - .fetch_all(&mut *tx) - .await?; + .bind(now).bind(now).bind(id.to_string()) + .fetch_all(&mut *tx).await?; - // 3. Soft-delete the service account - sqlx::query( - r#" - UPDATE service_accounts - SET deleted_at = ? - WHERE id = ? - "#, - ) - .bind(now) - .bind(id.to_string()) - .execute(&mut *tx) - .await?; + query(r#"UPDATE service_accounts SET deleted_at = ? WHERE id = ?"#) + .bind(now) + .bind(id.to_string()) + .execute(&mut *tx) + .await?; tx.commit().await?; - // Convert string IDs to UUIDs - let revoked_uuids: Vec = revoked_ids + let revoked_uuids = revoked_ids .into_iter() .filter_map(|s| parse_uuid(&s).ok()) .collect(); - Ok(revoked_uuids) } } #[cfg(test)] mod tests { + use sqlx::SqlitePool; + use super::*; use crate::db::repos::ServiceAccountRepo; diff --git a/src/db/sqlite/sso_group_mappings.rs b/src/db/sqlite/sso_group_mappings.rs index 52c8279..ef9195a 100644 --- a/src/db/sqlite/sso_group_mappings.rs +++ b/src/db/sqlite/sso_group_mappings.rs @@ -1,8 +1,10 @@ use async_trait::async_trait; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; -use super::common::parse_uuid; +use super::{ + backend::{Pool, Row, RowExt, map_unique_violation, query}, + common::parse_uuid, +}; use crate::{ db::{ error::{DbError, DbResult}, @@ -15,29 +17,29 @@ use crate::{ }; pub struct SqliteSsoGroupMappingRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteSsoGroupMappingRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } /// Parse an SsoGroupMapping from a database row. - fn parse_mapping(row: &sqlx::sqlite::SqliteRow) -> DbResult { - let team_id: Option = row.get("team_id"); + fn parse_mapping(row: &Row) -> DbResult { + let team_id: Option = row.col("team_id"); let team_id = team_id.map(|s| parse_uuid(&s)).transpose()?; Ok(SsoGroupMapping { - id: parse_uuid(&row.get::("id"))?, - sso_connection_name: row.get("sso_connection_name"), - idp_group: row.get("idp_group"), - org_id: parse_uuid(&row.get::("org_id"))?, + id: parse_uuid(&row.col::("id"))?, + sso_connection_name: row.col("sso_connection_name"), + idp_group: row.col("idp_group"), + org_id: parse_uuid(&row.col::("org_id"))?, team_id, - role: row.get("role"), - priority: row.get("priority"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + role: row.col("role"), + priority: row.col("priority"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) } @@ -60,7 +62,7 @@ impl SqliteSsoGroupMappingRepo { "" }; - let query = format!( + let sql = format!( r#" SELECT id, sso_connection_name, idp_group, org_id, team_id, role, priority, created_at, updated_at FROM sso_group_mappings @@ -72,7 +74,7 @@ impl SqliteSsoGroupMappingRepo { comparison, connection_filter, order, order ); - let mut query_builder = sqlx::query(&query) + let mut query_builder = query(&sql) .bind(org_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()); @@ -106,7 +108,8 @@ impl SqliteSsoGroupMappingRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl SsoGroupMappingRepo for SqliteSsoGroupMappingRepo { async fn create( &self, @@ -116,7 +119,7 @@ impl SsoGroupMappingRepo for SqliteSsoGroupMappingRepo { let id = Uuid::new_v4(); let now = chrono::Utc::now(); - sqlx::query( + query( r#" INSERT INTO sso_group_mappings (id, sso_connection_name, idp_group, org_id, team_id, role, priority, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) @@ -133,15 +136,10 @@ impl SsoGroupMappingRepo for SqliteSsoGroupMappingRepo { .bind(now) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict(format!( - "Mapping for IdP group '{}' already exists for this connection/org/team combination", - input.idp_group - )) - } - _ => DbError::from(e), - })?; + .map_err(map_unique_violation(format!( + "Mapping for IdP group '{}' already exists for this connection/org/team combination", + input.idp_group + )))?; Ok(SsoGroupMapping { id, @@ -157,7 +155,7 @@ impl SsoGroupMappingRepo for SqliteSsoGroupMappingRepo { } async fn get_by_id(&self, id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, sso_connection_name, idp_group, org_id, team_id, role, priority, created_at, updated_at FROM sso_group_mappings @@ -188,7 +186,7 @@ impl SsoGroupMappingRepo for SqliteSsoGroupMappingRepo { .await; } - let rows = sqlx::query( + let rows = query( r#" SELECT id, sso_connection_name, idp_group, org_id, team_id, role, priority, created_at, updated_at FROM sso_group_mappings @@ -239,7 +237,7 @@ impl SsoGroupMappingRepo for SqliteSsoGroupMappingRepo { .await; } - let rows = sqlx::query( + let rows = query( r#" SELECT id, sso_connection_name, idp_group, org_id, team_id, role, priority, created_at, updated_at FROM sso_group_mappings @@ -283,7 +281,7 @@ impl SsoGroupMappingRepo for SqliteSsoGroupMappingRepo { let placeholders: Vec<&str> = idp_groups.iter().map(|_| "?").collect(); let in_clause = placeholders.join(", "); - let query = format!( + let sql = format!( r#" SELECT id, sso_connection_name, idp_group, org_id, team_id, role, priority, created_at, updated_at FROM sso_group_mappings @@ -293,7 +291,7 @@ impl SsoGroupMappingRepo for SqliteSsoGroupMappingRepo { in_clause ); - let mut query_builder = sqlx::query(&query) + let mut query_builder = query(&sql) .bind(sso_connection_name) .bind(org_id.to_string()); @@ -307,12 +305,12 @@ impl SsoGroupMappingRepo for SqliteSsoGroupMappingRepo { } async fn count_by_org(&self, org_id: Uuid) -> DbResult { - let row = sqlx::query("SELECT COUNT(*) as count FROM sso_group_mappings WHERE org_id = ?") + let row = query("SELECT COUNT(*) as count FROM sso_group_mappings WHERE org_id = ?") .bind(org_id.to_string()) .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn update(&self, id: Uuid, input: UpdateSsoGroupMapping) -> DbResult { @@ -341,12 +339,12 @@ impl SsoGroupMappingRepo for SqliteSsoGroupMappingRepo { set_clauses.push("priority = ?"); } - let query = format!( + let sql = format!( "UPDATE sso_group_mappings SET {} WHERE id = ?", set_clauses.join(", ") ); - let mut query_builder = sqlx::query(&query).bind(now); + let mut query_builder = query(&sql).bind(now); if let Some(ref idp_group) = input.idp_group { query_builder = query_builder.bind(idp_group); @@ -365,12 +363,9 @@ impl SsoGroupMappingRepo for SqliteSsoGroupMappingRepo { .bind(id.to_string()) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict("Mapping with this combination already exists".into()) - } - _ => DbError::from(e), - })?; + .map_err(map_unique_violation( + "Mapping with this combination already exists", + ))?; if result.rows_affected() == 0 { return Err(DbError::NotFound); @@ -380,7 +375,7 @@ impl SsoGroupMappingRepo for SqliteSsoGroupMappingRepo { } async fn delete(&self, id: Uuid) -> DbResult<()> { - let result = sqlx::query("DELETE FROM sso_group_mappings WHERE id = ?") + let result = query("DELETE FROM sso_group_mappings WHERE id = ?") .bind(id.to_string()) .execute(&self.pool) .await?; @@ -398,7 +393,7 @@ impl SsoGroupMappingRepo for SqliteSsoGroupMappingRepo { org_id: Uuid, idp_group: &str, ) -> DbResult { - let result = sqlx::query( + let result = query( "DELETE FROM sso_group_mappings WHERE sso_connection_name = ? AND org_id = ? AND idp_group = ?", ) .bind(sso_connection_name) @@ -413,6 +408,8 @@ impl SsoGroupMappingRepo for SqliteSsoGroupMappingRepo { #[cfg(test)] mod tests { + use sqlx::SqlitePool; + use super::*; async fn create_test_pool() -> SqlitePool { diff --git a/src/db/sqlite/teams.rs b/src/db/sqlite/teams.rs index afdd555..61ddf91 100644 --- a/src/db/sqlite/teams.rs +++ b/src/db/sqlite/teams.rs @@ -1,8 +1,10 @@ use async_trait::async_trait; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; -use super::common::parse_uuid; +use super::{ + backend::{Pool, RowExt, map_unique_violation, query}, + common::parse_uuid, +}; use crate::{ db::{ error::{DbError, DbResult}, @@ -16,11 +18,11 @@ use crate::{ }; pub struct SqliteTeamRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteTeamRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } @@ -42,7 +44,7 @@ impl SqliteTeamRepo { "AND deleted_at IS NULL" }; - let query = format!( + let sql = format!( r#" SELECT id, org_id, slug, name, created_at, updated_at FROM teams @@ -54,7 +56,7 @@ impl SqliteTeamRepo { comparison, deleted_filter, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(org_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -68,12 +70,12 @@ impl SqliteTeamRepo { .take(limit as usize) .map(|row| { Ok(Team { - id: parse_uuid(&row.get::("id"))?, - org_id: parse_uuid(&row.get::("org_id"))?, - slug: row.get("slug"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + org_id: parse_uuid(&row.col::("org_id"))?, + slug: row.col("slug"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) }) .collect::>>()?; @@ -102,7 +104,7 @@ impl SqliteTeamRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT u.id as user_id, u.external_id, u.email, u.name, tm.role, tm.created_at as joined_at @@ -116,7 +118,7 @@ impl SqliteTeamRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(team_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -130,12 +132,12 @@ impl SqliteTeamRepo { .take(limit as usize) .map(|row| { Ok(TeamMember { - user_id: parse_uuid(&row.get::("user_id"))?, - external_id: row.get("external_id"), - email: row.get("email"), - name: row.get("name"), - role: row.get("role"), - joined_at: row.get("joined_at"), + user_id: parse_uuid(&row.col::("user_id"))?, + external_id: row.col("external_id"), + email: row.col("email"), + name: row.col("name"), + role: row.col("role"), + joined_at: row.col("joined_at"), }) }) .collect::>>()?; @@ -153,13 +155,14 @@ impl SqliteTeamRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl TeamRepo for SqliteTeamRepo { async fn create(&self, org_id: Uuid, input: CreateTeam) -> DbResult { let id = Uuid::new_v4(); let now = chrono::Utc::now(); - sqlx::query( + query( r#" INSERT INTO teams (id, org_id, slug, name, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?) @@ -173,15 +176,10 @@ impl TeamRepo for SqliteTeamRepo { .bind(now) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict(format!( - "Team with slug '{}' already exists in this organization", - input.slug - )) - } - _ => DbError::from(e), - })?; + .map_err(map_unique_violation(format!( + "Team with slug '{}' already exists in this organization", + input.slug + )))?; Ok(Team { id, @@ -194,7 +192,7 @@ impl TeamRepo for SqliteTeamRepo { } async fn get_by_id(&self, id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, org_id, slug, name, created_at, updated_at FROM teams @@ -207,12 +205,12 @@ impl TeamRepo for SqliteTeamRepo { match result { Some(row) => Ok(Some(Team { - id: parse_uuid(&row.get::("id"))?, - org_id: parse_uuid(&row.get::("org_id"))?, - slug: row.get("slug"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + org_id: parse_uuid(&row.col::("org_id"))?, + slug: row.col("slug"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), })), None => Ok(None), } @@ -225,7 +223,7 @@ impl TeamRepo for SqliteTeamRepo { // Build query with placeholders for each ID let placeholders = ids.iter().map(|_| "?").collect::>().join(","); - let query = format!( + let sql = format!( r#" SELECT id, org_id, slug, name, created_at, updated_at FROM teams @@ -234,7 +232,7 @@ impl TeamRepo for SqliteTeamRepo { placeholders ); - let mut query_builder = sqlx::query(&query); + let mut query_builder = query(&sql); for id in ids { query_builder = query_builder.bind(id.to_string()); } @@ -244,19 +242,19 @@ impl TeamRepo for SqliteTeamRepo { rows.into_iter() .map(|row| { Ok(Team { - id: parse_uuid(&row.get::("id"))?, - org_id: parse_uuid(&row.get::("org_id"))?, - slug: row.get("slug"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + org_id: parse_uuid(&row.col::("org_id"))?, + slug: row.col("slug"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) }) .collect() } async fn get_by_slug(&self, org_id: Uuid, slug: &str) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, org_id, slug, name, created_at, updated_at FROM teams @@ -270,12 +268,12 @@ impl TeamRepo for SqliteTeamRepo { match result { Some(row) => Ok(Some(Team { - id: parse_uuid(&row.get::("id"))?, - org_id: parse_uuid(&row.get::("org_id"))?, - slug: row.get("slug"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + org_id: parse_uuid(&row.col::("org_id"))?, + slug: row.col("slug"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), })), None => Ok(None), } @@ -291,7 +289,7 @@ impl TeamRepo for SqliteTeamRepo { .await; } - let query = if params.include_deleted { + let sql = if params.include_deleted { r#" SELECT id, org_id, slug, name, created_at, updated_at FROM teams @@ -309,7 +307,7 @@ impl TeamRepo for SqliteTeamRepo { "# }; - let rows = sqlx::query(query) + let rows = query(sql) .bind(org_id.to_string()) .bind(fetch_limit) .fetch_all(&self.pool) @@ -321,12 +319,12 @@ impl TeamRepo for SqliteTeamRepo { .take(limit as usize) .map(|row| { Ok(Team { - id: parse_uuid(&row.get::("id"))?, - org_id: parse_uuid(&row.get::("org_id"))?, - slug: row.get("slug"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + org_id: parse_uuid(&row.col::("org_id"))?, + slug: row.col("slug"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) }) .collect::>>()?; @@ -340,24 +338,24 @@ impl TeamRepo for SqliteTeamRepo { } async fn count_by_org(&self, org_id: Uuid, include_deleted: bool) -> DbResult { - let query = if include_deleted { + let sql = if include_deleted { "SELECT COUNT(*) as count FROM teams WHERE org_id = ?" } else { "SELECT COUNT(*) as count FROM teams WHERE org_id = ? AND deleted_at IS NULL" }; - let row = sqlx::query(query) + let row = query(sql) .bind(org_id.to_string()) .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn update(&self, id: Uuid, input: UpdateTeam) -> DbResult { if let Some(name) = input.name { let now = chrono::Utc::now(); - let result = sqlx::query( + let result = query( r#" UPDATE teams SET name = ?, updated_at = ? @@ -383,7 +381,7 @@ impl TeamRepo for SqliteTeamRepo { async fn delete(&self, id: Uuid) -> DbResult<()> { let now = chrono::Utc::now(); - let result = sqlx::query( + let result = query( r#" UPDATE teams SET deleted_at = ? @@ -409,7 +407,7 @@ impl TeamRepo for SqliteTeamRepo { async fn add_member(&self, team_id: Uuid, input: AddTeamMember) -> DbResult { let now = chrono::Utc::now(); - sqlx::query( + query( r#" INSERT INTO team_memberships (team_id, user_id, role, source, created_at) VALUES (?, ?, ?, ?, ?) @@ -422,12 +420,9 @@ impl TeamRepo for SqliteTeamRepo { .bind(now) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict("User is already a member of this team".to_string()) - } - _ => DbError::from(e), - })?; + .map_err(map_unique_violation( + "User is already a member of this team".to_string(), + ))?; // Fetch the user details to return a complete TeamMember self.get_member(team_id, input.user_id) @@ -436,7 +431,7 @@ impl TeamRepo for SqliteTeamRepo { } async fn remove_member(&self, team_id: Uuid, user_id: Uuid) -> DbResult<()> { - let result = sqlx::query( + let result = query( r#" DELETE FROM team_memberships WHERE team_id = ? AND user_id = ? @@ -461,7 +456,7 @@ impl TeamRepo for SqliteTeamRepo { except_team_ids: &[Uuid], ) -> DbResult { let result = if except_team_ids.is_empty() { - sqlx::query( + query( r#" DELETE FROM team_memberships WHERE user_id = ? AND source = ? @@ -477,16 +472,14 @@ impl TeamRepo for SqliteTeamRepo { .map(|_| "?") .collect::>() .join(","); - let query = format!( + let sql = format!( r#" DELETE FROM team_memberships WHERE user_id = ? AND source = ? AND team_id NOT IN ({}) "#, placeholders ); - let mut q = sqlx::query(&query) - .bind(user_id.to_string()) - .bind(source.as_str()); + let mut q = query(&sql).bind(user_id.to_string()).bind(source.as_str()); for id in except_team_ids { q = q.bind(id.to_string()); } @@ -502,7 +495,7 @@ impl TeamRepo for SqliteTeamRepo { user_id: Uuid, input: UpdateTeamMember, ) -> DbResult { - let result = sqlx::query( + let result = query( r#" UPDATE team_memberships SET role = ? @@ -538,7 +531,7 @@ impl TeamRepo for SqliteTeamRepo { .await; } - let rows = sqlx::query( + let rows = query( r#" SELECT u.id as user_id, u.external_id, u.email, u.name, tm.role, tm.created_at as joined_at @@ -560,12 +553,12 @@ impl TeamRepo for SqliteTeamRepo { .take(limit as usize) .map(|row| { Ok(TeamMember { - user_id: parse_uuid(&row.get::("user_id"))?, - external_id: row.get("external_id"), - email: row.get("email"), - name: row.get("name"), - role: row.get("role"), - joined_at: row.get("joined_at"), + user_id: parse_uuid(&row.col::("user_id"))?, + external_id: row.col("external_id"), + email: row.col("email"), + name: row.col("name"), + role: row.col("role"), + joined_at: row.col("joined_at"), }) }) .collect::>>()?; @@ -579,7 +572,7 @@ impl TeamRepo for SqliteTeamRepo { } async fn get_member(&self, team_id: Uuid, user_id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT u.id as user_id, u.external_id, u.email, u.name, tm.role, tm.created_at as joined_at @@ -595,19 +588,19 @@ impl TeamRepo for SqliteTeamRepo { match result { Some(row) => Ok(Some(TeamMember { - user_id: parse_uuid(&row.get::("user_id"))?, - external_id: row.get("external_id"), - email: row.get("email"), - name: row.get("name"), - role: row.get("role"), - joined_at: row.get("joined_at"), + user_id: parse_uuid(&row.col::("user_id"))?, + external_id: row.col("external_id"), + email: row.col("email"), + name: row.col("name"), + role: row.col("role"), + joined_at: row.col("joined_at"), })), None => Ok(None), } } async fn is_member(&self, team_id: Uuid, user_id: Uuid) -> DbResult { - let row = sqlx::query( + let row = query( r#" SELECT COUNT(*) as count FROM team_memberships @@ -619,20 +612,22 @@ impl TeamRepo for SqliteTeamRepo { .fetch_one(&self.pool) .await?; - Ok(row.get::("count") > 0) + Ok(row.col::("count") > 0) } async fn count_members(&self, team_id: Uuid) -> DbResult { - let row = sqlx::query("SELECT COUNT(*) as count FROM team_memberships WHERE team_id = ?") + let row = query("SELECT COUNT(*) as count FROM team_memberships WHERE team_id = ?") .bind(team_id.to_string()) .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } } #[cfg(test)] mod tests { + use sqlx::SqlitePool; + use super::*; use crate::db::repos::TeamRepo; diff --git a/src/db/sqlite/usage.rs b/src/db/sqlite/usage.rs index 6e5c04e..67597cf 100644 --- a/src/db/sqlite/usage.rs +++ b/src/db/sqlite/usage.rs @@ -1,8 +1,8 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; +use super::backend::{Pool, RowExt, begin, query}; use crate::{ db::{ error::DbResult, @@ -23,31 +23,32 @@ const MEDIA_AGGREGATE_COLS: &str = "\ COALESCE(SUM(character_count), 0) as character_count"; pub struct SqliteUsageRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteUsageRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } - fn media_fields(row: &sqlx::sqlite::SqliteRow) -> (i64, i64, i64) { + fn media_fields(row: &super::backend::Row) -> (i64, i64, i64) { ( - row.get("image_count"), - row.get("audio_seconds"), - row.get("character_count"), + row.col("image_count"), + row.col("audio_seconds"), + row.col("character_count"), ) } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl UsageRepo for SqliteUsageRepo { async fn log(&self, entry: UsageLogEntry) -> DbResult<()> { let id = Uuid::new_v4(); let total_tokens = entry.input_tokens + entry.output_tokens; // Use INSERT OR IGNORE for idempotency - duplicate request_ids are silently skipped - sqlx::query( + query( r#" INSERT OR IGNORE INTO usage_records ( id, request_id, api_key_id, user_id, org_id, project_id, team_id, @@ -106,20 +107,15 @@ impl UsageRepo for SqliteUsageRepo { let mut total_inserted = 0; - // Wrap all chunks in a single transaction for atomicity. - // On failure, the caller can safely retry the entire batch since - // INSERT OR IGNORE makes re-insertion idempotent. - let mut tx = self.pool.begin().await?; + let mut tx = begin(&self.pool).await?; - // Process in chunks to stay within SQLite's parameter limit for chunk in entries.chunks(MAX_ENTRIES_PER_BATCH) { - // Build dynamic multi-row INSERT query let placeholders: Vec<&str> = chunk .iter() .map(|_| "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)") .collect(); - let query = format!( + let sql = format!( r#" INSERT OR IGNORE INTO usage_records ( id, request_id, api_key_id, user_id, org_id, project_id, team_id, @@ -134,7 +130,7 @@ impl UsageRepo for SqliteUsageRepo { placeholders.join(", ") ); - let mut query_builder = sqlx::query(&query); + let mut query_builder = query(&sql); for entry in chunk { let id = Uuid::new_v4(); @@ -176,12 +172,13 @@ impl UsageRepo for SqliteUsageRepo { } tx.commit().await?; + Ok(total_inserted) } async fn get_summary(&self, api_key_id: Uuid, range: DateRange) -> DbResult { // Use range query instead of date casting to allow index usage on recorded_at - let row = sqlx::query(&format!( + let row = query(&format!( r#" SELECT COALESCE(SUM(cost_microcents), 0) as total_cost_microcents, @@ -206,13 +203,13 @@ impl UsageRepo for SqliteUsageRepo { let (image_count, audio_seconds, character_count) = Self::media_fields(&row); Ok(UsageSummary { - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), - first_request_at: row.get("first_request_at"), - last_request_at: row.get("last_request_at"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), + first_request_at: row.col("first_request_at"), + last_request_at: row.col("last_request_at"), image_count, audio_seconds, character_count, @@ -221,7 +218,7 @@ impl UsageRepo for SqliteUsageRepo { async fn get_by_date(&self, api_key_id: Uuid, range: DateRange) -> DbResult> { // Use range query in WHERE for index usage; date cast only needed in SELECT/GROUP BY - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -250,12 +247,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailySpend { - date: row.get("date"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -265,7 +262,7 @@ impl UsageRepo for SqliteUsageRepo { } async fn get_by_model(&self, api_key_id: Uuid, range: DateRange) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT model, @@ -294,12 +291,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); ModelSpend { - model: row.get("model"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + model: row.col("model"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -313,7 +310,7 @@ impl UsageRepo for SqliteUsageRepo { api_key_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT http_referer as referer, @@ -342,12 +339,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); RefererSpend { - referer: row.get("referer"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + referer: row.col("referer"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -359,7 +356,7 @@ impl UsageRepo for SqliteUsageRepo { async fn get_usage_stats(&self, api_key_id: Uuid, range: DateRange) -> DbResult { // Get daily totals first, then compute stats in Rust // This avoids SQLite's lack of native STDDEV function - let rows = sqlx::query( + let rows = query( r#" SELECT COALESCE(SUM(cost_microcents), 0) as daily_cost @@ -381,7 +378,7 @@ impl UsageRepo for SqliteUsageRepo { async fn get_current_period_spend(&self, api_key_id: Uuid, period: &str) -> DbResult { // Use range queries to allow index usage on recorded_at - let query = match period { + let sql = match period { "daily" => { r#" SELECT COALESCE(SUM(cost_microcents), 0) as total @@ -405,12 +402,12 @@ impl UsageRepo for SqliteUsageRepo { } }; - let row = sqlx::query(query) + let row = query(sql) .bind(api_key_id.to_string()) .fetch_one(&self.pool) .await?; - Ok(row.get("total")) + Ok(row.col("total")) } // ==================== Aggregated Usage Queries ==================== @@ -420,7 +417,7 @@ impl UsageRepo for SqliteUsageRepo { org_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -449,12 +446,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailySpend { - date: row.get("date"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -468,7 +465,7 @@ impl UsageRepo for SqliteUsageRepo { project_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -497,12 +494,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailySpend { - date: row.get("date"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -516,7 +513,7 @@ impl UsageRepo for SqliteUsageRepo { user_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -545,12 +542,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailySpend { - date: row.get("date"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -564,7 +561,7 @@ impl UsageRepo for SqliteUsageRepo { team_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -593,12 +590,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailySpend { - date: row.get("date"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -612,7 +609,7 @@ impl UsageRepo for SqliteUsageRepo { provider: &str, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -641,12 +638,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailySpend { - date: row.get("date"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -660,7 +657,7 @@ impl UsageRepo for SqliteUsageRepo { provider: &str, range: DateRange, ) -> DbResult { - let row = sqlx::query(&format!( + let row = query(&format!( r#" SELECT COALESCE(SUM(cost_microcents), 0) as total_cost_microcents, @@ -685,13 +682,13 @@ impl UsageRepo for SqliteUsageRepo { let (image_count, audio_seconds, character_count) = Self::media_fields(&row); Ok(UsageSummary { - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), - first_request_at: row.get("first_request_at"), - last_request_at: row.get("last_request_at"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), + first_request_at: row.col("first_request_at"), + last_request_at: row.col("last_request_at"), image_count, audio_seconds, character_count, @@ -703,7 +700,7 @@ impl UsageRepo for SqliteUsageRepo { provider: &str, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT model, @@ -732,12 +729,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); ModelSpend { - model: row.get("model"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + model: row.col("model"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -751,7 +748,7 @@ impl UsageRepo for SqliteUsageRepo { provider: &str, range: DateRange, ) -> DbResult { - let rows = sqlx::query( + let rows = query( r#" SELECT COALESCE(SUM(cost_microcents), 0) as daily_cost @@ -776,7 +773,7 @@ impl UsageRepo for SqliteUsageRepo { org_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT model, @@ -805,12 +802,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); ModelSpend { - model: row.get("model"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + model: row.col("model"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -824,7 +821,7 @@ impl UsageRepo for SqliteUsageRepo { project_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT model, @@ -853,12 +850,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); ModelSpend { - model: row.get("model"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + model: row.col("model"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -872,7 +869,7 @@ impl UsageRepo for SqliteUsageRepo { user_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT model, @@ -901,12 +898,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); ModelSpend { - model: row.get("model"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + model: row.col("model"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -920,7 +917,7 @@ impl UsageRepo for SqliteUsageRepo { team_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT model, @@ -949,12 +946,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); ModelSpend { - model: row.get("model"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + model: row.col("model"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -968,7 +965,7 @@ impl UsageRepo for SqliteUsageRepo { org_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT provider, @@ -997,12 +994,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); ProviderSpend { - provider: row.get("provider"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + provider: row.col("provider"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -1016,7 +1013,7 @@ impl UsageRepo for SqliteUsageRepo { team_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT provider, @@ -1045,12 +1042,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); ProviderSpend { - provider: row.get("provider"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + provider: row.col("provider"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -1060,7 +1057,7 @@ impl UsageRepo for SqliteUsageRepo { } async fn get_summary_by_org(&self, org_id: Uuid, range: DateRange) -> DbResult { - let row = sqlx::query(&format!( + let row = query(&format!( r#" SELECT COALESCE(SUM(cost_microcents), 0) as total_cost_microcents, @@ -1085,13 +1082,13 @@ impl UsageRepo for SqliteUsageRepo { let (image_count, audio_seconds, character_count) = Self::media_fields(&row); Ok(UsageSummary { - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), - first_request_at: row.get("first_request_at"), - last_request_at: row.get("last_request_at"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), + first_request_at: row.col("first_request_at"), + last_request_at: row.col("last_request_at"), image_count, audio_seconds, character_count, @@ -1103,7 +1100,7 @@ impl UsageRepo for SqliteUsageRepo { project_id: Uuid, range: DateRange, ) -> DbResult { - let row = sqlx::query(&format!( + let row = query(&format!( r#" SELECT COALESCE(SUM(cost_microcents), 0) as total_cost_microcents, @@ -1128,13 +1125,13 @@ impl UsageRepo for SqliteUsageRepo { let (image_count, audio_seconds, character_count) = Self::media_fields(&row); Ok(UsageSummary { - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), - first_request_at: row.get("first_request_at"), - last_request_at: row.get("last_request_at"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), + first_request_at: row.col("first_request_at"), + last_request_at: row.col("last_request_at"), image_count, audio_seconds, character_count, @@ -1142,7 +1139,7 @@ impl UsageRepo for SqliteUsageRepo { } async fn get_summary_by_user(&self, user_id: Uuid, range: DateRange) -> DbResult { - let row = sqlx::query(&format!( + let row = query(&format!( r#" SELECT COALESCE(SUM(cost_microcents), 0) as total_cost_microcents, @@ -1167,13 +1164,13 @@ impl UsageRepo for SqliteUsageRepo { let (image_count, audio_seconds, character_count) = Self::media_fields(&row); Ok(UsageSummary { - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), - first_request_at: row.get("first_request_at"), - last_request_at: row.get("last_request_at"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), + first_request_at: row.col("first_request_at"), + last_request_at: row.col("last_request_at"), image_count, audio_seconds, character_count, @@ -1181,7 +1178,7 @@ impl UsageRepo for SqliteUsageRepo { } async fn get_summary_by_team(&self, team_id: Uuid, range: DateRange) -> DbResult { - let row = sqlx::query(&format!( + let row = query(&format!( r#" SELECT COALESCE(SUM(cost_microcents), 0) as total_cost_microcents, @@ -1206,13 +1203,13 @@ impl UsageRepo for SqliteUsageRepo { let (image_count, audio_seconds, character_count) = Self::media_fields(&row); Ok(UsageSummary { - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), - first_request_at: row.get("first_request_at"), - last_request_at: row.get("last_request_at"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), + first_request_at: row.col("first_request_at"), + last_request_at: row.col("last_request_at"), image_count, audio_seconds, character_count, @@ -1220,7 +1217,7 @@ impl UsageRepo for SqliteUsageRepo { } async fn get_usage_stats_by_org(&self, org_id: Uuid, range: DateRange) -> DbResult { - let rows = sqlx::query( + let rows = query( r#" SELECT COALESCE(SUM(cost_microcents), 0) as daily_cost @@ -1245,7 +1242,7 @@ impl UsageRepo for SqliteUsageRepo { project_id: Uuid, range: DateRange, ) -> DbResult { - let rows = sqlx::query( + let rows = query( r#" SELECT COALESCE(SUM(cost_microcents), 0) as daily_cost @@ -1270,7 +1267,7 @@ impl UsageRepo for SqliteUsageRepo { user_id: Uuid, range: DateRange, ) -> DbResult { - let rows = sqlx::query( + let rows = query( r#" SELECT COALESCE(SUM(cost_microcents), 0) as daily_cost @@ -1295,7 +1292,7 @@ impl UsageRepo for SqliteUsageRepo { team_id: Uuid, range: DateRange, ) -> DbResult { - let rows = sqlx::query( + let rows = query( r#" SELECT COALESCE(SUM(cost_microcents), 0) as daily_cost @@ -1320,7 +1317,7 @@ impl UsageRepo for SqliteUsageRepo { api_key_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT provider, @@ -1349,12 +1346,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); ProviderSpend { - provider: row.get("provider"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + provider: row.col("provider"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -1368,7 +1365,7 @@ impl UsageRepo for SqliteUsageRepo { project_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT provider, @@ -1397,12 +1394,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); ProviderSpend { - provider: row.get("provider"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + provider: row.col("provider"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -1416,7 +1413,7 @@ impl UsageRepo for SqliteUsageRepo { user_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT provider, @@ -1445,12 +1442,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); ProviderSpend { - provider: row.get("provider"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + provider: row.col("provider"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -1464,7 +1461,7 @@ impl UsageRepo for SqliteUsageRepo { api_key_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -1494,13 +1491,13 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyModelSpend { - date: row.get("date"), - model: row.get("model"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + model: row.col("model"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -1514,7 +1511,7 @@ impl UsageRepo for SqliteUsageRepo { org_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -1544,13 +1541,13 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyModelSpend { - date: row.get("date"), - model: row.get("model"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + model: row.col("model"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -1564,7 +1561,7 @@ impl UsageRepo for SqliteUsageRepo { project_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -1594,13 +1591,13 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyModelSpend { - date: row.get("date"), - model: row.get("model"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + model: row.col("model"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -1614,7 +1611,7 @@ impl UsageRepo for SqliteUsageRepo { user_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -1644,13 +1641,13 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyModelSpend { - date: row.get("date"), - model: row.get("model"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + model: row.col("model"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -1664,7 +1661,7 @@ impl UsageRepo for SqliteUsageRepo { team_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -1694,13 +1691,13 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyModelSpend { - date: row.get("date"), - model: row.get("model"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + model: row.col("model"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -1714,7 +1711,7 @@ impl UsageRepo for SqliteUsageRepo { api_key_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -1744,13 +1741,13 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyProviderSpend { - date: row.get("date"), - provider: row.get("provider"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + provider: row.col("provider"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -1764,7 +1761,7 @@ impl UsageRepo for SqliteUsageRepo { org_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -1794,13 +1791,13 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyProviderSpend { - date: row.get("date"), - provider: row.get("provider"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + provider: row.col("provider"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -1814,7 +1811,7 @@ impl UsageRepo for SqliteUsageRepo { project_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -1844,13 +1841,13 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyProviderSpend { - date: row.get("date"), - provider: row.get("provider"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + provider: row.col("provider"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -1864,7 +1861,7 @@ impl UsageRepo for SqliteUsageRepo { user_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -1894,13 +1891,13 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyProviderSpend { - date: row.get("date"), - provider: row.get("provider"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + provider: row.col("provider"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -1914,7 +1911,7 @@ impl UsageRepo for SqliteUsageRepo { team_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -1944,13 +1941,13 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyProviderSpend { - date: row.get("date"), - provider: row.get("provider"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + provider: row.col("provider"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -1966,7 +1963,7 @@ impl UsageRepo for SqliteUsageRepo { api_key_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT pricing_source, @@ -1995,12 +1992,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); PricingSourceSpend { - pricing_source: row.get("pricing_source"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + pricing_source: row.col("pricing_source"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -2014,7 +2011,7 @@ impl UsageRepo for SqliteUsageRepo { org_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT pricing_source, @@ -2043,12 +2040,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); PricingSourceSpend { - pricing_source: row.get("pricing_source"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + pricing_source: row.col("pricing_source"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -2062,7 +2059,7 @@ impl UsageRepo for SqliteUsageRepo { project_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT pricing_source, @@ -2091,12 +2088,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); PricingSourceSpend { - pricing_source: row.get("pricing_source"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + pricing_source: row.col("pricing_source"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -2110,7 +2107,7 @@ impl UsageRepo for SqliteUsageRepo { user_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT pricing_source, @@ -2139,12 +2136,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); PricingSourceSpend { - pricing_source: row.get("pricing_source"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + pricing_source: row.col("pricing_source"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -2158,7 +2155,7 @@ impl UsageRepo for SqliteUsageRepo { team_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT pricing_source, @@ -2187,12 +2184,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); PricingSourceSpend { - pricing_source: row.get("pricing_source"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + pricing_source: row.col("pricing_source"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -2206,7 +2203,7 @@ impl UsageRepo for SqliteUsageRepo { api_key_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -2236,13 +2233,13 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyPricingSourceSpend { - date: row.get("date"), - pricing_source: row.get("pricing_source"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + pricing_source: row.col("pricing_source"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -2256,7 +2253,7 @@ impl UsageRepo for SqliteUsageRepo { org_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -2286,13 +2283,13 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyPricingSourceSpend { - date: row.get("date"), - pricing_source: row.get("pricing_source"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + pricing_source: row.col("pricing_source"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -2306,7 +2303,7 @@ impl UsageRepo for SqliteUsageRepo { project_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -2336,13 +2333,13 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyPricingSourceSpend { - date: row.get("date"), - pricing_source: row.get("pricing_source"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + pricing_source: row.col("pricing_source"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -2356,7 +2353,7 @@ impl UsageRepo for SqliteUsageRepo { user_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -2386,13 +2383,13 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyPricingSourceSpend { - date: row.get("date"), - pricing_source: row.get("pricing_source"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + pricing_source: row.col("pricing_source"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -2406,7 +2403,7 @@ impl UsageRepo for SqliteUsageRepo { team_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -2436,13 +2433,13 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyPricingSourceSpend { - date: row.get("date"), - pricing_source: row.get("pricing_source"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + pricing_source: row.col("pricing_source"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -2460,7 +2457,7 @@ impl UsageRepo for SqliteUsageRepo { project_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT u.user_id, users.name as user_name, users.email as user_email, COALESCE(SUM(u.cost_microcents), 0) as total_cost_microcents, @@ -2492,15 +2489,15 @@ impl UsageRepo for SqliteUsageRepo { let (image_count, audio_seconds, character_count) = Self::media_fields(row); UserSpend { user_id: row - .get::, _>("user_id") + .col::>("user_id") .and_then(|s| s.parse().ok()), - user_name: row.get("user_name"), - user_email: row.get("user_email"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + user_name: row.col("user_name"), + user_email: row.col("user_email"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -2514,7 +2511,7 @@ impl UsageRepo for SqliteUsageRepo { project_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT date(u.recorded_at) as date, u.user_id, users.name as user_name, users.email as user_email, @@ -2546,17 +2543,17 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyUserSpend { - date: row.get("date"), + date: row.col("date"), user_id: row - .get::, _>("user_id") + .col::>("user_id") .and_then(|s| s.parse().ok()), - user_name: row.get("user_name"), - user_email: row.get("user_email"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + user_name: row.col("user_name"), + user_email: row.col("user_email"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -2572,7 +2569,7 @@ impl UsageRepo for SqliteUsageRepo { team_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT u.user_id, users.name as user_name, users.email as user_email, COALESCE(SUM(u.cost_microcents), 0) as total_cost_microcents, @@ -2604,15 +2601,15 @@ impl UsageRepo for SqliteUsageRepo { let (image_count, audio_seconds, character_count) = Self::media_fields(row); UserSpend { user_id: row - .get::, _>("user_id") + .col::>("user_id") .and_then(|s| s.parse().ok()), - user_name: row.get("user_name"), - user_email: row.get("user_email"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + user_name: row.col("user_name"), + user_email: row.col("user_email"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -2626,7 +2623,7 @@ impl UsageRepo for SqliteUsageRepo { team_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT date(u.recorded_at) as date, u.user_id, users.name as user_name, users.email as user_email, @@ -2658,17 +2655,17 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyUserSpend { - date: row.get("date"), + date: row.col("date"), user_id: row - .get::, _>("user_id") + .col::>("user_id") .and_then(|s| s.parse().ok()), - user_name: row.get("user_name"), - user_email: row.get("user_email"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + user_name: row.col("user_name"), + user_email: row.col("user_email"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -2682,7 +2679,7 @@ impl UsageRepo for SqliteUsageRepo { team_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT u.project_id, projects.name as project_name, COALESCE(SUM(u.cost_microcents), 0) as total_cost_microcents, @@ -2714,14 +2711,14 @@ impl UsageRepo for SqliteUsageRepo { let (image_count, audio_seconds, character_count) = Self::media_fields(row); ProjectSpend { project_id: row - .get::, _>("project_id") + .col::>("project_id") .and_then(|s| s.parse().ok()), - project_name: row.get("project_name"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + project_name: row.col("project_name"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -2735,7 +2732,7 @@ impl UsageRepo for SqliteUsageRepo { team_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT date(u.recorded_at) as date, u.project_id, projects.name as project_name, @@ -2767,16 +2764,16 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyProjectSpend { - date: row.get("date"), + date: row.col("date"), project_id: row - .get::, _>("project_id") + .col::>("project_id") .and_then(|s| s.parse().ok()), - project_name: row.get("project_name"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + project_name: row.col("project_name"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -2792,7 +2789,7 @@ impl UsageRepo for SqliteUsageRepo { org_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT u.user_id, users.name as user_name, users.email as user_email, COALESCE(SUM(u.cost_microcents), 0) as total_cost_microcents, @@ -2824,15 +2821,15 @@ impl UsageRepo for SqliteUsageRepo { let (image_count, audio_seconds, character_count) = Self::media_fields(row); UserSpend { user_id: row - .get::, _>("user_id") + .col::>("user_id") .and_then(|s| s.parse().ok()), - user_name: row.get("user_name"), - user_email: row.get("user_email"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + user_name: row.col("user_name"), + user_email: row.col("user_email"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -2846,7 +2843,7 @@ impl UsageRepo for SqliteUsageRepo { org_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT date(u.recorded_at) as date, u.user_id, users.name as user_name, users.email as user_email, @@ -2878,17 +2875,17 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyUserSpend { - date: row.get("date"), + date: row.col("date"), user_id: row - .get::, _>("user_id") + .col::>("user_id") .and_then(|s| s.parse().ok()), - user_name: row.get("user_name"), - user_email: row.get("user_email"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + user_name: row.col("user_name"), + user_email: row.col("user_email"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -2902,7 +2899,7 @@ impl UsageRepo for SqliteUsageRepo { org_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT u.project_id, projects.name as project_name, COALESCE(SUM(u.cost_microcents), 0) as total_cost_microcents, @@ -2934,14 +2931,14 @@ impl UsageRepo for SqliteUsageRepo { let (image_count, audio_seconds, character_count) = Self::media_fields(row); ProjectSpend { project_id: row - .get::, _>("project_id") + .col::>("project_id") .and_then(|s| s.parse().ok()), - project_name: row.get("project_name"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + project_name: row.col("project_name"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -2955,7 +2952,7 @@ impl UsageRepo for SqliteUsageRepo { org_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT date(u.recorded_at) as date, u.project_id, projects.name as project_name, @@ -2987,16 +2984,16 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyProjectSpend { - date: row.get("date"), + date: row.col("date"), project_id: row - .get::, _>("project_id") + .col::>("project_id") .and_then(|s| s.parse().ok()), - project_name: row.get("project_name"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + project_name: row.col("project_name"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -3010,7 +3007,7 @@ impl UsageRepo for SqliteUsageRepo { org_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT u.team_id, teams.name as team_name, COALESCE(SUM(u.cost_microcents), 0) as total_cost_microcents, @@ -3042,14 +3039,14 @@ impl UsageRepo for SqliteUsageRepo { let (image_count, audio_seconds, character_count) = Self::media_fields(row); TeamSpend { team_id: row - .get::, _>("team_id") + .col::>("team_id") .and_then(|s| s.parse().ok()), - team_name: row.get("team_name"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + team_name: row.col("team_name"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -3063,7 +3060,7 @@ impl UsageRepo for SqliteUsageRepo { org_id: Uuid, range: DateRange, ) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT date(u.recorded_at) as date, u.team_id, teams.name as team_name, @@ -3095,16 +3092,16 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyTeamSpend { - date: row.get("date"), + date: row.col("date"), team_id: row - .get::, _>("team_id") + .col::>("team_id") .and_then(|s| s.parse().ok()), - team_name: row.get("team_name"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + team_name: row.col("team_name"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -3116,7 +3113,7 @@ impl UsageRepo for SqliteUsageRepo { // --- Global scope: base queries --- async fn get_summary_global(&self, range: DateRange) -> DbResult { - let row = sqlx::query(&format!( + let row = query(&format!( r#" SELECT COALESCE(SUM(cost_microcents), 0) as total_cost_microcents, @@ -3139,13 +3136,13 @@ impl UsageRepo for SqliteUsageRepo { let (image_count, audio_seconds, character_count) = Self::media_fields(&row); Ok(UsageSummary { - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), - first_request_at: row.get("first_request_at"), - last_request_at: row.get("last_request_at"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), + first_request_at: row.col("first_request_at"), + last_request_at: row.col("last_request_at"), image_count, audio_seconds, character_count, @@ -3153,7 +3150,7 @@ impl UsageRepo for SqliteUsageRepo { } async fn get_daily_usage_global(&self, range: DateRange) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -3180,12 +3177,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailySpend { - date: row.get("date"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -3195,7 +3192,7 @@ impl UsageRepo for SqliteUsageRepo { } async fn get_model_usage_global(&self, range: DateRange) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT model, @@ -3222,12 +3219,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); ModelSpend { - model: row.get("model"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + model: row.col("model"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -3237,7 +3234,7 @@ impl UsageRepo for SqliteUsageRepo { } async fn get_provider_usage_global(&self, range: DateRange) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT provider, @@ -3264,12 +3261,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); ProviderSpend { - provider: row.get("provider"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + provider: row.col("provider"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -3282,7 +3279,7 @@ impl UsageRepo for SqliteUsageRepo { &self, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT pricing_source, @@ -3309,12 +3306,12 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); PricingSourceSpend { - pricing_source: row.get("pricing_source"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + pricing_source: row.col("pricing_source"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -3327,7 +3324,7 @@ impl UsageRepo for SqliteUsageRepo { &self, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -3355,13 +3352,13 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyModelSpend { - date: row.get("date"), - model: row.get("model"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + model: row.col("model"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -3374,7 +3371,7 @@ impl UsageRepo for SqliteUsageRepo { &self, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -3402,13 +3399,13 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyProviderSpend { - date: row.get("date"), - provider: row.get("provider"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + provider: row.col("provider"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -3421,7 +3418,7 @@ impl UsageRepo for SqliteUsageRepo { &self, range: DateRange, ) -> DbResult> { - let rows = sqlx::query(&format!( + let rows = query(&format!( r#" SELECT date(recorded_at) as date, @@ -3449,13 +3446,13 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyPricingSourceSpend { - date: row.get("date"), - pricing_source: row.get("pricing_source"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + date: row.col("date"), + pricing_source: row.col("pricing_source"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -3465,7 +3462,7 @@ impl UsageRepo for SqliteUsageRepo { } async fn get_usage_stats_global(&self, range: DateRange) -> DbResult { - let rows = sqlx::query( + let rows = query( r#" SELECT COALESCE(SUM(cost_microcents), 0) as daily_cost @@ -3486,7 +3483,7 @@ impl UsageRepo for SqliteUsageRepo { // --- Global scope: entity breakdowns --- async fn get_user_usage_global(&self, range: DateRange) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT u.user_id, users.name as user_name, users.email as user_email, COALESCE(SUM(u.cost_microcents), 0) as total_cost_microcents, @@ -3516,15 +3513,15 @@ impl UsageRepo for SqliteUsageRepo { let (image_count, audio_seconds, character_count) = Self::media_fields(row); UserSpend { user_id: row - .get::, _>("user_id") + .col::>("user_id") .and_then(|s| s.parse().ok()), - user_name: row.get("user_name"), - user_email: row.get("user_email"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + user_name: row.col("user_name"), + user_email: row.col("user_email"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -3534,7 +3531,7 @@ impl UsageRepo for SqliteUsageRepo { } async fn get_daily_user_usage_global(&self, range: DateRange) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT date(u.recorded_at) as date, u.user_id, users.name as user_name, users.email as user_email, @@ -3564,17 +3561,17 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyUserSpend { - date: row.get("date"), + date: row.col("date"), user_id: row - .get::, _>("user_id") + .col::>("user_id") .and_then(|s| s.parse().ok()), - user_name: row.get("user_name"), - user_email: row.get("user_email"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + user_name: row.col("user_name"), + user_email: row.col("user_email"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -3584,7 +3581,7 @@ impl UsageRepo for SqliteUsageRepo { } async fn get_project_usage_global(&self, range: DateRange) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT u.project_id, projects.name as project_name, COALESCE(SUM(u.cost_microcents), 0) as total_cost_microcents, @@ -3614,14 +3611,14 @@ impl UsageRepo for SqliteUsageRepo { let (image_count, audio_seconds, character_count) = Self::media_fields(row); ProjectSpend { project_id: row - .get::, _>("project_id") + .col::>("project_id") .and_then(|s| s.parse().ok()), - project_name: row.get("project_name"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + project_name: row.col("project_name"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -3634,7 +3631,7 @@ impl UsageRepo for SqliteUsageRepo { &self, range: DateRange, ) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT date(u.recorded_at) as date, u.project_id, projects.name as project_name, @@ -3664,16 +3661,16 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyProjectSpend { - date: row.get("date"), + date: row.col("date"), project_id: row - .get::, _>("project_id") + .col::>("project_id") .and_then(|s| s.parse().ok()), - project_name: row.get("project_name"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + project_name: row.col("project_name"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -3683,7 +3680,7 @@ impl UsageRepo for SqliteUsageRepo { } async fn get_team_usage_global(&self, range: DateRange) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT u.team_id, teams.name as team_name, COALESCE(SUM(u.cost_microcents), 0) as total_cost_microcents, @@ -3713,14 +3710,14 @@ impl UsageRepo for SqliteUsageRepo { let (image_count, audio_seconds, character_count) = Self::media_fields(row); TeamSpend { team_id: row - .get::, _>("team_id") + .col::>("team_id") .and_then(|s| s.parse().ok()), - team_name: row.get("team_name"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + team_name: row.col("team_name"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -3730,7 +3727,7 @@ impl UsageRepo for SqliteUsageRepo { } async fn get_daily_team_usage_global(&self, range: DateRange) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT date(u.recorded_at) as date, u.team_id, teams.name as team_name, @@ -3760,16 +3757,16 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyTeamSpend { - date: row.get("date"), + date: row.col("date"), team_id: row - .get::, _>("team_id") + .col::>("team_id") .and_then(|s| s.parse().ok()), - team_name: row.get("team_name"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + team_name: row.col("team_name"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -3779,7 +3776,7 @@ impl UsageRepo for SqliteUsageRepo { } async fn get_org_usage_global(&self, range: DateRange) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT u.org_id, organizations.name as org_name, COALESCE(SUM(u.cost_microcents), 0) as total_cost_microcents, @@ -3809,14 +3806,14 @@ impl UsageRepo for SqliteUsageRepo { let (image_count, audio_seconds, character_count) = Self::media_fields(row); OrgSpend { org_id: row - .get::, _>("org_id") + .col::>("org_id") .and_then(|s| s.parse().ok()), - org_name: row.get("org_name"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + org_name: row.col("org_name"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -3826,7 +3823,7 @@ impl UsageRepo for SqliteUsageRepo { } async fn get_daily_org_usage_global(&self, range: DateRange) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT date(u.recorded_at) as date, u.org_id, organizations.name as org_name, @@ -3856,16 +3853,16 @@ impl UsageRepo for SqliteUsageRepo { .map(|row| { let (image_count, audio_seconds, character_count) = Self::media_fields(row); DailyOrgSpend { - date: row.get("date"), + date: row.col("date"), org_id: row - .get::, _>("org_id") + .col::>("org_id") .and_then(|s| s.parse().ok()), - org_name: row.get("org_name"), - total_cost_microcents: row.get("total_cost_microcents"), - input_tokens: row.get("input_tokens"), - output_tokens: row.get("output_tokens"), - total_tokens: row.get("total_tokens"), - request_count: row.get("request_count"), + org_name: row.col("org_name"), + total_cost_microcents: row.col("total_cost_microcents"), + input_tokens: row.col("input_tokens"), + output_tokens: row.col("output_tokens"), + total_tokens: row.col("total_tokens"), + request_count: row.col("request_count"), image_count, audio_seconds, character_count, @@ -3895,7 +3892,7 @@ impl UsageRepo for SqliteUsageRepo { let limit = std::cmp::min(batch_size as u64, remaining) as i64; // Delete a batch using subquery to select IDs (SQLite doesn't support LIMIT in DELETE directly) - let result = sqlx::query( + let result = query( r#" DELETE FROM usage_records WHERE id IN ( @@ -3941,7 +3938,7 @@ impl UsageRepo for SqliteUsageRepo { let limit = std::cmp::min(batch_size as u64, remaining) as i64; // daily_spend uses composite primary key (api_key_id, date, model), use rowid for deletion - let result = sqlx::query( + let result = query( r#" DELETE FROM daily_spend WHERE rowid IN ( @@ -3970,8 +3967,8 @@ impl UsageRepo for SqliteUsageRepo { /// Helper function to compute usage stats from daily cost rows. /// This avoids duplicating the statistics calculation logic. -fn compute_stats_from_daily_costs(rows: &[sqlx::sqlite::SqliteRow]) -> DbResult { - let daily_costs: Vec = rows.iter().map(|row| row.get("daily_cost")).collect(); +fn compute_stats_from_daily_costs(rows: &[super::backend::Row]) -> DbResult { + let daily_costs: Vec = rows.iter().map(|row| row.col("daily_cost")).collect(); let sample_days = daily_costs.len() as i32; if sample_days == 0 { diff --git a/src/db/sqlite/users.rs b/src/db/sqlite/users.rs index e39569c..fee084e 100644 --- a/src/db/sqlite/users.rs +++ b/src/db/sqlite/users.rs @@ -1,8 +1,10 @@ use async_trait::async_trait; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; -use super::common::parse_uuid; +use super::{ + backend::{Pool, RowExt, map_unique_violation, query, unique_violation_message}, + common::parse_uuid, +}; use crate::{ db::{ error::{DbError, DbResult}, @@ -18,11 +20,11 @@ use crate::{ }; pub struct SqliteUserRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteUserRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } @@ -37,7 +39,7 @@ impl SqliteUserRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT id, external_id, email, name, created_at, updated_at FROM users @@ -48,7 +50,7 @@ impl SqliteUserRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(cursor.created_at) .bind(cursor.id.to_string()) .bind(fetch_limit) @@ -61,12 +63,12 @@ impl SqliteUserRepo { .take(limit as usize) .map(|row| { Ok(User { - id: parse_uuid(&row.get::("id"))?, - external_id: row.get("external_id"), - email: row.get("email"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + external_id: row.col("external_id"), + email: row.col("email"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) }) .collect::>>()?; @@ -95,7 +97,7 @@ impl SqliteUserRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT u.id, u.external_id, u.email, u.name, u.created_at, u.updated_at FROM users u @@ -108,7 +110,7 @@ impl SqliteUserRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(org_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -122,12 +124,12 @@ impl SqliteUserRepo { .take(limit as usize) .map(|row| { Ok(User { - id: parse_uuid(&row.get::("id"))?, - external_id: row.get("external_id"), - email: row.get("email"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + external_id: row.col("external_id"), + email: row.col("email"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) }) .collect::>>()?; @@ -156,7 +158,7 @@ impl SqliteUserRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT u.id, u.external_id, u.email, u.name, u.created_at, u.updated_at FROM users u @@ -169,7 +171,7 @@ impl SqliteUserRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(project_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -183,12 +185,12 @@ impl SqliteUserRepo { .take(limit as usize) .map(|row| { Ok(User { - id: parse_uuid(&row.get::("id"))?, - external_id: row.get("external_id"), - email: row.get("email"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + external_id: row.col("external_id"), + email: row.col("email"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) }) .collect::>>()?; @@ -206,13 +208,14 @@ impl SqliteUserRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl UserRepo for SqliteUserRepo { async fn create(&self, input: CreateUser) -> DbResult { let id = Uuid::new_v4(); let now = chrono::Utc::now(); - sqlx::query( + query( r#" INSERT INTO users (id, external_id, email, name, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?) @@ -226,15 +229,10 @@ impl UserRepo for SqliteUserRepo { .bind(now) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict(format!( - "User with external_id '{}' already exists", - input.external_id - )) - } - _ => DbError::from(e), - })?; + .map_err(map_unique_violation(format!( + "User with external_id '{}' already exists", + input.external_id + )))?; Ok(User { id, @@ -247,7 +245,7 @@ impl UserRepo for SqliteUserRepo { } async fn get_by_id(&self, id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, external_id, email, name, created_at, updated_at FROM users @@ -260,19 +258,19 @@ impl UserRepo for SqliteUserRepo { match result { Some(row) => Ok(Some(User { - id: parse_uuid(&row.get::("id"))?, - external_id: row.get("external_id"), - email: row.get("email"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + external_id: row.col("external_id"), + email: row.col("email"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), })), None => Ok(None), } } async fn get_by_external_id(&self, external_id: &str) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, external_id, email, name, created_at, updated_at FROM users @@ -285,12 +283,12 @@ impl UserRepo for SqliteUserRepo { match result { Some(row) => Ok(Some(User { - id: parse_uuid(&row.get::("id"))?, - external_id: row.get("external_id"), - email: row.get("email"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + external_id: row.col("external_id"), + email: row.col("email"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), })), None => Ok(None), } @@ -308,7 +306,7 @@ impl UserRepo for SqliteUserRepo { } // First page (no cursor provided) - let rows = sqlx::query( + let rows = query( r#" SELECT id, external_id, email, name, created_at, updated_at FROM users @@ -326,12 +324,12 @@ impl UserRepo for SqliteUserRepo { .take(limit as usize) .map(|row| { Ok(User { - id: parse_uuid(&row.get::("id"))?, - external_id: row.get("external_id"), - email: row.get("email"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + external_id: row.col("external_id"), + email: row.col("email"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) }) .collect::>>()?; @@ -346,16 +344,16 @@ impl UserRepo for SqliteUserRepo { async fn count(&self, _include_deleted: bool) -> DbResult { // Users table doesn't have soft delete, so include_deleted is ignored - let row = sqlx::query("SELECT COUNT(*) as count FROM users") + let row = query("SELECT COUNT(*) as count FROM users") .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn update(&self, id: Uuid, input: UpdateUser) -> DbResult { let now = chrono::Utc::now(); - let result = sqlx::query( + let result = query( r#" UPDATE users SET email = COALESCE(?, email), @@ -387,7 +385,7 @@ impl UserRepo for SqliteUserRepo { ) -> DbResult<()> { let now = chrono::Utc::now(); - sqlx::query( + query( r#" INSERT INTO org_memberships (org_id, user_id, role, source, created_at) VALUES (?, ?, ?, ?, ?) @@ -400,12 +398,11 @@ impl UserRepo for SqliteUserRepo { .bind(now) .execute(&self.pool) .await - .map_err(|e| match &e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - // SQLite error format: "UNIQUE constraint failed: table.column1, table.column2" - // Primary key (org_id, user_id) violation includes both columns - // Single-org unique index (user_id only) violation includes just user_id - let msg = db_err.message(); + .map_err(|e| { + // SQLite error format: "UNIQUE constraint failed: table.column1, table.column2" + // Primary key (org_id, user_id) violation includes both columns + // Single-org unique index (user_id only) violation includes just user_id + if let Some(msg) = unique_violation_message(&e) { // If message contains user_id but NOT org_id, it's the single-org constraint if msg.contains("user_id") && !msg.contains("org_id") { DbError::Conflict( @@ -416,8 +413,9 @@ impl UserRepo for SqliteUserRepo { } else { DbError::Conflict("User is already a member of this organization".to_string()) } + } else { + DbError::from(e) } - _ => DbError::from(e), })?; Ok(()) @@ -430,7 +428,7 @@ impl UserRepo for SqliteUserRepo { except_org_ids: &[Uuid], ) -> DbResult { let result = if except_org_ids.is_empty() { - sqlx::query( + query( r#" DELETE FROM org_memberships WHERE user_id = ? AND source = ? @@ -447,16 +445,14 @@ impl UserRepo for SqliteUserRepo { .map(|_| "?") .collect::>() .join(","); - let query = format!( + let sql = format!( r#" DELETE FROM org_memberships WHERE user_id = ? AND source = ? AND org_id NOT IN ({}) "#, placeholders ); - let mut q = sqlx::query(&query) - .bind(user_id.to_string()) - .bind(source.as_str()); + let mut q = query(&sql).bind(user_id.to_string()).bind(source.as_str()); for id in except_org_ids { q = q.bind(id.to_string()); } @@ -472,7 +468,7 @@ impl UserRepo for SqliteUserRepo { org_id: Uuid, role: &str, ) -> DbResult<()> { - let result = sqlx::query( + let result = query( r#" UPDATE org_memberships SET role = ? @@ -493,7 +489,7 @@ impl UserRepo for SqliteUserRepo { } async fn remove_from_org(&self, user_id: Uuid, org_id: Uuid) -> DbResult<()> { - let result = sqlx::query( + let result = query( r#" DELETE FROM org_memberships WHERE org_id = ? AND user_id = ? @@ -527,7 +523,7 @@ impl UserRepo for SqliteUserRepo { } // First page (no cursor provided) - let rows = sqlx::query( + let rows = query( r#" SELECT u.id, u.external_id, u.email, u.name, u.created_at, u.updated_at FROM users u @@ -548,12 +544,12 @@ impl UserRepo for SqliteUserRepo { .take(limit as usize) .map(|row| { Ok(User { - id: parse_uuid(&row.get::("id"))?, - external_id: row.get("external_id"), - email: row.get("email"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + external_id: row.col("external_id"), + email: row.col("email"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) }) .collect::>>()?; @@ -567,11 +563,11 @@ impl UserRepo for SqliteUserRepo { } async fn count_org_members(&self, org_id: Uuid, _include_deleted: bool) -> DbResult { - let row = sqlx::query("SELECT COUNT(*) as count FROM org_memberships WHERE org_id = ?") + let row = query("SELECT COUNT(*) as count FROM org_memberships WHERE org_id = ?") .bind(org_id.to_string()) .fetch_one(&self.pool) .await?; - Ok(row.get::("count")) + Ok(row.col::("count")) } async fn add_to_project( @@ -583,7 +579,7 @@ impl UserRepo for SqliteUserRepo { ) -> DbResult<()> { let now = chrono::Utc::now(); - sqlx::query( + query( r#" INSERT INTO project_memberships (project_id, user_id, role, source, created_at) VALUES (?, ?, ?, ?, ?) @@ -596,12 +592,9 @@ impl UserRepo for SqliteUserRepo { .bind(now) .execute(&self.pool) .await - .map_err(|e| match e { - sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - DbError::Conflict("User is already a member of this project".to_string()) - } - _ => DbError::from(e), - })?; + .map_err(map_unique_violation( + "User is already a member of this project".to_string(), + ))?; Ok(()) } @@ -613,7 +606,7 @@ impl UserRepo for SqliteUserRepo { except_project_ids: &[Uuid], ) -> DbResult { let result = if except_project_ids.is_empty() { - sqlx::query( + query( r#" DELETE FROM project_memberships WHERE user_id = ? AND source = ? @@ -629,16 +622,14 @@ impl UserRepo for SqliteUserRepo { .map(|_| "?") .collect::>() .join(","); - let query = format!( + let sql = format!( r#" DELETE FROM project_memberships WHERE user_id = ? AND source = ? AND project_id NOT IN ({}) "#, placeholders ); - let mut q = sqlx::query(&query) - .bind(user_id.to_string()) - .bind(source.as_str()); + let mut q = query(&sql).bind(user_id.to_string()).bind(source.as_str()); for id in except_project_ids { q = q.bind(id.to_string()); } @@ -654,7 +645,7 @@ impl UserRepo for SqliteUserRepo { project_id: Uuid, role: &str, ) -> DbResult<()> { - let result = sqlx::query( + let result = query( r#" UPDATE project_memberships SET role = ? @@ -675,7 +666,7 @@ impl UserRepo for SqliteUserRepo { } async fn remove_from_project(&self, user_id: Uuid, project_id: Uuid) -> DbResult<()> { - let result = sqlx::query( + let result = query( r#" DELETE FROM project_memberships WHERE project_id = ? AND user_id = ? @@ -709,7 +700,7 @@ impl UserRepo for SqliteUserRepo { } // First page (no cursor provided) - let rows = sqlx::query( + let rows = query( r#" SELECT u.id, u.external_id, u.email, u.name, u.created_at, u.updated_at FROM users u @@ -730,12 +721,12 @@ impl UserRepo for SqliteUserRepo { .take(limit as usize) .map(|row| { Ok(User { - id: parse_uuid(&row.get::("id"))?, - external_id: row.get("external_id"), - email: row.get("email"), - name: row.get("name"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + id: parse_uuid(&row.col::("id"))?, + external_id: row.col("external_id"), + email: row.col("email"), + name: row.col("name"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) }) .collect::>>()?; @@ -753,12 +744,11 @@ impl UserRepo for SqliteUserRepo { project_id: Uuid, _include_deleted: bool, ) -> DbResult { - let row = - sqlx::query("SELECT COUNT(*) as count FROM project_memberships WHERE project_id = ?") - .bind(project_id.to_string()) - .fetch_one(&self.pool) - .await?; - Ok(row.get::("count")) + let row = query("SELECT COUNT(*) as count FROM project_memberships WHERE project_id = ?") + .bind(project_id.to_string()) + .fetch_one(&self.pool) + .await?; + Ok(row.col::("count")) } // ==================== GDPR Export Methods ==================== @@ -767,7 +757,7 @@ impl UserRepo for SqliteUserRepo { &self, user_id: Uuid, ) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT o.id as org_id, @@ -788,14 +778,14 @@ impl UserRepo for SqliteUserRepo { rows.into_iter() .map(|row| { - let source_str: String = row.get("source"); + let source_str: String = row.col("source"); Ok(UserOrgMembership { - org_id: parse_uuid(&row.get::("org_id"))?, - org_slug: row.get("org_slug"), - org_name: row.get("org_name"), - role: row.get("role"), - source: MembershipSource::from_str(&source_str).unwrap_or_default(), - joined_at: row.get("joined_at"), + org_id: parse_uuid(&row.col::("org_id"))?, + org_slug: row.col("org_slug"), + org_name: row.col("org_name"), + role: row.col("role"), + source: MembershipSource::parse(&source_str).unwrap_or_default(), + joined_at: row.col("joined_at"), }) }) .collect() @@ -805,7 +795,7 @@ impl UserRepo for SqliteUserRepo { &self, user_id: Uuid, ) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT p.id as project_id, @@ -827,22 +817,22 @@ impl UserRepo for SqliteUserRepo { rows.into_iter() .map(|row| { - let source_str: String = row.get("source"); + let source_str: String = row.col("source"); Ok(UserProjectMembership { - project_id: parse_uuid(&row.get::("project_id"))?, - project_slug: row.get("project_slug"), - project_name: row.get("project_name"), - org_id: parse_uuid(&row.get::("org_id"))?, - role: row.get("role"), - source: MembershipSource::from_str(&source_str).unwrap_or_default(), - joined_at: row.get("joined_at"), + project_id: parse_uuid(&row.col::("project_id"))?, + project_slug: row.col("project_slug"), + project_name: row.col("project_name"), + org_id: parse_uuid(&row.col::("org_id"))?, + role: row.col("role"), + source: MembershipSource::parse(&source_str).unwrap_or_default(), + joined_at: row.col("joined_at"), }) }) .collect() } async fn get_team_memberships_for_user(&self, user_id: Uuid) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT t.id as team_id, @@ -864,15 +854,15 @@ impl UserRepo for SqliteUserRepo { rows.into_iter() .map(|row| { - let source_str: String = row.get("source"); + let source_str: String = row.col("source"); Ok(TeamMembership { - team_id: parse_uuid(&row.get::("team_id"))?, - team_slug: row.get("team_slug"), - team_name: row.get("team_name"), - org_id: parse_uuid(&row.get::("org_id"))?, - role: row.get("role"), - source: MembershipSource::from_str(&source_str).unwrap_or_default(), - joined_at: row.get("joined_at"), + team_id: parse_uuid(&row.col::("team_id"))?, + team_slug: row.col("team_slug"), + team_name: row.col("team_name"), + org_id: parse_uuid(&row.col::("org_id"))?, + role: row.col("role"), + source: MembershipSource::parse(&source_str).unwrap_or_default(), + joined_at: row.col("joined_at"), }) }) .collect() @@ -885,7 +875,7 @@ impl UserRepo for SqliteUserRepo { let mut result = UserDeletionResult::default(); // Delete usage records for user's API keys first (they reference api_keys) - let usage_result = sqlx::query( + let usage_result = query( r#" DELETE FROM usage_records WHERE api_key_id IN ( @@ -900,7 +890,7 @@ impl UserRepo for SqliteUserRepo { result.usage_records_deleted = usage_result.rows_affected(); // Delete API keys owned by user - let api_keys_result = sqlx::query( + let api_keys_result = query( r#" DELETE FROM api_keys WHERE owner_type = 'user' AND owner_id = ? @@ -912,7 +902,7 @@ impl UserRepo for SqliteUserRepo { result.api_keys_deleted = api_keys_result.rows_affected(); // Delete conversations owned by user - let conversations_result = sqlx::query( + let conversations_result = query( r#" DELETE FROM conversations WHERE owner_type = 'user' AND owner_id = ? @@ -924,7 +914,7 @@ impl UserRepo for SqliteUserRepo { result.conversations_deleted = conversations_result.rows_affected(); // Delete dynamic providers owned by user - let providers_result = sqlx::query( + let providers_result = query( r#" DELETE FROM dynamic_providers WHERE owner_type = 'user' AND owner_id = ? @@ -936,7 +926,7 @@ impl UserRepo for SqliteUserRepo { result.dynamic_providers_deleted = providers_result.rows_affected(); // Delete user (org_memberships and project_memberships cascade automatically) - let user_result = sqlx::query( + let user_result = query( r#" DELETE FROM users WHERE id = ? "#, @@ -956,6 +946,8 @@ impl UserRepo for SqliteUserRepo { #[cfg(test)] mod tests { + use sqlx::SqlitePool; + use super::*; use crate::db::repos::UserRepo; diff --git a/src/db/sqlite/vector_stores.rs b/src/db/sqlite/vector_stores.rs index 9482f08..a9a171c 100644 --- a/src/db/sqlite/vector_stores.rs +++ b/src/db/sqlite/vector_stores.rs @@ -2,10 +2,12 @@ use std::collections::HashMap; use async_trait::async_trait; use chrono::{DateTime, Utc}; -use sqlx::{Row, SqlitePool}; use uuid::Uuid; -use super::common::parse_uuid; +use super::{ + backend::{Pool, Row, RowExt, query}, + common::parse_uuid, +}; use crate::{ db::{ error::{DbError, DbResult}, @@ -20,11 +22,11 @@ use crate::{ }; pub struct SqliteVectorStoresRepo { - pool: SqlitePool, + pool: Pool, } impl SqliteVectorStoresRepo { - pub fn new(pool: SqlitePool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } @@ -72,56 +74,56 @@ impl SqliteVectorStoresRepo { /// Expects columns: id, owner_type, owner_id, name, description, status, embedding_model, /// embedding_dimensions, usage_bytes, file_counts, metadata, expires_after, expires_at, /// last_active_at, created_at, updated_at - fn vector_store_from_row(row: &sqlx::sqlite::SqliteRow) -> DbResult { - let owner_type_str: String = row.get("owner_type"); - let status_str: String = row.get("status"); - let file_counts_str: String = row.get("file_counts"); + fn vector_store_from_row(row: &Row) -> DbResult { + let owner_type_str: String = row.col("owner_type"); + let status_str: String = row.col("status"); + let file_counts_str: String = row.col("file_counts"); Ok(VectorStore { - id: parse_uuid(&row.get::("id"))?, + id: parse_uuid(&row.col::("id"))?, object: OBJECT_TYPE_VECTOR_STORE.to_string(), owner_type: owner_type_str .parse() .map_err(|e: String| DbError::Internal(e))?, - owner_id: parse_uuid(&row.get::("owner_id"))?, - name: row.get("name"), - description: row.get("description"), + owner_id: parse_uuid(&row.col::("owner_id"))?, + name: row.col("name"), + description: row.col("description"), status: status_str .parse() .map_err(|e: String| DbError::Internal(e))?, - embedding_model: row.get("embedding_model"), - embedding_dimensions: row.get("embedding_dimensions"), - usage_bytes: row.get("usage_bytes"), + embedding_model: row.col("embedding_model"), + embedding_dimensions: row.col("embedding_dimensions"), + usage_bytes: row.col("usage_bytes"), file_counts: Self::parse_file_counts(&file_counts_str)?, - metadata: Self::parse_metadata(row.get("metadata"))?, - expires_after: Self::parse_expires_after(row.get("expires_after"))?, - expires_at: row.get("expires_at"), - last_active_at: row.get("last_active_at"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + metadata: Self::parse_metadata(row.col("metadata"))?, + expires_after: Self::parse_expires_after(row.col("expires_after"))?, + expires_at: row.col("expires_at"), + last_active_at: row.col("last_active_at"), + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) } /// Parse a VectorStoreFile from a database row. /// Expects columns: id, vector_store_id, file_id, status, usage_bytes, last_error, /// chunking_strategy, attributes, created_at, updated_at - fn vector_store_file_from_row(row: &sqlx::sqlite::SqliteRow) -> DbResult { - let status_str: String = row.get("status"); + fn vector_store_file_from_row(row: &Row) -> DbResult { + let status_str: String = row.col("status"); Ok(VectorStoreFile { - internal_id: parse_uuid(&row.get::("id"))?, - file_id: parse_uuid(&row.get::("file_id"))?, + internal_id: parse_uuid(&row.col::("id"))?, + file_id: parse_uuid(&row.col::("file_id"))?, object: OBJECT_TYPE_VECTOR_STORE_FILE.to_string(), - vector_store_id: parse_uuid(&row.get::("vector_store_id"))?, + vector_store_id: parse_uuid(&row.col::("vector_store_id"))?, status: status_str .parse() .map_err(|e: String| DbError::Internal(e))?, - usage_bytes: row.get("usage_bytes"), - last_error: Self::parse_file_error(row.get("last_error"))?, - chunking_strategy: Self::parse_chunking_strategy(row.get("chunking_strategy"))?, - attributes: Self::parse_attributes(row.get("attributes"))?, - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), + usage_bytes: row.col("usage_bytes"), + last_error: Self::parse_file_error(row.col("last_error"))?, + chunking_strategy: Self::parse_chunking_strategy(row.col("chunking_strategy"))?, + attributes: Self::parse_attributes(row.col("attributes"))?, + created_at: row.col("created_at"), + updated_at: row.col("updated_at"), }) } @@ -134,7 +136,8 @@ impl SqliteVectorStoresRepo { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl VectorStoresRepo for SqliteVectorStoresRepo { // ==================== Vector Stores CRUD ==================== @@ -162,7 +165,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { let default_file_counts = r#"{"cancelled":0,"completed":0,"failed":0,"in_progress":0,"total":0}"#; - sqlx::query( + query( r#" INSERT INTO vector_stores (id, owner_type, owner_id, name, description, embedding_model, embedding_dimensions, metadata, expires_after, file_counts, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) @@ -205,7 +208,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { } async fn get_vector_store(&self, id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, owner_type, owner_id, name, description, status, embedding_model, embedding_dimensions, usage_bytes, file_counts, metadata, expires_after, expires_at, last_active_at, created_at, updated_at @@ -224,7 +227,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { } async fn get_by_id_and_org(&self, id: Uuid, org_id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT vs.id, vs.owner_type, vs.owner_id, vs.name, vs.description, vs.status, vs.embedding_model, vs.embedding_dimensions, vs.usage_bytes, vs.file_counts, vs.metadata, vs.expires_after, vs.expires_at, vs.last_active_at, vs.created_at, vs.updated_at @@ -267,7 +270,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { owner_id: Uuid, name: &str, ) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, owner_type, owner_id, name, description, status, embedding_model, embedding_dimensions, usage_bytes, file_counts, metadata, expires_after, expires_at, last_active_at, created_at, updated_at @@ -307,7 +310,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { "AND deleted_at IS NULL" }; - let query = format!( + let sql = format!( r#" SELECT id, owner_type, owner_id, name, description, status, embedding_model, embedding_dimensions, usage_bytes, file_counts, metadata, expires_after, expires_at, last_active_at, created_at, updated_at @@ -321,7 +324,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { comparison, deleted_filter, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(owner_type.as_str()) .bind(owner_id.to_string()) .bind(cursor.created_at) @@ -354,7 +357,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { // First page (no cursor) let order = params.sort_order.as_sql(); - let query = if params.include_deleted { + let sql = if params.include_deleted { format!( r#" SELECT id, owner_type, owner_id, name, description, status, embedding_model, embedding_dimensions, @@ -380,7 +383,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { ) }; - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(owner_type.as_str()) .bind(owner_id.to_string()) .bind(fetch_limit) @@ -463,7 +466,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { "AND deleted_at IS NULL" }; - let query = format!( + let sql = format!( r#" SELECT id, owner_type, owner_id, name, description, status, embedding_model, embedding_dimensions, usage_bytes, file_counts, metadata, expires_after, expires_at, last_active_at, created_at, updated_at @@ -478,16 +481,16 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { ); // Build the query dynamically - let mut query_builder = sqlx::query(&query); + let mut q = query(&sql); for binding in &bindings { - query_builder = query_builder.bind(binding); + q = q.bind(binding); } - query_builder = query_builder + q = q .bind(cursor.created_at) .bind(cursor.id.to_string()) .bind(fetch_limit); - let rows = query_builder.fetch_all(&self.pool).await?; + let rows = q.fetch_all(&self.pool).await?; let has_more = rows.len() as i64 > limit; let mut items: Vec = rows @@ -512,7 +515,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { } // First page (no cursor) - let query = if params.include_deleted { + let sql = if params.include_deleted { format!( r#" SELECT id, owner_type, owner_id, name, description, status, embedding_model, embedding_dimensions, @@ -539,13 +542,13 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { }; // Build the query dynamically - let mut query_builder = sqlx::query(&query); + let mut q = query(&sql); for binding in &bindings { - query_builder = query_builder.bind(binding); + q = q.bind(binding); } - query_builder = query_builder.bind(fetch_limit); + q = q.bind(fetch_limit); - let rows = query_builder.fetch_all(&self.pool).await?; + let rows = q.fetch_all(&self.pool).await?; let has_more = rows.len() as i64 > limit; let items: Vec = rows @@ -584,7 +587,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { "AND deleted_at IS NULL" }; - let query = format!( + let sql = format!( r#" SELECT id, owner_type, owner_id, name, description, status, embedding_model, embedding_dimensions, usage_bytes, file_counts, metadata, expires_after, expires_at, last_active_at, created_at, updated_at @@ -597,7 +600,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { comparison, deleted_filter, order_dir, order_dir ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(cursor.created_at) .bind(cursor.id.to_string()) .bind(fetch_limit) @@ -627,7 +630,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { } // First page (no cursor) - let query = if params.include_deleted { + let sql = if params.include_deleted { format!( r#" SELECT id, owner_type, owner_id, name, description, status, embedding_model, embedding_dimensions, @@ -652,10 +655,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { ) }; - let rows = sqlx::query(&query) - .bind(fetch_limit) - .fetch_all(&self.pool) - .await?; + let rows = query(&sql).bind(fetch_limit).fetch_all(&self.pool).await?; let has_more = rows.len() as i64 > limit; let items: Vec = rows @@ -684,11 +684,11 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { // Use IMMEDIATE transaction mode to acquire write lock before reading let mut conn = self.pool.acquire().await?; - sqlx::query("BEGIN IMMEDIATE").execute(&mut *conn).await?; + query("BEGIN IMMEDIATE").execute(&mut *conn).await?; let result = async { // Read current state within transaction - let current = sqlx::query( + let current = query( r#" SELECT id, owner_type, owner_id, name, description, status, embedding_model, embedding_dimensions, usage_bytes, file_counts, metadata, expires_after, expires_at, last_active_at, created_at, updated_at @@ -701,20 +701,20 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { .await? .ok_or(DbError::NotFound)?; - let owner_type_str: String = current.get("owner_type"); + let owner_type_str: String = current.col("owner_type"); let owner_type: VectorStoreOwnerType = owner_type_str .parse() .map_err(|e: String| DbError::Internal(e))?; - let owner_id = parse_uuid(¤t.get::("owner_id"))?; - let status_str: String = current.get("status"); - let file_counts_str: String = current.get("file_counts"); - let embedding_model: String = current.get("embedding_model"); - let embedding_dimensions: i32 = current.get("embedding_dimensions"); + let owner_id = parse_uuid(¤t.col::("owner_id"))?; + let status_str: String = current.col("status"); + let file_counts_str: String = current.col("file_counts"); + let embedding_model: String = current.col("embedding_model"); + let embedding_dimensions: i32 = current.col("embedding_dimensions"); - let current_name: String = current.get("name"); - let current_description: Option = current.get("description"); - let current_metadata: Option = current.get("metadata"); - let current_expires_after: Option = current.get("expires_after"); + let current_name: String = current.col("name"); + let current_description: Option = current.col("description"); + let current_metadata: Option = current.col("metadata"); + let current_expires_after: Option = current.col("expires_after"); let new_name = input.name.unwrap_or(current_name); let new_description = input.description.or(current_description); @@ -731,7 +731,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { .map_err(|e| DbError::Internal(e.to_string()))? .or(current_expires_after); - let update_result = sqlx::query( + let update_result = query( r#" UPDATE vector_stores SET name = ?, description = ?, metadata = ?, expires_after = ?, updated_at = ? @@ -763,13 +763,13 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { .map_err(|e: String| DbError::Internal(e))?, embedding_model, embedding_dimensions, - usage_bytes: current.get("usage_bytes"), + usage_bytes: current.col("usage_bytes"), file_counts: Self::parse_file_counts(&file_counts_str)?, metadata: Self::parse_metadata(new_metadata)?, expires_after: Self::parse_expires_after(new_expires_after)?, - expires_at: current.get("expires_at"), - last_active_at: current.get("last_active_at"), - created_at: current.get("created_at"), + expires_at: current.col("expires_at"), + last_active_at: current.col("last_active_at"), + created_at: current.col("created_at"), updated_at: now, }) } @@ -778,10 +778,10 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { // Commit or rollback based on result match &result { Ok(_) => { - sqlx::query("COMMIT").execute(&mut *conn).await?; + query("COMMIT").execute(&mut *conn).await?; } Err(_) => { - let _ = sqlx::query("ROLLBACK").execute(&mut *conn).await; + let _ = query("ROLLBACK").execute(&mut *conn).await; } } @@ -791,7 +791,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { async fn delete_vector_store(&self, id: Uuid) -> DbResult<()> { let now = chrono::Utc::now(); - let result = sqlx::query( + let result = query( r#" UPDATE vector_stores SET deleted_at = ? @@ -812,7 +812,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { async fn hard_delete_vector_store(&self, id: Uuid) -> DbResult<()> { // First delete all vector_store_files links - sqlx::query( + query( r#" DELETE FROM vector_store_files WHERE vector_store_id = ? @@ -823,7 +823,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { .await?; // Then delete the vector store - let result = sqlx::query( + let result = query( r#" DELETE FROM vector_stores WHERE id = ? @@ -844,7 +844,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { &self, older_than: DateTime, ) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT id, owner_type, owner_id, name, description, status, embedding_model, embedding_dimensions, usage_bytes, file_counts, metadata, expires_after, expires_at, last_active_at, created_at, updated_at @@ -864,7 +864,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { async fn touch_vector_store(&self, id: Uuid) -> DbResult<()> { let now = chrono::Utc::now(); - let result = sqlx::query( + let result = query( r#" UPDATE vector_stores SET last_active_at = ?, updated_at = ? @@ -905,7 +905,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { .transpose() .map_err(|e| DbError::Internal(e.to_string()))?; - sqlx::query( + query( r#" INSERT INTO vector_store_files (id, vector_store_id, file_id, chunking_strategy, attributes, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?) @@ -937,7 +937,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { } async fn get_vector_store_file(&self, id: Uuid) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, vector_store_id, file_id, status, usage_bytes, last_error, chunking_strategy, attributes, created_at, updated_at @@ -960,7 +960,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { vector_store_id: Uuid, file_id: Uuid, ) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT id, vector_store_id, file_id, status, usage_bytes, last_error, chunking_strategy, attributes, created_at, updated_at @@ -989,7 +989,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { owner_type: VectorStoreOwnerType, owner_id: Uuid, ) -> DbResult> { - let result = sqlx::query( + let result = query( r#" SELECT cf.id, cf.vector_store_id, cf.file_id, cf.status, cf.usage_bytes, cf.last_error, cf.chunking_strategy, cf.attributes, cf.created_at, cf.updated_at @@ -1029,7 +1029,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { let (comparison, order, should_reverse) = params.sort_order.cursor_query_params(params.direction); - let query = format!( + let sql = format!( r#" SELECT id, vector_store_id, file_id, status, usage_bytes, last_error, chunking_strategy, attributes, created_at, updated_at @@ -1042,7 +1042,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { comparison, order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(vector_store_id.to_string()) .bind(cursor.created_at) .bind(cursor.id.to_string()) @@ -1074,7 +1074,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { // First page (no cursor) let order = params.sort_order.as_sql(); - let query = format!( + let sql = format!( r#" SELECT id, vector_store_id, file_id, status, usage_bytes, last_error, chunking_strategy, attributes, created_at, updated_at @@ -1086,7 +1086,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { order, order ); - let rows = sqlx::query(&query) + let rows = query(&sql) .bind(vector_store_id.to_string()) .bind(fetch_limit) .fetch_all(&self.pool) @@ -1122,7 +1122,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { .transpose() .map_err(|e| DbError::Internal(e.to_string()))?; - let result = sqlx::query( + let result = query( r#" UPDATE vector_store_files SET status = ?, last_error = ?, updated_at = ? @@ -1146,7 +1146,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { async fn update_vector_store_file_usage(&self, id: Uuid, usage_bytes: i64) -> DbResult<()> { let now = chrono::Utc::now(); - let result = sqlx::query( + let result = query( r#" UPDATE vector_store_files SET usage_bytes = ?, updated_at = ? @@ -1169,7 +1169,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { async fn remove_file_from_vector_store(&self, id: Uuid) -> DbResult<()> { let now = chrono::Utc::now(); - let result = sqlx::query( + let result = query( r#" UPDATE vector_store_files SET deleted_at = ?, updated_at = ? @@ -1193,7 +1193,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { &self, older_than: DateTime, ) -> DbResult> { - let rows = sqlx::query( + let rows = query( r#" SELECT id, vector_store_id, file_id, status, usage_bytes, last_error, chunking_strategy, attributes, created_at, updated_at @@ -1211,7 +1211,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { } async fn hard_delete_vector_store_file(&self, id: Uuid) -> DbResult<()> { - let result = sqlx::query( + let result = query( r#" DELETE FROM vector_store_files WHERE id = ? @@ -1229,7 +1229,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { } async fn hard_delete_soft_deleted_references(&self, file_id: Uuid) -> DbResult { - let result = sqlx::query( + let result = query( r#" DELETE FROM vector_store_files WHERE file_id = ? AND deleted_at IS NOT NULL @@ -1251,7 +1251,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { // Calculate aggregate stats from files (excluding soft-deleted) // SQLite doesn't have jsonb_build_object, so we build the JSON string manually - let stats = sqlx::query( + let stats = query( r#" SELECT COALESCE(SUM(usage_bytes), 0) as total_usage, @@ -1268,12 +1268,12 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { .fetch_one(&self.pool) .await?; - let total_usage: i64 = stats.get("total_usage"); - let cancelled: i32 = stats.get("cancelled"); - let completed: i32 = stats.get("completed"); - let failed: i32 = stats.get("failed"); - let in_progress: i32 = stats.get("in_progress"); - let total: i32 = stats.get("total"); + let total_usage: i64 = stats.col("total_usage"); + let cancelled: i32 = stats.col("cancelled"); + let completed: i32 = stats.col("completed"); + let failed: i32 = stats.col("failed"); + let in_progress: i32 = stats.col("in_progress"); + let total: i32 = stats.col("total"); let file_counts_json = serde_json::json!({ "cancelled": cancelled, @@ -1284,7 +1284,7 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { }) .to_string(); - sqlx::query( + query( r#" UPDATE vector_stores SET usage_bytes = ?, file_counts = ?, updated_at = ? @@ -1304,6 +1304,8 @@ impl VectorStoresRepo for SqliteVectorStoresRepo { #[cfg(test)] mod tests { + use sqlx::SqlitePool; + use super::*; use crate::models::VectorStoreOwner; diff --git a/src/db/wasm_sqlite/bridge.rs b/src/db/wasm_sqlite/bridge.rs new file mode 100644 index 0000000..68b6604 --- /dev/null +++ b/src/db/wasm_sqlite/bridge.rs @@ -0,0 +1,245 @@ +//! JavaScript FFI bridge to wa-sqlite running in the browser. +//! +//! `WasmSqlitePool` holds no Rust state — the actual SQLite database lives in +//! JavaScript (wa-sqlite + OPFS). All queries are dispatched via `wasm_bindgen` +//! extern functions and results are deserialized back into Rust types. + +use wasm_bindgen::prelude::*; + +use super::types::{WasmDbError, WasmParam, WasmQueryResult, WasmRow, WasmValue}; + +// ───────────────────────────────────────────────────────────────────────────── +// JS FFI declarations +// ───────────────────────────────────────────────────────────────────────────── + +#[wasm_bindgen] +extern "C" { + /// Execute a SELECT query. Returns a JSON-encoded array of row objects. + /// Each row is an object mapping column names to values. + #[wasm_bindgen(js_namespace = ["globalThis", "__hadrian_sqlite"], catch)] + async fn query(sql: &str, params: JsValue) -> Result; + + /// Execute a write statement (INSERT/UPDATE/DELETE). + /// Returns a JSON-encoded object with `{ changes: number, last_insert_rowid: number }`. + #[wasm_bindgen(js_namespace = ["globalThis", "__hadrian_sqlite"], catch)] + async fn execute(sql: &str, params: JsValue) -> Result; + + /// Initialize the database (create tables, run migrations). + #[wasm_bindgen(js_namespace = ["globalThis", "__hadrian_sqlite"], catch)] + async fn init_database() -> Result<(), JsValue>; + + /// Execute a multi-statement SQL script (e.g. migrations). + /// Uses sql.js `db.exec()` which handles multiple statements natively. + #[wasm_bindgen(js_namespace = ["globalThis", "__hadrian_sqlite"], catch)] + async fn execute_script(sql: &str) -> Result<(), JsValue>; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Pool +// ───────────────────────────────────────────────────────────────────────────── + +/// Handle to the wa-sqlite database running in JavaScript. +/// +/// This is a zero-size type — the actual database is managed by JS. Multiple +/// clones share the same underlying JS database instance. +#[derive(Debug, Clone)] +pub struct WasmSqlitePool; + +/// A no-op "connection" handle for WASM SQLite. +/// +/// In native SQLite, `pool.acquire()` returns a real connection handle used +/// for transactions (`BEGIN IMMEDIATE` … `COMMIT`). In WASM (single-threaded), +/// all queries go through the same JS bridge, so this is just a thin wrapper +/// that derefs back to `WasmSqlitePool`. +pub struct WasmPoolConnection(WasmSqlitePool); + +impl std::ops::Deref for WasmPoolConnection { + type Target = WasmSqlitePool; + fn deref(&self) -> &WasmSqlitePool { + &self.0 + } +} + +impl std::ops::DerefMut for WasmPoolConnection { + fn deref_mut(&mut self) -> &mut WasmSqlitePool { + &mut self.0 + } +} + +impl WasmSqlitePool { + pub fn new() -> Self { + Self + } + + /// Acquire a "connection" (no-op in WASM — mirrors `sqlx::SqlitePool::acquire()`). + pub async fn acquire(&self) -> Result { + Ok(WasmPoolConnection(self.clone())) + } + + /// Initialize the database (called once at startup). + pub async fn init(&self) -> Result<(), WasmDbError> { + init_database() + .await + .map_err(|e| WasmDbError::Query(js_error_to_string(&e))) + } + + /// Execute a SELECT query and return all matching rows. + pub async fn execute_query( + &self, + sql: &str, + params: &[WasmParam], + ) -> Result, WasmDbError> { + let js_params = + serde_wasm_bindgen::to_value(params).map_err(|e| WasmDbError::Query(e.to_string()))?; + + let result = query(sql, js_params) + .await + .map_err(|e| classify_js_error(&e))?; + + // The JS bridge returns an array of objects: [{ col1: val1, col2: val2, ... }, ...] + // We deserialize into Vec> via an intermediate representation. + let rows: Vec> = + serde_wasm_bindgen::from_value(result) + .map_err(|e| WasmDbError::Query(format!("Failed to deserialize rows: {e}")))?; + + Ok(rows + .into_iter() + .map(|obj| WasmRow { + columns: obj + .into_iter() + .map(|(k, v)| (k, json_to_wasm_value(v))) + .collect(), + }) + .collect()) + } + + /// Execute a write statement (INSERT/UPDATE/DELETE) and return the result. + pub async fn execute_statement( + &self, + sql: &str, + params: &[WasmParam], + ) -> Result { + let js_params = + serde_wasm_bindgen::to_value(params).map_err(|e| WasmDbError::Query(e.to_string()))?; + + let result = execute(sql, js_params) + .await + .map_err(|e| classify_js_error(&e))?; + + #[derive(serde::Deserialize)] + struct ExecResult { + changes: u64, + #[serde(default)] + last_insert_rowid: i64, + } + + let exec: ExecResult = serde_wasm_bindgen::from_value(result) + .map_err(|e| WasmDbError::Query(format!("Failed to deserialize exec result: {e}")))?; + + Ok(WasmQueryResult { + rows_affected: exec.changes, + last_insert_rowid: exec.last_insert_rowid, + }) + } + + /// Run the embedded SQLite migration SQL. + pub async fn run_migrations(&self) -> Result<(), WasmDbError> { + let migration_sql = + include_str!("../../../migrations_sqlx/sqlite/20250101000000_initial.sql"); + + // Create migrations tracking table + self.execute_statement( + "CREATE TABLE IF NOT EXISTS _wasm_migrations ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + applied_at TEXT NOT NULL DEFAULT (datetime('now')) + )", + &[], + ) + .await?; + + // Check if migration already applied + let rows = self + .execute_query( + "SELECT id FROM _wasm_migrations WHERE name = ?", + &[WasmParam::Text("20250101000000_initial".to_string())], + ) + .await?; + + if !rows.is_empty() { + tracing::debug!("WASM SQLite migrations already applied"); + return Ok(()); + } + + // Use execute_script to run the entire migration as one batch. + // sql.js's db.exec() handles multiple statements natively, avoiding + // issues with semicolons inside SQL comments. + execute_script(migration_sql) + .await + .map_err(|e| WasmDbError::Query(js_error_to_string(&e)))?; + + // Record migration + self.execute_statement( + "INSERT INTO _wasm_migrations (name) VALUES (?)", + &[WasmParam::Text("20250101000000_initial".to_string())], + ) + .await?; + + tracing::info!("WASM SQLite migrations applied successfully"); + Ok(()) + } +} + +impl Default for WasmSqlitePool { + fn default() -> Self { + Self::new() + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Helpers +// ───────────────────────────────────────────────────────────────────────────── + +/// Convert a serde_json::Value to a WasmValue. +fn json_to_wasm_value(v: serde_json::Value) -> WasmValue { + match v { + serde_json::Value::Null => WasmValue::Null, + serde_json::Value::Bool(b) => WasmValue::Integer(b as i64), + serde_json::Value::Number(n) => { + if let Some(i) = n.as_i64() { + WasmValue::Integer(i) + } else if let Some(f) = n.as_f64() { + WasmValue::Real(f) + } else { + WasmValue::Text(n.to_string()) + } + } + serde_json::Value::String(s) => WasmValue::Text(s), + other => WasmValue::Text(other.to_string()), + } +} + +/// Extract a human-readable error message from a JsValue. +fn js_error_to_string(e: &JsValue) -> String { + if let Some(s) = e.as_string() { + return s; + } + if let Some(err) = e.dyn_ref::() { + return err.message().into(); + } + format!("{e:?}") +} + +/// Classify a JS error into the appropriate WasmDbError variant. +fn classify_js_error(e: &JsValue) -> WasmDbError { + let msg = js_error_to_string(e); + + // SQLite constraint error messages + if msg.contains("UNIQUE constraint failed") { + WasmDbError::UniqueViolation(msg) + } else if msg.contains("FOREIGN KEY constraint failed") { + WasmDbError::ForeignKeyViolation(msg) + } else { + WasmDbError::Query(msg) + } +} diff --git a/src/db/wasm_sqlite/mod.rs b/src/db/wasm_sqlite/mod.rs new file mode 100644 index 0000000..6ff009e --- /dev/null +++ b/src/db/wasm_sqlite/mod.rs @@ -0,0 +1,33 @@ +//! WASM SQLite database layer. +//! +//! Provides a sqlx-compatible API surface backed by wa-sqlite running in the +//! browser via JavaScript FFI. The repository implementations live in +//! `src/db/sqlite/` and are shared with native SQLite via the backend +//! abstraction layer (`src/db/sqlite/backend.rs`). +//! +//! This module only exports the FFI bridge and core types. +//! +//! # Architecture +//! +//! ```text +//! Rust repo code → WasmSqlitePool::query() → JS bridge → wa-sqlite → OPFS +//! ``` +//! +//! The API mirrors sqlx's runtime query builder: +//! - `query(sql)` → `WasmQuery` with `.bind()` chaining +//! - `.fetch_all(pool)` / `.fetch_optional(pool)` / `.execute(pool)` +//! - `WasmRow` with `.get::(column)` for type-safe column access + +pub(crate) mod bridge; +pub(crate) mod types; + +pub use bridge::WasmSqlitePool; +pub use types::{ + WasmDbError, WasmDecode, WasmParam, WasmQuery, WasmQueryResult, WasmQueryScalar, WasmRow, + WasmValue, +}; + +/// Create a query builder (analogous to `sqlx::query()`). +pub fn query(sql: &str) -> WasmQuery { + WasmQuery::new(sql) +} diff --git a/src/db/wasm_sqlite/types.rs b/src/db/wasm_sqlite/types.rs new file mode 100644 index 0000000..21e580a --- /dev/null +++ b/src/db/wasm_sqlite/types.rs @@ -0,0 +1,736 @@ +//! Core types for the WASM SQLite database layer. + +use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc}; +use rust_decimal::Decimal; +use thiserror::Error; +use uuid::Uuid; + +use super::WasmSqlitePool; + +// ───────────────────────────────────────────────────────────────────────────── +// Error +// ───────────────────────────────────────────────────────────────────────────── + +#[derive(Debug, Error)] +pub enum WasmDbError { + #[error("WASM SQLite error: {0}")] + Query(String), + #[error("Row not found")] + RowNotFound, + #[error("Column not found: {0}")] + ColumnNotFound(String), + #[error("Type mismatch for column {column}: expected {expected}, got {actual}")] + TypeMismatch { + column: String, + expected: &'static str, + actual: String, + }, + #[error("Unique constraint violation: {0}")] + UniqueViolation(String), + #[error("Foreign key constraint violation: {0}")] + ForeignKeyViolation(String), + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), +} + +impl WasmDbError { + pub fn is_unique_violation(&self) -> bool { + matches!(self, Self::UniqueViolation(_)) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Bind parameters +// ───────────────────────────────────────────────────────────────────────────── + +/// A bindable parameter value for SQLite queries. +#[derive(Debug, Clone, serde::Serialize)] +#[serde(untagged)] +pub enum WasmParam { + Null, + Text(String), + Integer(i64), + Real(f64), + Bool(bool), +} + +// Conversions into WasmParam +impl From for WasmParam { + fn from(v: String) -> Self { + Self::Text(v) + } +} + +impl From<&str> for WasmParam { + fn from(v: &str) -> Self { + Self::Text(v.to_string()) + } +} + +impl From<&String> for WasmParam { + fn from(v: &String) -> Self { + Self::Text(v.clone()) + } +} + +impl From for WasmParam { + fn from(v: i64) -> Self { + Self::Integer(v) + } +} + +impl From for WasmParam { + fn from(v: i32) -> Self { + Self::Integer(v as i64) + } +} + +impl From for WasmParam { + fn from(v: u32) -> Self { + Self::Integer(v as i64) + } +} + +impl From for WasmParam { + fn from(v: u16) -> Self { + Self::Integer(v as i64) + } +} + +impl From for WasmParam { + fn from(v: i16) -> Self { + Self::Integer(v as i64) + } +} + +impl From<&Option> for WasmParam { + fn from(v: &Option) -> Self { + match v { + Some(i) => Self::Integer(*i as i64), + None => Self::Null, + } + } +} + +impl From for WasmParam { + fn from(v: f64) -> Self { + Self::Real(v) + } +} + +impl From for WasmParam { + fn from(v: bool) -> Self { + Self::Bool(v) + } +} + +impl From for WasmParam { + fn from(v: Uuid) -> Self { + Self::Text(v.to_string()) + } +} + +impl From> for WasmParam { + fn from(v: DateTime) -> Self { + Self::Text(v.to_rfc3339()) + } +} + +impl From for WasmParam { + fn from(v: NaiveDateTime) -> Self { + Self::Text(v.format("%Y-%m-%dT%H:%M:%S%.f").to_string()) + } +} + +impl From for WasmParam { + fn from(v: NaiveDate) -> Self { + Self::Text(v.format("%Y-%m-%d").to_string()) + } +} + +impl From for WasmParam { + fn from(v: Decimal) -> Self { + Self::Text(v.to_string()) + } +} + +impl> From> for WasmParam { + fn from(v: Option) -> Self { + match v { + Some(inner) => inner.into(), + None => Self::Null, + } + } +} + +impl From<&Option> for WasmParam { + fn from(v: &Option) -> Self { + match v { + Some(s) => Self::Text(s.clone()), + None => Self::Null, + } + } +} + +impl From<&Uuid> for WasmParam { + fn from(v: &Uuid) -> Self { + Self::Text(v.to_string()) + } +} + +impl From<&bool> for WasmParam { + fn from(v: &bool) -> Self { + Self::Bool(*v) + } +} + +impl From<&i64> for WasmParam { + fn from(v: &i64) -> Self { + Self::Integer(*v) + } +} + +impl From<&i32> for WasmParam { + fn from(v: &i32) -> Self { + Self::Integer(*v as i64) + } +} + +impl From<&f64> for WasmParam { + fn from(v: &f64) -> Self { + Self::Real(*v) + } +} + +impl From<&Decimal> for WasmParam { + fn from(v: &Decimal) -> Self { + Self::Text(v.to_string()) + } +} + +impl From<&Option> for WasmParam { + fn from(v: &Option) -> Self { + match v { + Some(i) => Self::Integer(*i as i64), + None => Self::Null, + } + } +} + +impl From<&Option> for WasmParam { + fn from(v: &Option) -> Self { + match v { + Some(i) => Self::Integer(*i), + None => Self::Null, + } + } +} + +impl From<&Option> for WasmParam { + fn from(v: &Option) -> Self { + match v { + Some(f) => Self::Real(*f), + None => Self::Null, + } + } +} + +impl From<&Option> for WasmParam { + fn from(v: &Option) -> Self { + match v { + Some(b) => Self::Bool(*b), + None => Self::Null, + } + } +} + +impl From<&Option> for WasmParam { + fn from(v: &Option) -> Self { + match v { + Some(u) => Self::Text(u.to_string()), + None => Self::Null, + } + } +} + +impl From<&Option>> for WasmParam { + fn from(v: &Option>) -> Self { + match v { + Some(dt) => Self::Text(dt.to_rfc3339()), + None => Self::Null, + } + } +} + +impl From<&Option> for WasmParam { + fn from(v: &Option) -> Self { + match v { + Some(d) => Self::Text(d.to_string()), + None => Self::Null, + } + } +} + +impl From<&Option> for WasmParam { + fn from(v: &Option) -> Self { + match v { + Some(j) => Self::Text(j.to_string()), + None => Self::Null, + } + } +} + +impl From<&serde_json::Value> for WasmParam { + fn from(v: &serde_json::Value) -> Self { + Self::Text(v.to_string()) + } +} + +impl From> for WasmParam { + fn from(v: Vec) -> Self { + use base64::Engine; + Self::Text(base64::engine::general_purpose::STANDARD.encode(&v)) + } +} + +impl From<&Option>> for WasmParam { + fn from(v: &Option>) -> Self { + use base64::Engine; + match v { + Some(bytes) => Self::Text(base64::engine::general_purpose::STANDARD.encode(bytes)), + None => Self::Null, + } + } +} + +impl From<&DateTime> for WasmParam { + fn from(v: &DateTime) -> Self { + Self::Text(v.to_rfc3339()) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Query builder +// ───────────────────────────────────────────────────────────────────────────── + +/// A scalar query builder that returns a single column decoded as `T`. +/// +/// Mimics `sqlx::QueryScalar` — use `.bind()` to add parameters, then +/// `.fetch_all()` to run and extract the first column from each row. +pub struct WasmQueryScalar { + inner: WasmQuery, + _marker: std::marker::PhantomData, +} + +impl WasmQueryScalar { + pub fn new(sql: &str) -> Self { + Self { + inner: WasmQuery::new(sql), + _marker: std::marker::PhantomData, + } + } + + pub fn bind>(mut self, value: V) -> Self { + self.inner = self.inner.bind(value); + self + } + + pub async fn fetch_all(self, pool: &WasmSqlitePool) -> Result, WasmDbError> { + let rows = self.inner.fetch_all(pool).await?; + rows.into_iter() + .map(|row| { + let (_, value) = + row.columns.into_iter().next().ok_or_else(|| { + WasmDbError::Query("No columns in scalar result".to_string()) + })?; + T::decode(&value, "scalar") + }) + .collect() + } +} + +/// A query builder that accumulates bind parameters. +/// +/// Mimics `sqlx::Query` — use `.bind()` to add parameters, then +/// `.fetch_all()`, `.fetch_optional()`, or `.execute()` to run. +pub struct WasmQuery { + sql: String, + params: Vec, +} + +impl WasmQuery { + pub fn new(sql: &str) -> Self { + Self { + sql: sql.to_string(), + params: Vec::new(), + } + } + + /// Bind a parameter value. + pub fn bind>(mut self, value: T) -> Self { + self.params.push(value.into()); + self + } + + /// Execute and return all rows. + pub async fn fetch_all(self, pool: &WasmSqlitePool) -> Result, WasmDbError> { + pool.execute_query(&self.sql, &self.params).await + } + + /// Execute and return the first row, or None. + pub async fn fetch_optional( + self, + pool: &WasmSqlitePool, + ) -> Result, WasmDbError> { + let rows = pool.execute_query(&self.sql, &self.params).await?; + Ok(rows.into_iter().next()) + } + + /// Execute and return the first row, or error if not found. + pub async fn fetch_one(self, pool: &WasmSqlitePool) -> Result { + self.fetch_optional(pool) + .await? + .ok_or(WasmDbError::RowNotFound) + } + + /// Execute a statement (INSERT/UPDATE/DELETE) and return affected row count. + pub async fn execute(self, pool: &WasmSqlitePool) -> Result { + pool.execute_statement(&self.sql, &self.params).await + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Query result (for INSERT/UPDATE/DELETE) +// ───────────────────────────────────────────────────────────────────────────── + +/// Result of an execute() call. +#[derive(Debug, Default)] +pub struct WasmQueryResult { + pub rows_affected: u64, + pub last_insert_rowid: i64, +} + +impl WasmQueryResult { + pub fn rows_affected(&self) -> u64 { + self.rows_affected + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Row and value types +// ───────────────────────────────────────────────────────────────────────────── + +/// A database value from a SQLite column. +#[derive(Debug, Clone, serde::Deserialize)] +#[serde(untagged)] +pub enum WasmValue { + Null, + Integer(i64), + Real(f64), + Text(String), +} + +/// A single row returned from a query. +#[derive(Debug, Clone)] +pub struct WasmRow { + pub(crate) columns: Vec<(String, WasmValue)>, +} + +impl WasmRow { + /// Get a typed value by column name. + /// + /// Mirrors `sqlx::Row::get()`. Panics if column not found or type mismatch + /// (matching sqlx behavior). + pub fn get(&self, col: &str) -> T { + self.try_get(col) + .unwrap_or_else(|e| panic!("Row::get({col}): {e}")) + } + + /// Try to get a typed value by column name. + pub fn try_get(&self, col: &str) -> Result { + let value = self + .columns + .iter() + .find(|(name, _)| name == col) + .map(|(_, v)| v) + .ok_or_else(|| WasmDbError::ColumnNotFound(col.to_string()))?; + T::decode(value, col) + } +} + +/// Trait for decoding a `WasmValue` into a Rust type. +pub trait WasmDecode: Sized { + fn decode(value: &WasmValue, col: &str) -> Result; +} + +impl WasmDecode for String { + fn decode(value: &WasmValue, col: &str) -> Result { + match value { + WasmValue::Text(s) => Ok(s.clone()), + WasmValue::Integer(i) => Ok(i.to_string()), + WasmValue::Real(f) => Ok(f.to_string()), + WasmValue::Null => Err(WasmDbError::TypeMismatch { + column: col.to_string(), + expected: "String", + actual: "NULL".to_string(), + }), + } + } +} + +impl WasmDecode for Option { + fn decode(value: &WasmValue, _col: &str) -> Result { + match value { + WasmValue::Text(s) => Ok(Some(s.clone())), + WasmValue::Null => Ok(None), + WasmValue::Integer(i) => Ok(Some(i.to_string())), + WasmValue::Real(f) => Ok(Some(f.to_string())), + } + } +} + +impl WasmDecode for i64 { + fn decode(value: &WasmValue, col: &str) -> Result { + match value { + WasmValue::Integer(i) => Ok(*i), + WasmValue::Real(f) => Ok(*f as i64), + _ => Err(WasmDbError::TypeMismatch { + column: col.to_string(), + expected: "i64", + actual: format!("{value:?}"), + }), + } + } +} + +impl WasmDecode for Option { + fn decode(value: &WasmValue, _col: &str) -> Result { + match value { + WasmValue::Integer(i) => Ok(Some(*i)), + WasmValue::Real(f) => Ok(Some(*f as i64)), + WasmValue::Null => Ok(None), + _ => Ok(None), + } + } +} + +impl WasmDecode for i32 { + fn decode(value: &WasmValue, col: &str) -> Result { + i64::decode(value, col).map(|v| v as i32) + } +} + +impl WasmDecode for Option { + fn decode(value: &WasmValue, col: &str) -> Result { + match value { + WasmValue::Null => Ok(None), + _ => i32::decode(value, col).map(Some), + } + } +} + +impl WasmDecode for bool { + fn decode(value: &WasmValue, col: &str) -> Result { + match value { + WasmValue::Integer(i) => Ok(*i != 0), + WasmValue::Null => Ok(false), + _ => Err(WasmDbError::TypeMismatch { + column: col.to_string(), + expected: "bool", + actual: format!("{value:?}"), + }), + } + } +} + +impl WasmDecode for Option { + fn decode(value: &WasmValue, _col: &str) -> Result { + match value { + WasmValue::Integer(i) => Ok(Some(*i != 0)), + WasmValue::Null => Ok(None), + _ => Ok(None), + } + } +} + +impl WasmDecode for f64 { + fn decode(value: &WasmValue, col: &str) -> Result { + match value { + WasmValue::Real(f) => Ok(*f), + WasmValue::Integer(i) => Ok(*i as f64), + _ => Err(WasmDbError::TypeMismatch { + column: col.to_string(), + expected: "f64", + actual: format!("{value:?}"), + }), + } + } +} + +impl WasmDecode for NaiveDate { + fn decode(value: &WasmValue, col: &str) -> Result { + match value { + WasmValue::Text(s) => { + NaiveDate::parse_from_str(s, "%Y-%m-%d").map_err(|_| WasmDbError::TypeMismatch { + column: col.to_string(), + expected: "NaiveDate", + actual: s.clone(), + }) + } + _ => Err(WasmDbError::TypeMismatch { + column: col.to_string(), + expected: "NaiveDate", + actual: format!("{value:?}"), + }), + } + } +} + +impl WasmDecode for DateTime { + fn decode(value: &WasmValue, col: &str) -> Result { + match value { + WasmValue::Text(s) => { + // Try RFC 3339 first, then common SQLite formats + DateTime::parse_from_rfc3339(s) + .map(|dt| dt.with_timezone(&Utc)) + .or_else(|_| { + NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") + .or_else(|_| NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f")) + .map(|ndt| ndt.and_utc()) + }) + .map_err(|_| WasmDbError::TypeMismatch { + column: col.to_string(), + expected: "DateTime", + actual: s.clone(), + }) + } + _ => Err(WasmDbError::TypeMismatch { + column: col.to_string(), + expected: "DateTime", + actual: format!("{value:?}"), + }), + } + } +} + +impl WasmDecode for Option> { + fn decode(value: &WasmValue, col: &str) -> Result { + match value { + WasmValue::Null => Ok(None), + _ => DateTime::::decode(value, col).map(Some), + } + } +} + +impl WasmDecode for Uuid { + fn decode(value: &WasmValue, col: &str) -> Result { + match value { + WasmValue::Text(s) => Uuid::parse_str(s).map_err(|_| WasmDbError::TypeMismatch { + column: col.to_string(), + expected: "UUID", + actual: s.clone(), + }), + _ => Err(WasmDbError::TypeMismatch { + column: col.to_string(), + expected: "UUID", + actual: format!("{value:?}"), + }), + } + } +} + +impl WasmDecode for Option { + fn decode(value: &WasmValue, col: &str) -> Result { + match value { + WasmValue::Null => Ok(None), + _ => Uuid::decode(value, col).map(Some), + } + } +} + +impl WasmDecode for Decimal { + fn decode(value: &WasmValue, col: &str) -> Result { + match value { + WasmValue::Text(s) => s.parse::().map_err(|_| WasmDbError::TypeMismatch { + column: col.to_string(), + expected: "Decimal", + actual: s.clone(), + }), + WasmValue::Integer(i) => Ok(Decimal::from(*i)), + WasmValue::Real(f) => Decimal::try_from(*f).map_err(|_| WasmDbError::TypeMismatch { + column: col.to_string(), + expected: "Decimal", + actual: f.to_string(), + }), + WasmValue::Null => Err(WasmDbError::TypeMismatch { + column: col.to_string(), + expected: "Decimal", + actual: "NULL".to_string(), + }), + } + } +} + +impl WasmDecode for Option { + fn decode(value: &WasmValue, col: &str) -> Result { + match value { + WasmValue::Null => Ok(None), + _ => Decimal::decode(value, col).map(Some), + } + } +} + +impl WasmDecode for Vec { + fn decode(value: &WasmValue, col: &str) -> Result { + use base64::Engine; + match value { + WasmValue::Text(s) => { + base64::engine::general_purpose::STANDARD + .decode(s) + .map_err(|_| WasmDbError::TypeMismatch { + column: col.to_string(), + expected: "Vec (base64)", + actual: s.clone(), + }) + } + _ => Err(WasmDbError::TypeMismatch { + column: col.to_string(), + expected: "Vec (base64)", + actual: format!("{value:?}"), + }), + } + } +} + +impl WasmDecode for Option> { + fn decode(value: &WasmValue, col: &str) -> Result { + match value { + WasmValue::Null => Ok(None), + _ => Vec::::decode(value, col).map(Some), + } + } +} + +impl WasmDecode for serde_json::Value { + fn decode(value: &WasmValue, col: &str) -> Result { + match value { + WasmValue::Text(s) => serde_json::from_str(s).map_err(|_| WasmDbError::TypeMismatch { + column: col.to_string(), + expected: "JSON", + actual: s.clone(), + }), + WasmValue::Null => Ok(serde_json::Value::Null), + WasmValue::Integer(i) => Ok(serde_json::json!(*i)), + WasmValue::Real(f) => Ok(serde_json::json!(*f)), + } + } +} + +impl WasmDecode for Option { + fn decode(value: &WasmValue, col: &str) -> Result { + match value { + WasmValue::Null => Ok(None), + _ => serde_json::Value::decode(value, col).map(Some), + } + } +} diff --git a/src/dlq/database.rs b/src/dlq/database.rs index a2334e9..f0e9899 100644 --- a/src/dlq/database.rs +++ b/src/dlq/database.rs @@ -34,7 +34,8 @@ impl DatabaseDlq { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl DeadLetterQueue for DatabaseDlq { async fn push(&self, entry: DlqEntry) -> DlqResult<()> { let metadata = serde_json::to_string(&entry.metadata) diff --git a/src/dlq/file.rs b/src/dlq/file.rs index ca9d4d1..e8c3906 100644 --- a/src/dlq/file.rs +++ b/src/dlq/file.rs @@ -139,7 +139,8 @@ impl FileDlq { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl DeadLetterQueue for FileDlq { async fn push(&self, entry: DlqEntry) -> DlqResult<()> { // Write to disk first diff --git a/src/dlq/mod.rs b/src/dlq/mod.rs index d4e4337..5bbd80f 100644 --- a/src/dlq/mod.rs +++ b/src/dlq/mod.rs @@ -1,6 +1,7 @@ #[cfg(any(feature = "database-sqlite", feature = "database-postgres"))] mod database; mod error; +#[cfg(feature = "server")] mod file; #[cfg(feature = "redis")] mod redis; @@ -12,6 +13,7 @@ use std::sync::Arc; #[cfg(any(feature = "database-sqlite", feature = "database-postgres"))] pub use database::DatabaseDlq; pub use error::{DlqError, DlqResult}; +#[cfg(feature = "server")] pub use file::FileDlq; #[cfg(feature = "redis")] pub use redis::RedisDlq; @@ -34,12 +36,19 @@ pub async fn create_dlq( }; let dlq: Arc = match config { + #[cfg(feature = "server")] DeadLetterQueueConfig::File { path, max_file_size_mb, max_files, .. } => Arc::new(FileDlq::new(path, *max_file_size_mb, *max_files).await?), + #[cfg(not(feature = "server"))] + DeadLetterQueueConfig::File { .. } => { + return Err(DlqError::Internal( + "File DLQ configured but the 'server' feature is not enabled.".to_string(), + )); + } #[cfg(feature = "redis")] DeadLetterQueueConfig::Redis { diff --git a/src/dlq/redis.rs b/src/dlq/redis.rs index 73c5709..58bd8ed 100644 --- a/src/dlq/redis.rs +++ b/src/dlq/redis.rs @@ -163,7 +163,8 @@ impl RedisDlq { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl DeadLetterQueue for RedisDlq { async fn push(&self, entry: DlqEntry) -> DlqResult<()> { let mut conn = self.conn().await?; diff --git a/src/dlq/traits.rs b/src/dlq/traits.rs index 347209d..2e8cde6 100644 --- a/src/dlq/traits.rs +++ b/src/dlq/traits.rs @@ -210,7 +210,8 @@ impl DlqListParams { /// Dead-letter queue trait for storing failed operations. /// /// Implementations must be thread-safe and support concurrent access. -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait DeadLetterQueue: Send + Sync { /// Push an entry to the dead-letter queue. async fn push(&self, entry: DlqEntry) -> DlqResult<()>; diff --git a/src/guardrails/azure.rs b/src/guardrails/azure.rs index 6ce5afa..a8511a7 100644 --- a/src/guardrails/azure.rs +++ b/src/guardrails/azure.rs @@ -193,7 +193,8 @@ impl AzureContentSafetyProvider { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl GuardrailsProvider for AzureContentSafetyProvider { fn name(&self) -> &str { "azure_content_safety" diff --git a/src/guardrails/bedrock.rs b/src/guardrails/bedrock.rs index f497206..98d4a09 100644 --- a/src/guardrails/bedrock.rs +++ b/src/guardrails/bedrock.rs @@ -286,7 +286,8 @@ impl BedrockGuardrailsProvider { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl GuardrailsProvider for BedrockGuardrailsProvider { fn name(&self) -> &str { "bedrock" diff --git a/src/guardrails/blocklist.rs b/src/guardrails/blocklist.rs index cc0742b..40a81cf 100644 --- a/src/guardrails/blocklist.rs +++ b/src/guardrails/blocklist.rs @@ -131,7 +131,8 @@ impl BlocklistProvider { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl GuardrailsProvider for BlocklistProvider { fn name(&self) -> &str { "blocklist" diff --git a/src/guardrails/content_limits.rs b/src/guardrails/content_limits.rs index e2eefdb..7df3f64 100644 --- a/src/guardrails/content_limits.rs +++ b/src/guardrails/content_limits.rs @@ -63,7 +63,8 @@ impl ContentLimitsProvider { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl GuardrailsProvider for ContentLimitsProvider { fn name(&self) -> &str { "content_limits" diff --git a/src/guardrails/custom.rs b/src/guardrails/custom.rs index 05426ff..46d4d45 100644 --- a/src/guardrails/custom.rs +++ b/src/guardrails/custom.rs @@ -137,7 +137,8 @@ impl CustomHttpProvider { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl GuardrailsProvider for CustomHttpProvider { fn name(&self) -> &str { "custom" diff --git a/src/guardrails/error.rs b/src/guardrails/error.rs index 993cb9b..a5ba4a0 100644 --- a/src/guardrails/error.rs +++ b/src/guardrails/error.rs @@ -149,7 +149,11 @@ impl GuardrailsError { /// Creates a provider error from a reqwest error. pub fn from_reqwest(provider: impl Into, err: reqwest::Error) -> Self { - let retryable = err.is_timeout() || err.is_connect(); + let mut retryable = err.is_timeout(); + #[cfg(not(target_arch = "wasm32"))] + { + retryable = retryable || err.is_connect(); + } Self::ProviderError { message: err.to_string(), provider: provider.into(), diff --git a/src/guardrails/mod.rs b/src/guardrails/mod.rs index fa061c7..e6992f0 100644 --- a/src/guardrails/mod.rs +++ b/src/guardrails/mod.rs @@ -176,7 +176,8 @@ pub fn inject_trace_context(headers: &mut HashMap) { /// } /// } /// ``` -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait GuardrailsProvider: Send + Sync { /// Returns the name of this provider (e.g., "openai", "bedrock", "azure"). fn name(&self) -> &str; diff --git a/src/guardrails/openai.rs b/src/guardrails/openai.rs index 0e6a053..f192a92 100644 --- a/src/guardrails/openai.rs +++ b/src/guardrails/openai.rs @@ -108,7 +108,8 @@ impl OpenAIModerationProvider { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl GuardrailsProvider for OpenAIModerationProvider { fn name(&self) -> &str { "openai_moderation" diff --git a/src/guardrails/pii_regex.rs b/src/guardrails/pii_regex.rs index b0649fa..c281d35 100644 --- a/src/guardrails/pii_regex.rs +++ b/src/guardrails/pii_regex.rs @@ -256,7 +256,8 @@ impl PiiRegexProvider { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl GuardrailsProvider for PiiRegexProvider { fn name(&self) -> &str { "pii_regex" diff --git a/src/guardrails/streaming.rs b/src/guardrails/streaming.rs index 9053f09..04914d2 100644 --- a/src/guardrails/streaming.rs +++ b/src/guardrails/streaming.rs @@ -253,7 +253,7 @@ where let user_id = self.config.user_id.clone(); let on_error = self.config.on_error.clone(); - tokio::spawn(async move { + crate::compat::spawn_detached(async move { evaluate_buffered_content( state, provider, @@ -299,7 +299,7 @@ where let user_id = self.config.user_id.clone(); let on_error = self.config.on_error.clone(); - tokio::spawn(async move { + crate::compat::spawn_detached(async move { evaluate_chunk_content( state, provider, @@ -336,7 +336,7 @@ where let on_error = self.config.on_error.clone(); let start_time = self.start_time; - tokio::spawn(async move { + crate::compat::spawn_detached(async move { evaluate_final_content( state, provider, diff --git a/src/jobs/provider_health_check.rs b/src/jobs/provider_health_check.rs index ba116e0..25f0cbb 100644 --- a/src/jobs/provider_health_check.rs +++ b/src/jobs/provider_health_check.rs @@ -376,6 +376,7 @@ impl ProviderHealthChecker { /// until all tasks complete (which is never under normal operation). /// /// If no providers are registered, this returns immediately. + #[cfg(not(target_arch = "wasm32"))] pub async fn start(self) { if self.providers.is_empty() { tracing::info!("No providers registered for health checks"); diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..1162691 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,49 @@ +mod api_types; +pub mod app; +pub mod auth; +pub mod authz; +pub mod cache; +pub mod catalog; +#[cfg(feature = "cli")] +pub mod cli; +pub mod compat; +pub mod config; +pub mod db; +pub mod dlq; +pub mod events; +pub mod guardrails; +pub mod init; +pub mod jobs; +pub mod middleware; +pub mod models; +pub mod observability; +pub mod ontology; +pub mod openapi; +pub mod pricing; +pub mod providers; +pub mod retention; +pub mod routes; +pub mod routing; +#[cfg(feature = "sso")] +pub mod scim; +pub mod secrets; +pub mod services; +pub mod streaming; +pub mod usage_buffer; +pub mod usage_sink; +pub mod validation; +#[cfg(feature = "wizard")] +pub mod wizard; + +#[cfg(feature = "wasm")] +pub mod wasm; + +#[cfg(test)] +mod tests; + +// Re-export items that other modules reference via `crate::`. +pub use app::AppState; +#[cfg(feature = "server")] +pub use app::build_app; +#[cfg(feature = "wizard")] +pub(crate) use cli::{default_config_path, default_data_dir}; diff --git a/src/main.rs b/src/main.rs index 2004649..93cc2d9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,49 +1,17 @@ -mod api_types; -mod app; -mod auth; -pub mod authz; -mod cache; -mod catalog; -mod cli; -mod config; -mod db; -mod dlq; -pub mod events; -mod guardrails; -mod init; -mod jobs; -mod middleware; -mod models; -pub mod observability; -mod ontology; -pub mod openapi; -mod pricing; -mod providers; -mod retention; -mod routes; -mod routing; -#[cfg(feature = "sso")] -pub mod scim; -mod secrets; -pub mod services; -mod streaming; -mod usage_buffer; -mod usage_sink; -mod validation; -#[cfg(feature = "wizard")] -mod wizard; - -#[cfg(test)] -mod tests; - -// Re-export items that other modules reference via `crate::`. -pub use app::{AppState, build_app}; +#[cfg(feature = "cli")] use clap::Parser; -#[cfg(feature = "wizard")] -pub(crate) use cli::{default_config_path, default_data_dir}; +#[cfg(feature = "cli")] #[tokio::main] async fn main() { - let args = cli::Args::parse(); - cli::dispatch(args).await; + let args = hadrian::cli::Args::parse(); + hadrian::cli::dispatch(args).await; +} + +#[cfg(not(feature = "cli"))] +fn main() { + eprintln!( + "The CLI feature is not enabled. Build with --features cli (or server/tiny/minimal/standard/full)." + ); + std::process::exit(1); } diff --git a/src/middleware/layers/admin.rs b/src/middleware/layers/admin.rs index 6c95f66..bfe89d7 100644 --- a/src/middleware/layers/admin.rs +++ b/src/middleware/layers/admin.rs @@ -16,8 +16,10 @@ use std::net::IpAddr; +#[cfg(feature = "server")] +use axum::extract::ConnectInfo; use axum::{ - extract::{ConnectInfo, Request, State}, + extract::{Request, State}, middleware::Next, response::Response, }; @@ -27,18 +29,11 @@ use uuid::Uuid; use crate::{ AppState, auth::{AuthError, AuthenticatedRequest, Identity, IdentityKind}, - middleware::{ClientInfo, RequestId}, + middleware::{AdminAuth, ClientInfo, RequestId}, observability::metrics, services::audit_logs::{AuthEventParams, auth_events}, }; -/// Admin authentication result. -#[derive(Debug, Clone)] -pub struct AdminAuth { - /// The authenticated identity - pub identity: Identity, -} - /// Middleware that requires admin authentication. /// This will reject requests without valid Proxy auth headers or OIDC session. pub async fn admin_auth_middleware( @@ -59,10 +54,13 @@ pub async fn admin_auth_middleware( let cookies = req.extensions().get::().cloned(); // Extract connecting IP for trusted proxy validation + #[cfg(feature = "server")] let connecting_ip = req .extensions() .get::>() .map(|ci| ci.0.ip()); + #[cfg(not(feature = "server"))] + let connecting_ip: Option = None; let client_info = ClientInfo { ip_address: connecting_ip.map(|ip| ip.to_string()), @@ -2328,7 +2326,7 @@ mod tests { /// Create a minimal AppState for testing with ProxyAuth config fn create_test_state(identity_header: &str, trusted_proxies: TrustedProxiesConfig) -> AppState { // Create minimal config from empty TOML - let mut config = GatewayConfig::from_str("").unwrap(); + let mut config = GatewayConfig::parse("").unwrap(); config.auth.mode = AuthMode::Iap(Box::new(IapConfig { identity_header: identity_header.to_string(), email_header: Some("X-Email".to_string()), @@ -2637,7 +2635,7 @@ mod tests { /// Create a minimal AppState for testing with Emergency config fn create_emergency_test_state(emergency_config: Option) -> AppState { - let mut config = GatewayConfig::from_str("").unwrap(); + let mut config = GatewayConfig::parse("").unwrap(); config.auth.emergency = emergency_config; AppState { diff --git a/src/middleware/layers/api.rs b/src/middleware/layers/api.rs index e722c92..a802c36 100644 --- a/src/middleware/layers/api.rs +++ b/src/middleware/layers/api.rs @@ -1,28 +1,30 @@ use std::{net::IpAddr, sync::Arc, time::Duration}; +#[cfg(feature = "server")] +use axum::extract::ConnectInfo; use axum::{ - extract::{ConnectInfo, Request, State}, + extract::{Request, State}, middleware::Next, response::{IntoResponse, Response}, }; use chrono::Utc; -use super::{ - rate_limit::{ - RateLimitError, TokenRateLimitCheckResult, TokenRateLimitResult, TokenReservation, - add_rate_limit_headers, add_token_rate_limit_headers, adjust_token_reservation, - }, - request_id::RequestId, +use super::rate_limit::{ + RateLimitError, TokenRateLimitCheckResult, TokenRateLimitResult, TokenReservation, + add_rate_limit_headers, add_token_rate_limit_headers, adjust_token_reservation, }; use crate::{ AppState, auth::{ApiKeyAuth, AuthError, AuthenticatedRequest, Identity, IdentityKind}, cache::{BudgetCheckParams, Cache, CacheKeys, RateLimitCheckParams, RateLimitResult}, events::{BudgetType, ServerEvent}, - middleware::util::{ - budget::{BudgetCheckResult, BudgetError, adjust_budget_reservation}, - scope::required_scope_for_path, - usage::{UsageTracker, extract_full_usage_from_response, tracker_from_headers}, + middleware::{ + RequestId, + util::{ + budget::{BudgetCheckResult, BudgetError, adjust_budget_reservation}, + scope::required_scope_for_path, + usage::{UsageTracker, extract_full_usage_from_response, tracker_from_headers}, + }, }, models::{AuditActorType, BudgetPeriod, CreateAuditLog, has_valid_prefix, hash_api_key}, observability::metrics, @@ -580,10 +582,13 @@ pub async fn api_middleware( req.extensions_mut().insert(tracker.clone()); // Extract connecting IP for trusted proxy validation + #[cfg(feature = "server")] let connecting_ip = req .extensions() .get::>() .map(|ci| ci.0.ip()); + #[cfg(not(feature = "server"))] + let connecting_ip: Option = None; // Insert client info for audit logging let client_info = crate::middleware::ClientInfo { @@ -877,94 +882,97 @@ pub async fn api_middleware( token_reservation, header_project_id, }); - } else if let Some(buffer) = &state.usage_buffer { - // Track anonymous usage when auth is disabled (local dev / no-auth mode). - // Attribute to the default anonymous user/org created on startup. - let has_model = response.headers().contains_key("X-Model"); - let is_streaming = response - .headers() - .get(http::header::CONTENT_TYPE) - .and_then(|v| v.to_str().ok()) - .is_some_and(|s| s.contains("text/event-stream")) - || response - .headers() - .get("Transfer-Encoding") - .and_then(|v| v.to_str().ok()) - .is_some_and(|s| s.contains("chunked")); - - // Only track LLM requests (those with X-Model header). - // Skip streaming responses here — UsageTrackingStream handles them - // with actual token counts after the stream completes. - if has_model && !is_streaming { - let usage = extract_full_usage_from_response(&response); - - let model = response - .headers() - .get("X-Model") - .and_then(|v| v.to_str().ok()) - .map(String::from) - .or(tracker.model) - .unwrap_or_else(|| "unknown".to_string()); - let provider = response + } else { + #[cfg(feature = "concurrency")] + if let Some(buffer) = &state.usage_buffer { + // Track anonymous usage when auth is disabled (local dev / no-auth mode). + // Attribute to the default anonymous user/org created on startup. + let has_model = response.headers().contains_key("X-Model"); + let is_streaming = response .headers() - .get("X-Provider") + .get(http::header::CONTENT_TYPE) .and_then(|v| v.to_str().ok()) - .map(String::from) - .or(tracker.provider) - .unwrap_or_else(|| "unknown".to_string()); - - let elapsed = tracker.start_time.elapsed(); - let latency_ms = elapsed.as_millis().min(i32::MAX as u128) as i32; - - let status = if response.status().is_success() { - "success" - } else { - "error" - }; - metrics::record_llm_request(metrics::LlmRequestMetrics { - provider: &provider, - model: &model, - status, - status_code: Some(response.status().as_u16()), - duration_secs: elapsed.as_secs_f64(), - input_tokens: usage.input_tokens, - output_tokens: usage.output_tokens, - cost_microcents: usage.cost_microcents, - }); - - let header_project_id = headers - .get("X-Hadrian-Project") - .and_then(|v| v.to_str().ok()) - .and_then(|v| uuid::Uuid::parse_str(v).ok()); - - buffer.push(crate::models::UsageLogEntry { - request_id: request_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), - api_key_id: None, - user_id: state.default_user_id, - org_id: state.default_org_id, - project_id: header_project_id, - team_id: None, - service_account_id: None, - model, - provider, - input_tokens: saturate_i64_to_i32(usage.input_tokens.unwrap_or(0)), - output_tokens: saturate_i64_to_i32(usage.output_tokens.unwrap_or(0)), - cost_microcents: usage.cost_microcents, - http_referer: tracker.referer.clone(), - request_at: chrono::Utc::now(), - streamed: tracker.streamed, - cached_tokens: 0, - reasoning_tokens: 0, - finish_reason: None, - latency_ms: Some(latency_ms), - cancelled: false, - status_code: Some(response.status().as_u16() as i16), - pricing_source: usage.pricing_source, - image_count: usage.image_count, - audio_seconds: usage.audio_seconds, - character_count: usage.character_count, - provider_source: tracker.provider_source.clone(), - }); + .is_some_and(|s| s.contains("text/event-stream")) + || response + .headers() + .get("Transfer-Encoding") + .and_then(|v| v.to_str().ok()) + .is_some_and(|s| s.contains("chunked")); + + // Only track LLM requests (those with X-Model header). + // Skip streaming responses here — UsageTrackingStream handles them + // with actual token counts after the stream completes. + if has_model && !is_streaming { + let usage = extract_full_usage_from_response(&response); + + let model = response + .headers() + .get("X-Model") + .and_then(|v| v.to_str().ok()) + .map(String::from) + .or(tracker.model) + .unwrap_or_else(|| "unknown".to_string()); + let provider = response + .headers() + .get("X-Provider") + .and_then(|v| v.to_str().ok()) + .map(String::from) + .or(tracker.provider) + .unwrap_or_else(|| "unknown".to_string()); + + let elapsed = tracker.start_time.elapsed(); + let latency_ms = elapsed.as_millis().min(i32::MAX as u128) as i32; + + let status = if response.status().is_success() { + "success" + } else { + "error" + }; + metrics::record_llm_request(metrics::LlmRequestMetrics { + provider: &provider, + model: &model, + status, + status_code: Some(response.status().as_u16()), + duration_secs: elapsed.as_secs_f64(), + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + cost_microcents: usage.cost_microcents, + }); + + let header_project_id = headers + .get("X-Hadrian-Project") + .and_then(|v| v.to_str().ok()) + .and_then(|v| uuid::Uuid::parse_str(v).ok()); + + buffer.push(crate::models::UsageLogEntry { + request_id: request_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), + api_key_id: None, + user_id: state.default_user_id, + org_id: state.default_org_id, + project_id: header_project_id, + team_id: None, + service_account_id: None, + model, + provider, + input_tokens: saturate_i64_to_i32(usage.input_tokens.unwrap_or(0)), + output_tokens: saturate_i64_to_i32(usage.output_tokens.unwrap_or(0)), + cost_microcents: usage.cost_microcents, + http_referer: tracker.referer.clone(), + request_at: chrono::Utc::now(), + streamed: tracker.streamed, + cached_tokens: 0, + reasoning_tokens: 0, + finish_reason: None, + latency_ms: Some(latency_ms), + cancelled: false, + status_code: Some(response.status().as_u16() as i16), + pricing_source: usage.pricing_source, + image_count: usage.image_count, + audio_seconds: usage.audio_seconds, + character_count: usage.character_count, + provider_source: tracker.provider_source.clone(), + }); + } } } @@ -1125,6 +1133,7 @@ fn track_usage_async(ctx: UsageTrackingContext<'_>) { // Push to usage buffer for batched writes (if available). // Skip for streaming responses (UsageTrackingStream writes correct values) // and non-LLM requests (no X-Model header means this isn't an LLM call). + #[cfg(feature = "concurrency")] if has_model && !is_streaming { if let Some(buffer) = &state.usage_buffer { tracing::debug!( @@ -1148,6 +1157,7 @@ fn track_usage_async(ctx: UsageTrackingContext<'_>) { if api_key.is_some() { if let Some(cache) = state.cache { // Use task_tracker to ensure this task completes during graceful shutdown + #[cfg(feature = "server")] state.task_tracker.spawn(async move { // Adjust budget reservation with actual cost (for successful responses) // This replaces the estimated cost that was reserved before the request @@ -1997,6 +2007,7 @@ fn log_budget_exceeded(event: BudgetExceededEvent<'_>) { // Fire-and-forget: spawn a task to log the audit event // This ensures we don't block the response on audit logging + #[cfg(feature = "server")] state.task_tracker.spawn(async move { let result = db .audit_logs() @@ -2083,6 +2094,7 @@ fn log_budget_warning(event: BudgetWarningEvent<'_>) { let req_id = request_id.map(String::from); // Fire-and-forget: spawn a task to log the audit event + #[cfg(feature = "server")] state.task_tracker.spawn(async move { // Check if we've already logged a warning for this API key in this budget period // Cache key format: budget_warning_logged:{api_key_id}:{period} @@ -2208,7 +2220,7 @@ mod tests { /// Create AppState with Idp configuration fn create_multi_auth_state(header_name: &str, key_prefix: &str) -> AppState { - let mut config = GatewayConfig::from_str("").unwrap(); + let mut config = GatewayConfig::parse("").unwrap(); #[cfg(feature = "sso")] { config.auth.mode = AuthMode::Idp; @@ -2266,7 +2278,7 @@ mod tests { /// Create AppState with API key only authentication fn create_api_key_only_state(header_name: &str, key_prefix: &str) -> AppState { - let mut config = GatewayConfig::from_str("").unwrap(); + let mut config = GatewayConfig::parse("").unwrap(); config.auth.mode = AuthMode::ApiKey; config.auth.api_key = Some(ApiKeyAuthConfig { header_name: header_name.to_string(), diff --git a/src/middleware/layers/authz.rs b/src/middleware/layers/authz.rs index e5cbb48..2b44b30 100644 --- a/src/middleware/layers/authz.rs +++ b/src/middleware/layers/authz.rs @@ -8,473 +8,16 @@ use axum::{ middleware::Next, response::{IntoResponse, Response}, }; -use serde_json::json; -use tokio_util::task::TaskTracker; -use uuid::Uuid; use crate::{ AppState, auth::AuthenticatedRequest, - authz::{ - AuthzEngine, AuthzError, AuthzResult, PolicyContext, PolicyRegistry, RequestContext, - Subject, - }, - config::{AuthzAuditConfig, PolicyEffect}, - middleware::AdminAuth, - models::{AuditActorType, CreateAuditLog}, - services::AuditLogService, + authz::{AuthzEngine, Subject}, + middleware::{AdminAuth, AuthzContext}, }; -/// Authorization context extracted from request. -#[derive(Clone)] -pub struct AuthzContext { - pub subject: Subject, - pub engine: Arc, - /// Per-organization policy registry for org-scoped authorization. - /// When available, allows evaluation of org-specific RBAC policies. - pub registry: Option>, - /// Audit log service for logging authorization decisions (optional) - audit_service: Option, - /// Task tracker for async logging - task_tracker: Option, - /// Request metadata for audit logs - request_ip: Option, - request_user_agent: Option, - /// Audit logging configuration - audit_config: AuthzAuditConfig, - /// Default effect for API authorization when no policy matches. - /// This allows API endpoints to have a different default (e.g., "allow") - /// than admin endpoints (e.g., "deny"). - api_default_effect: PolicyEffect, -} - -impl AuthzContext { - /// Check if the subject is authorized for an action on a resource. - /// - /// Parameters: - /// - `resource`: The type of resource being accessed (e.g., "team", "project") - /// - `action`: The action being performed (e.g., "read", "create", "delete") - /// - `resource_id`: The specific resource ID being accessed - /// - `org_id`: The organization scope - /// - `team_id`: The team scope (if applicable) - /// - `project_id`: The project scope (if applicable) - pub fn authorize( - &self, - resource: &str, - action: &str, - resource_id: Option<&str>, - org_id: Option<&str>, - team_id: Option<&str>, - project_id: Option<&str>, - ) -> AuthzResult { - let mut context = PolicyContext::new(resource, action); - if let Some(id) = resource_id { - context = context.with_resource_id(id); - } - if let Some(id) = org_id { - context = context.with_org_id(id); - } - if let Some(id) = team_id { - context = context.with_team_id(id); - } - if let Some(id) = project_id { - context = context.with_project_id(id); - } - self.engine.authorize(&self.subject, &context) - } - - /// Check authorization and return an error if denied. - /// Logs authorization decisions based on audit configuration. - /// - /// This method evaluates **system policies only** (from config file). It does NOT - /// evaluate per-organization policies from the database. This is by design: - /// - /// - **Admin endpoints** use `require()` - controlled by platform operators via system policies - /// - **API endpoints** use `require_api()` - also evaluates org policies for customer-specific rules - /// - /// This separation ensures that: - /// 1. Admin operations are governed by platform-wide rules (simpler, more predictable) - /// 2. Org admins can customize API access (model usage, rate limits) without affecting admin operations - /// 3. Synchronous evaluation avoids async complexity in admin handlers - /// - /// Parameters: - /// - `resource`: The type of resource being accessed (e.g., "team", "project") - /// - `action`: The action being performed (e.g., "read", "create", "delete") - /// - `resource_id`: The specific resource ID being accessed - /// - `org_id`: The organization scope - /// - `team_id`: The team scope (if applicable) - /// - `project_id`: The project scope (if applicable) - pub fn require( - &self, - resource: &str, - action: &str, - resource_id: Option<&str>, - org_id: Option<&str>, - team_id: Option<&str>, - project_id: Option<&str>, - ) -> Result<(), AuthzError> { - let result = self.authorize(resource, action, resource_id, org_id, team_id, project_id); - if result.allowed { - // Log allowed decisions if configured - if self.audit_config.log_allowed { - self.log_authorization_decision( - resource, - action, - resource_id, - org_id, - team_id, - project_id, - &result, - ); - } - Ok(()) - } else { - // Log denied decisions if configured - if self.audit_config.log_denied { - self.log_authorization_decision( - resource, - action, - resource_id, - org_id, - team_id, - project_id, - &result, - ); - } - Err(AuthzError::AccessDenied( - result.reason.unwrap_or_else(|| "Access denied".to_string()), - )) - } - } - - /// Log an authorization decision asynchronously. - /// Logs to the audit log with full context for security monitoring. - #[allow(clippy::too_many_arguments)] - fn log_authorization_decision( - &self, - resource: &str, - action: &str, - resource_id: Option<&str>, - org_id: Option<&str>, - team_id: Option<&str>, - project_id: Option<&str>, - result: &AuthzResult, - ) { - // Only log if audit service is available - let (Some(audit_service), Some(task_tracker)) = - (self.audit_service.clone(), self.task_tracker.clone()) - else { - return; - }; - - // Build audit log entry - let actor_type = if self.subject.user_id.is_some() { - AuditActorType::User - } else { - AuditActorType::System - }; - - let actor_id = self - .subject - .user_id - .as_ref() - .and_then(|id| Uuid::parse_str(id).ok()); - - // Use provided resource_id or generate a nil UUID for the audit log - let audit_resource_id = resource_id - .and_then(|id| Uuid::parse_str(id).ok()) - .unwrap_or_else(Uuid::nil); - - let parsed_org_id = org_id.and_then(|id| Uuid::parse_str(id).ok()); - let parsed_project_id = project_id.and_then(|id| Uuid::parse_str(id).ok()); - - // Build details JSON with authorization context - let details = json!({ - "decision": if result.allowed { "allow" } else { "deny" }, - "policy_name": result.policy_name, - "reason": result.reason, - "resource": resource, - "action": action, - "org_id": org_id, - "team_id": team_id, - "project_id": project_id, - "resource_id": resource_id, - "subject": { - "user_id": self.subject.user_id, - "external_id": self.subject.external_id, - "email": self.subject.email, - "roles": self.subject.roles, - "team_ids": self.subject.team_ids, - } - }); - - let audit_action = format!("authz.{}", if result.allowed { "allow" } else { "deny" }); - let ip_address = self.request_ip.clone(); - let user_agent = self.request_user_agent.clone(); - let resource_type = resource.to_string(); - - // Spawn async task to write audit log (non-blocking) - task_tracker.spawn(async move { - let entry = CreateAuditLog { - actor_type, - actor_id, - action: audit_action, - resource_type, - resource_id: audit_resource_id, - org_id: parsed_org_id, - project_id: parsed_project_id, - details, - ip_address, - user_agent, - }; - - if let Err(e) = audit_service.create(entry).await { - tracing::warn!( - error = %e, - "Failed to log authorization decision to audit log" - ); - } - }); - } - - /// Check if the subject has a specific role. - #[allow(dead_code)] // Public API for CEL policy evaluation - pub fn has_role(&self, role: &str) -> bool { - self.subject.has_role(role) - } - - /// Check if the subject is a member of an organization. - #[allow(dead_code)] // Public API for CEL policy evaluation - pub fn is_org_member(&self, org_id: &str) -> bool { - self.subject.is_org_member(org_id) - } - - /// Check if the subject is a member of a team. - #[allow(dead_code)] // Public API for CEL policy evaluation - pub fn is_team_member(&self, team_id: &str) -> bool { - self.subject.is_team_member(team_id) - } - - /// Check if the subject is a member of a project. - #[allow(dead_code)] // Public API for CEL policy evaluation - pub fn is_project_member(&self, project_id: &str) -> bool { - self.subject.is_project_member(project_id) - } - - /// Authorize an API request with model and request-specific context. - /// - /// This is used for `/v1/*` API endpoints where authorization depends on - /// the specific model, request parameters, and time of day. - /// - /// Parameters: - /// - `resource`: The resource type (e.g., "model", "chat", "embeddings") - /// - `action`: The action being performed (e.g., "use", "complete") - /// - `model`: The model being requested (e.g., "gpt-4o", "claude-3-opus") - /// - `request`: Request-specific context (tokens, tools, etc.) - /// - `org_id`: Organization scope (from API key or identity) - /// - `project_id`: Project scope (from API key) - #[allow(dead_code)] // Public API for CEL policy evaluation on API endpoints - pub fn authorize_api( - &self, - resource: &str, - action: &str, - model: Option<&str>, - request: Option, - org_id: Option<&str>, - project_id: Option<&str>, - ) -> AuthzResult { - let mut context = PolicyContext::new(resource, action).with_current_time(); - - if let Some(m) = model { - context = context.with_model(m); - } - if let Some(req) = request { - context = context.with_request(req); - } - if let Some(id) = org_id { - context = context.with_org_id(id); - } - if let Some(id) = project_id { - context = context.with_project_id(id); - } - - self.engine.authorize(&self.subject, &context) - } - - /// Check API authorization and return an error if denied. - /// Logs authorization decisions based on audit configuration. - /// - /// This method evaluates both system policies (from config) and org policies - /// (from database) when a PolicyRegistry is available and org_id is provided. - /// - /// Parameters: - /// - `resource`: The resource type (e.g., "model", "chat", "embeddings") - /// - `action`: The action being performed (e.g., "use", "complete") - /// - `model`: The model being requested - /// - `request`: Request-specific context (tokens, tools, etc.) - /// - `org_id`: Organization scope - /// - `project_id`: Project scope - pub async fn require_api( - &self, - resource: &str, - action: &str, - model: Option<&str>, - request: Option, - org_id: Option<&str>, - project_id: Option<&str>, - ) -> Result<(), AuthzError> { - // Build the policy context - let mut context = PolicyContext::new(resource, action).with_current_time(); - - if let Some(m) = model { - context = context.with_model(m); - } - if let Some(req) = request.clone() { - context = context.with_request(req); - } - if let Some(id) = org_id { - context = context.with_org_id(id); - } - if let Some(id) = project_id { - context = context.with_project_id(id); - } - - // Evaluate using registry if available (includes org policies), otherwise engine only - // Use the API-specific default effect when no policy matches - let result = if let Some(ref registry) = self.registry { - let parsed_org_id = org_id.and_then(|id| Uuid::parse_str(id).ok()); - registry - .authorize_with_org_and_default( - parsed_org_id, - &self.subject, - &context, - self.api_default_effect, - ) - .await - } else { - // No registry available, use system policies only - self.engine.authorize(&self.subject, &context) - }; - - if result.allowed { - if self.audit_config.log_allowed { - self.log_api_authorization_decision( - resource, - action, - model, - request.as_ref(), - org_id, - project_id, - &result, - ); - } - Ok(()) - } else { - if self.audit_config.log_denied { - self.log_api_authorization_decision( - resource, - action, - model, - request.as_ref(), - org_id, - project_id, - &result, - ); - } - Err(AuthzError::AccessDenied( - result.reason.unwrap_or_else(|| "Access denied".to_string()), - )) - } - } - - /// Log an API authorization decision asynchronously. - #[allow(clippy::too_many_arguments)] - fn log_api_authorization_decision( - &self, - resource: &str, - action: &str, - model: Option<&str>, - request: Option<&RequestContext>, - org_id: Option<&str>, - project_id: Option<&str>, - result: &AuthzResult, - ) { - let (Some(audit_service), Some(task_tracker)) = - (self.audit_service.clone(), self.task_tracker.clone()) - else { - return; - }; - - let actor_type = if self.subject.user_id.is_some() { - AuditActorType::User - } else { - AuditActorType::System - }; - - let actor_id = self - .subject - .user_id - .as_ref() - .and_then(|id| Uuid::parse_str(id).ok()); - - let parsed_org_id = org_id.and_then(|id| Uuid::parse_str(id).ok()); - let parsed_project_id = project_id.and_then(|id| Uuid::parse_str(id).ok()); - - let details = json!({ - "decision": if result.allowed { "allow" } else { "deny" }, - "policy_name": result.policy_name, - "reason": result.reason, - "resource": resource, - "action": action, - "model": model, - "request": request.map(|r| json!({ - "max_tokens": r.max_tokens, - "messages_count": r.messages_count, - "has_tools": r.has_tools, - "has_file_search": r.has_file_search, - "stream": r.stream, - })), - "org_id": org_id, - "project_id": project_id, - "subject": { - "user_id": self.subject.user_id, - "external_id": self.subject.external_id, - "email": self.subject.email, - "roles": self.subject.roles, - } - }); - - let audit_action = format!( - "api_authz.{}", - if result.allowed { "allow" } else { "deny" } - ); - let ip_address = self.request_ip.clone(); - let user_agent = self.request_user_agent.clone(); - let resource_type = resource.to_string(); - - task_tracker.spawn(async move { - let entry = CreateAuditLog { - actor_type, - actor_id, - action: audit_action, - resource_type, - resource_id: Uuid::nil(), // API requests don't have a single resource ID - org_id: parsed_org_id, - project_id: parsed_project_id, - details, - ip_address, - user_agent, - }; - - if let Err(e) = audit_service.create(entry).await { - tracing::warn!( - error = %e, - "Failed to log API authorization decision to audit log" - ); - } - }); - } -} +// AuthzContext struct and impl are in crate::middleware::types +// (always available on all targets). Middleware functions below use it. /// Middleware that builds authorization context from the authenticated request. /// This must run after admin_auth_middleware. @@ -544,17 +87,17 @@ pub async fn authz_middleware( // Always add authz context to request (fail-closed pattern) // Admin endpoints use the main RBAC default effect (typically "deny") - req.extensions_mut().insert(AuthzContext { + req.extensions_mut().insert(AuthzContext::new( subject, engine, - registry: state.policy_registry.clone(), + state.policy_registry.clone(), audit_service, - task_tracker: Some(state.task_tracker.clone()), + Some(state.task_tracker.clone()), request_ip, request_user_agent, - audit_config: state.config.auth.rbac.audit.clone(), - api_default_effect: state.config.auth.rbac.default_effect, - }); + state.config.auth.rbac.audit.clone(), + state.config.auth.rbac.default_effect, + )); Ok(next.run(req).await) } @@ -622,17 +165,8 @@ pub async fn permissive_authz_middleware( // Insert permissive AuthzContext with empty subject // Since RBAC is disabled, all authorization checks will pass (no denials to log) - req.extensions_mut().insert(AuthzContext { - subject: Subject::new(), - engine, - registry: None, // No registry in permissive mode - audit_service: None, // No audit logging in permissive mode - task_tracker: None, - request_ip: None, - request_user_agent: None, - audit_config: AuthzAuditConfig::default(), // Use defaults (no logging in permissive mode anyway) - api_default_effect: PolicyEffect::Allow, // Permissive mode always allows - }); + req.extensions_mut() + .insert(AuthzContext::permissive(engine)); Ok(next.run(req).await) } @@ -712,17 +246,17 @@ pub async fn api_authz_middleware( // Always add authz context to request // API endpoints use the API-specific default effect (typically "allow" for backwards compatibility) - req.extensions_mut().insert(AuthzContext { + req.extensions_mut().insert(AuthzContext::new( subject, engine, - registry: state.policy_registry.clone(), + state.policy_registry.clone(), audit_service, - task_tracker: Some(state.task_tracker.clone()), + Some(state.task_tracker.clone()), request_ip, request_user_agent, - audit_config: state.config.auth.rbac.audit.clone(), - api_default_effect: api_rbac_config.default_effect, - }); + state.config.auth.rbac.audit.clone(), + api_rbac_config.default_effect, + )); Ok(next.run(req).await) } diff --git a/src/middleware/layers/rate_limit.rs b/src/middleware/layers/rate_limit.rs index 92b72fa..d3542d6 100644 --- a/src/middleware/layers/rate_limit.rs +++ b/src/middleware/layers/rate_limit.rs @@ -1,8 +1,10 @@ use std::{net::IpAddr, sync::Arc, time::Duration}; +#[cfg(feature = "server")] +use axum::extract::ConnectInfo; use axum::{ Json, - extract::{ConnectInfo, Request, State}, + extract::{Request, State}, http::{HeaderValue, StatusCode}, middleware::Next, response::{IntoResponse, Response}, @@ -166,10 +168,13 @@ pub async fn rate_limit_middleware( #[allow(clippy::collapsible_if)] pub fn extract_client_ip(req: &Request, trusted_proxies: &TrustedProxiesConfig) -> Option { // Get the direct connecting IP (from TCP connection) + #[cfg(feature = "server")] let connecting_ip = req .extensions() .get::>() .map(|ci| ci.0.ip()); + #[cfg(not(feature = "server"))] + let connecting_ip: Option = None; // If no proxy trust is configured, just return the connecting IP if !trusted_proxies.is_configured() { diff --git a/src/middleware/layers/request_id.rs b/src/middleware/layers/request_id.rs index af9d387..ff989da 100644 --- a/src/middleware/layers/request_id.rs +++ b/src/middleware/layers/request_id.rs @@ -11,44 +11,12 @@ use axum::{ response::{IntoResponse, Response}, }; use http_body_util::BodyExt; -use uuid::Uuid; + +use crate::middleware::RequestId; /// Header name for the request ID. pub const REQUEST_ID_HEADER: &str = "X-Request-Id"; -/// Extension containing the request ID for the current request. -#[derive(Debug, Clone)] -pub struct RequestId(pub String); - -impl RequestId { - /// Generate a new request ID. - pub fn new() -> Self { - Self(Uuid::new_v4().to_string()) - } - - /// Create from an existing ID. - pub fn from_string(id: String) -> Self { - Self(id) - } - - /// Get the ID as a string slice. - pub fn as_str(&self) -> &str { - &self.0 - } -} - -impl Default for RequestId { - fn default() -> Self { - Self::new() - } -} - -impl std::fmt::Display for RequestId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - /// Middleware that adds a request ID to each request. /// /// If the request already has an X-Request-Id header, it's used. diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 58bc299..fbeb3ac 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -21,29 +21,28 @@ //! ## Unprotected admin routes (login, session info) //! - [`permissive_authz_middleware`] — Injects allow-all authz context -// ── True middleware (Axum middleware layers) ──────────────────────────────────── +// ── Types extracted by middleware (used by route handlers via Extension) ──── +// Always available on all targets (including WASM). +mod types; +pub use types::{AdminAuth, AuthzContext, ClientInfo, RequestId}; + +// ── True middleware (Axum middleware layers) — server only ─────────────────── +#[cfg(feature = "server")] mod layers; // ── Internal utilities (budget, scope, usage helpers for combined middleware) ── +#[cfg(feature = "server")] pub(crate) mod util; -// ── Middleware layer exports ─────────────────────────────────────────────────── +// ── Middleware layer exports — server only ─────────────────────────────────── #[cfg(feature = "sso")] pub use layers::rate_limit::extract_client_ip_from_parts; +#[cfg(feature = "server")] pub use layers::{ - admin::{AdminAuth, admin_auth_middleware}, + admin::admin_auth_middleware, api::api_middleware, - authz::{AuthzContext, api_authz_middleware, authz_middleware, permissive_authz_middleware}, + authz::{AuthzResponse, api_authz_middleware, authz_middleware, permissive_authz_middleware}, rate_limit::rate_limit_middleware, - request_id::{RequestId, request_id_middleware}, + request_id::request_id_middleware, security_headers::security_headers_middleware, }; - -// ── Types extracted by middleware (used by route handlers via Extension) ──── - -/// Client connection metadata extracted by middleware for audit logging. -#[derive(Debug, Clone, Default)] -pub struct ClientInfo { - pub ip_address: Option, - pub user_agent: Option, -} diff --git a/src/middleware/types.rs b/src/middleware/types.rs new file mode 100644 index 0000000..ae86b39 --- /dev/null +++ b/src/middleware/types.rs @@ -0,0 +1,592 @@ +//! Middleware types extracted by middleware layers and consumed by route handlers. +//! +//! These types are always available on all targets (including WASM) so that +//! route handlers can compile without the server-only middleware *functions*. + +use std::sync::Arc; + +#[cfg(feature = "server")] +use serde_json::json; +#[cfg(feature = "server")] +use tokio_util::task::TaskTracker; +use uuid::Uuid; + +#[cfg(feature = "server")] +use crate::models::{AuditActorType, CreateAuditLog}; +use crate::{ + auth::Identity, + authz::{ + AuthzEngine, AuthzError, AuthzResult, PolicyContext, PolicyRegistry, RequestContext, + Subject, + }, + config::{AuthzAuditConfig, PolicyEffect}, + services::AuditLogService, +}; + +/// Client connection metadata extracted by middleware for audit logging. +#[derive(Debug, Clone, Default)] +pub struct ClientInfo { + pub ip_address: Option, + pub user_agent: Option, +} + +/// Admin authentication result. +#[derive(Debug, Clone)] +pub struct AdminAuth { + /// The authenticated identity + pub identity: Identity, +} + +/// Extension containing the request ID for the current request. +#[derive(Debug, Clone)] +pub struct RequestId(pub String); + +impl RequestId { + /// Generate a new request ID. + pub fn new() -> Self { + Self(Uuid::new_v4().to_string()) + } + + /// Create from an existing ID. + pub fn from_string(id: String) -> Self { + Self(id) + } + + /// Get the ID as a string slice. + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl Default for RequestId { + fn default() -> Self { + Self::new() + } +} + +impl std::fmt::Display for RequestId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Authorization context extracted from request. +#[derive(Clone)] +pub struct AuthzContext { + pub subject: Subject, + pub engine: Arc, + /// Per-organization policy registry for org-scoped authorization. + /// When available, allows evaluation of org-specific RBAC policies. + pub registry: Option>, + /// Audit log service for logging authorization decisions (optional) + audit_service: Option, + /// Task tracker for async logging + #[cfg(feature = "server")] + task_tracker: Option, + /// Request metadata for audit logs + request_ip: Option, + request_user_agent: Option, + /// Audit logging configuration + audit_config: AuthzAuditConfig, + /// Default effect for API authorization when no policy matches. + /// This allows API endpoints to have a different default (e.g., "allow") + /// than admin endpoints (e.g., "deny"). + api_default_effect: PolicyEffect, +} + +impl AuthzContext { + /// Create a permissive authorization context (RBAC disabled, all checks pass). + /// + /// Always available on all targets. Used by WASM and development routes. + pub fn permissive(engine: Arc) -> Self { + Self { + subject: Subject::new(), + engine, + registry: None, + audit_service: None, + #[cfg(feature = "server")] + task_tracker: None, + request_ip: None, + request_user_agent: None, + audit_config: AuthzAuditConfig::default(), + api_default_effect: PolicyEffect::Allow, + } + } + + /// Create a full authorization context (server middleware only). + #[cfg(feature = "server")] + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + subject: Subject, + engine: Arc, + registry: Option>, + audit_service: Option, + task_tracker: Option, + request_ip: Option, + request_user_agent: Option, + audit_config: AuthzAuditConfig, + api_default_effect: PolicyEffect, + ) -> Self { + Self { + subject, + engine, + registry, + audit_service, + task_tracker, + request_ip, + request_user_agent, + audit_config, + api_default_effect, + } + } + + /// Check if the subject is authorized for an action on a resource. + /// + /// Parameters: + /// - `resource`: The type of resource being accessed (e.g., "team", "project") + /// - `action`: The action being performed (e.g., "read", "create", "delete") + /// - `resource_id`: The specific resource ID being accessed + /// - `org_id`: The organization scope + /// - `team_id`: The team scope (if applicable) + /// - `project_id`: The project scope (if applicable) + pub fn authorize( + &self, + resource: &str, + action: &str, + resource_id: Option<&str>, + org_id: Option<&str>, + team_id: Option<&str>, + project_id: Option<&str>, + ) -> AuthzResult { + let mut context = PolicyContext::new(resource, action); + if let Some(id) = resource_id { + context = context.with_resource_id(id); + } + if let Some(id) = org_id { + context = context.with_org_id(id); + } + if let Some(id) = team_id { + context = context.with_team_id(id); + } + if let Some(id) = project_id { + context = context.with_project_id(id); + } + self.engine.authorize(&self.subject, &context) + } + + /// Check authorization and return an error if denied. + /// Logs authorization decisions based on audit configuration. + /// + /// This method evaluates **system policies only** (from config file). It does NOT + /// evaluate per-organization policies from the database. This is by design: + /// + /// - **Admin endpoints** use `require()` - controlled by platform operators via system policies + /// - **API endpoints** use `require_api()` - also evaluates org policies for customer-specific rules + /// + /// This separation ensures that: + /// 1. Admin operations are governed by platform-wide rules (simpler, more predictable) + /// 2. Org admins can customize API access (model usage, rate limits) without affecting admin operations + /// 3. Synchronous evaluation avoids async complexity in admin handlers + /// + /// Parameters: + /// - `resource`: The type of resource being accessed (e.g., "team", "project") + /// - `action`: The action being performed (e.g., "read", "create", "delete") + /// - `resource_id`: The specific resource ID being accessed + /// - `org_id`: The organization scope + /// - `team_id`: The team scope (if applicable) + /// - `project_id`: The project scope (if applicable) + pub fn require( + &self, + resource: &str, + action: &str, + resource_id: Option<&str>, + org_id: Option<&str>, + team_id: Option<&str>, + project_id: Option<&str>, + ) -> Result<(), AuthzError> { + let result = self.authorize(resource, action, resource_id, org_id, team_id, project_id); + if result.allowed { + // Log allowed decisions if configured + if self.audit_config.log_allowed { + self.log_authorization_decision( + resource, + action, + resource_id, + org_id, + team_id, + project_id, + &result, + ); + } + Ok(()) + } else { + // Log denied decisions if configured + if self.audit_config.log_denied { + self.log_authorization_decision( + resource, + action, + resource_id, + org_id, + team_id, + project_id, + &result, + ); + } + Err(AuthzError::AccessDenied( + result.reason.unwrap_or_else(|| "Access denied".to_string()), + )) + } + } + + /// Log an authorization decision asynchronously. + /// Logs to the audit log with full context for security monitoring. + #[allow(clippy::too_many_arguments)] + fn log_authorization_decision( + &self, + resource: &str, + action: &str, + resource_id: Option<&str>, + org_id: Option<&str>, + team_id: Option<&str>, + project_id: Option<&str>, + result: &AuthzResult, + ) { + // Only log if audit service and task tracker are available + #[cfg(not(feature = "server"))] + { + let _ = ( + resource, + action, + resource_id, + org_id, + team_id, + project_id, + result, + ); + return; + } + #[cfg(feature = "server")] + { + let (Some(audit_service), Some(task_tracker)) = + (self.audit_service.clone(), self.task_tracker.clone()) + else { + return; + }; + + // Build audit log entry + let actor_type = if self.subject.user_id.is_some() { + AuditActorType::User + } else { + AuditActorType::System + }; + + let actor_id = self + .subject + .user_id + .as_ref() + .and_then(|id| Uuid::parse_str(id).ok()); + + // Use provided resource_id or generate a nil UUID for the audit log + let audit_resource_id = resource_id + .and_then(|id| Uuid::parse_str(id).ok()) + .unwrap_or_else(Uuid::nil); + + let parsed_org_id = org_id.and_then(|id| Uuid::parse_str(id).ok()); + let parsed_project_id = project_id.and_then(|id| Uuid::parse_str(id).ok()); + + // Build details JSON with authorization context + let details = json!({ + "decision": if result.allowed { "allow" } else { "deny" }, + "policy_name": result.policy_name, + "reason": result.reason, + "resource": resource, + "action": action, + "org_id": org_id, + "team_id": team_id, + "project_id": project_id, + "resource_id": resource_id, + "subject": { + "user_id": self.subject.user_id, + "external_id": self.subject.external_id, + "email": self.subject.email, + "roles": self.subject.roles, + "team_ids": self.subject.team_ids, + } + }); + + let audit_action = format!("authz.{}", if result.allowed { "allow" } else { "deny" }); + let ip_address = self.request_ip.clone(); + let user_agent = self.request_user_agent.clone(); + let resource_type = resource.to_string(); + + // Spawn async task to write audit log (non-blocking) + task_tracker.spawn(async move { + let entry = CreateAuditLog { + actor_type, + actor_id, + action: audit_action, + resource_type, + resource_id: audit_resource_id, + org_id: parsed_org_id, + project_id: parsed_project_id, + details, + ip_address, + user_agent, + }; + + if let Err(e) = audit_service.create(entry).await { + tracing::warn!( + error = %e, + "Failed to log authorization decision to audit log" + ); + } + }); + } + } + + /// Check if the subject has a specific role. + #[allow(dead_code)] // Public API for CEL policy evaluation + pub fn has_role(&self, role: &str) -> bool { + self.subject.has_role(role) + } + + /// Check if the subject is a member of an organization. + #[allow(dead_code)] // Public API for CEL policy evaluation + pub fn is_org_member(&self, org_id: &str) -> bool { + self.subject.is_org_member(org_id) + } + + /// Check if the subject is a member of a team. + #[allow(dead_code)] // Public API for CEL policy evaluation + pub fn is_team_member(&self, team_id: &str) -> bool { + self.subject.is_team_member(team_id) + } + + /// Check if the subject is a member of a project. + #[allow(dead_code)] // Public API for CEL policy evaluation + pub fn is_project_member(&self, project_id: &str) -> bool { + self.subject.is_project_member(project_id) + } + + /// Authorize an API request with model and request-specific context. + /// + /// This is used for `/v1/*` API endpoints where authorization depends on + /// the specific model, request parameters, and time of day. + /// + /// Parameters: + /// - `resource`: The resource type (e.g., "model", "chat", "embeddings") + /// - `action`: The action being performed (e.g., "use", "complete") + /// - `model`: The model being requested (e.g., "gpt-4o", "claude-3-opus") + /// - `request`: Request-specific context (tokens, tools, etc.) + /// - `org_id`: Organization scope (from API key or identity) + /// - `project_id`: Project scope (from API key) + #[allow(dead_code)] // Public API for CEL policy evaluation on API endpoints + pub fn authorize_api( + &self, + resource: &str, + action: &str, + model: Option<&str>, + request: Option, + org_id: Option<&str>, + project_id: Option<&str>, + ) -> AuthzResult { + let mut context = PolicyContext::new(resource, action).with_current_time(); + + if let Some(m) = model { + context = context.with_model(m); + } + if let Some(req) = request { + context = context.with_request(req); + } + if let Some(id) = org_id { + context = context.with_org_id(id); + } + if let Some(id) = project_id { + context = context.with_project_id(id); + } + + self.engine.authorize(&self.subject, &context) + } + + /// Check API authorization and return an error if denied. + /// Logs authorization decisions based on audit configuration. + /// + /// This method evaluates both system policies (from config) and org policies + /// (from database) when a PolicyRegistry is available and org_id is provided. + /// + /// Parameters: + /// - `resource`: The resource type (e.g., "model", "chat", "embeddings") + /// - `action`: The action being performed (e.g., "use", "complete") + /// - `model`: The model being requested + /// - `request`: Request-specific context (tokens, tools, etc.) + /// - `org_id`: Organization scope + /// - `project_id`: Project scope + pub async fn require_api( + &self, + resource: &str, + action: &str, + model: Option<&str>, + request: Option, + org_id: Option<&str>, + project_id: Option<&str>, + ) -> Result<(), AuthzError> { + // Build the policy context + let mut context = PolicyContext::new(resource, action).with_current_time(); + + if let Some(m) = model { + context = context.with_model(m); + } + if let Some(req) = request.clone() { + context = context.with_request(req); + } + if let Some(id) = org_id { + context = context.with_org_id(id); + } + if let Some(id) = project_id { + context = context.with_project_id(id); + } + + // Evaluate using registry if available (includes org policies), otherwise engine only + // Use the API-specific default effect when no policy matches + let result = if let Some(ref registry) = self.registry { + let parsed_org_id = org_id.and_then(|id| Uuid::parse_str(id).ok()); + registry + .authorize_with_org_and_default( + parsed_org_id, + &self.subject, + &context, + self.api_default_effect, + ) + .await + } else { + // No registry available, use system policies only + self.engine.authorize(&self.subject, &context) + }; + + if result.allowed { + if self.audit_config.log_allowed { + self.log_api_authorization_decision( + resource, + action, + model, + request.as_ref(), + org_id, + project_id, + &result, + ); + } + Ok(()) + } else { + if self.audit_config.log_denied { + self.log_api_authorization_decision( + resource, + action, + model, + request.as_ref(), + org_id, + project_id, + &result, + ); + } + Err(AuthzError::AccessDenied( + result.reason.unwrap_or_else(|| "Access denied".to_string()), + )) + } + } + + /// Log an API authorization decision asynchronously. + #[allow(clippy::too_many_arguments)] + fn log_api_authorization_decision( + &self, + resource: &str, + action: &str, + model: Option<&str>, + request: Option<&RequestContext>, + org_id: Option<&str>, + project_id: Option<&str>, + result: &AuthzResult, + ) { + #[cfg(not(feature = "server"))] + { + let _ = (resource, action, model, request, org_id, project_id, result); + return; + } + #[cfg(feature = "server")] + { + let (Some(audit_service), Some(task_tracker)) = + (self.audit_service.clone(), self.task_tracker.clone()) + else { + return; + }; + + let actor_type = if self.subject.user_id.is_some() { + AuditActorType::User + } else { + AuditActorType::System + }; + + let actor_id = self + .subject + .user_id + .as_ref() + .and_then(|id| Uuid::parse_str(id).ok()); + + let parsed_org_id = org_id.and_then(|id| Uuid::parse_str(id).ok()); + let parsed_project_id = project_id.and_then(|id| Uuid::parse_str(id).ok()); + + let details = json!({ + "decision": if result.allowed { "allow" } else { "deny" }, + "policy_name": result.policy_name, + "reason": result.reason, + "resource": resource, + "action": action, + "model": model, + "request": request.map(|r| json!({ + "max_tokens": r.max_tokens, + "messages_count": r.messages_count, + "has_tools": r.has_tools, + "has_file_search": r.has_file_search, + "stream": r.stream, + })), + "org_id": org_id, + "project_id": project_id, + "subject": { + "user_id": self.subject.user_id, + "external_id": self.subject.external_id, + "email": self.subject.email, + "roles": self.subject.roles, + } + }); + + let audit_action = format!( + "api_authz.{}", + if result.allowed { "allow" } else { "deny" } + ); + let ip_address = self.request_ip.clone(); + let user_agent = self.request_user_agent.clone(); + let resource_type = resource.to_string(); + + task_tracker.spawn(async move { + let entry = CreateAuditLog { + actor_type, + actor_id, + action: audit_action, + resource_type, + resource_id: Uuid::nil(), // API requests don't have a single resource ID + org_id: parsed_org_id, + project_id: parsed_project_id, + details, + ip_address, + user_agent, + }; + + if let Err(e) = audit_service.create(entry).await { + tracing::warn!( + error = %e, + "Failed to log API authorization decision to audit log" + ); + } + }); + } + } +} diff --git a/src/middleware/util/usage.rs b/src/middleware/util/usage.rs index a0a73a9..79856bb 100644 --- a/src/middleware/util/usage.rs +++ b/src/middleware/util/usage.rs @@ -98,7 +98,7 @@ pub fn extract_full_usage_from_response(response: &Response) -> ExtractedUsage { let pricing_source = headers .get("X-Pricing-Source") .and_then(|v| v.to_str().ok()) - .map(crate::pricing::CostPricingSource::from_str) + .map(crate::pricing::CostPricingSource::parse) .unwrap_or_default(); let image_count = headers diff --git a/src/models/model_pricing.rs b/src/models/model_pricing.rs index 94a4a32..21d0e65 100644 --- a/src/models/model_pricing.rs +++ b/src/models/model_pricing.rs @@ -43,7 +43,7 @@ impl PricingSource { } } - pub fn from_str(s: &str) -> Self { + pub fn parse(s: &str) -> Self { match s { "provider_api" => Self::ProviderApi, "default" => Self::Default, diff --git a/src/models/user.rs b/src/models/user.rs index 107cfa8..c24d312 100644 --- a/src/models/user.rs +++ b/src/models/user.rs @@ -35,7 +35,7 @@ impl MembershipSource { } /// Parse from database string - pub fn from_str(s: &str) -> Option { + pub fn parse(s: &str) -> Option { match s { "manual" => Some(Self::Manual), "jit" => Some(Self::Jit), diff --git a/src/observability/mod.rs b/src/observability/mod.rs index e2eb71d..e31e6e7 100644 --- a/src/observability/mod.rs +++ b/src/observability/mod.rs @@ -7,9 +7,12 @@ //! - SIEM integration for enterprise security monitoring pub mod metrics; +#[cfg(feature = "server")] pub mod siem; +#[cfg(feature = "server")] mod tracing_init; +#[cfg(feature = "server")] pub use tracing_init::*; /// Set the current span's OpenTelemetry status to `Ok`. diff --git a/src/observability/tracing_init.rs b/src/observability/tracing_init.rs index 14dbe6f..880b9fc 100644 --- a/src/observability/tracing_init.rs +++ b/src/observability/tracing_init.rs @@ -3,20 +3,27 @@ //! OpenTelemetry distributed tracing can be enabled via configuration. //! Requires the `otlp` feature for OTLP export support. +#[cfg(feature = "server")] #[cfg(feature = "otlp")] use opentelemetry::trace::TracerProvider as _; +#[cfg(feature = "server")] #[cfg(feature = "otlp")] use opentelemetry_sdk::trace::{Sampler, SdkTracerProvider}; +#[cfg(feature = "server")] use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt}; // Stub types for when OTLP feature is disabled +#[cfg(feature = "server")] #[cfg(not(feature = "otlp"))] struct SdkTracerProviderStub; +#[cfg(feature = "server")] #[cfg(not(feature = "otlp"))] struct TracerStub; +#[cfg(feature = "server")] #[cfg(feature = "otlp")] use crate::config::{OtlpProtocol, PropagationFormat, SamplingStrategy}; +#[cfg(feature = "server")] use crate::{ config::{LogFormat, LoggingConfig, ObservabilityConfig}, observability::siem::{CefConfig, CefLayer, LeefConfig, LeefLayer, SyslogConfig, SyslogLayer}, @@ -28,6 +35,7 @@ use crate::{ /// - Console logging with configurable format (pretty, compact, JSON) /// - Environment-based log filtering /// - OpenTelemetry distributed tracing (if configured) +#[cfg(feature = "server")] pub fn init_tracing(config: &ObservabilityConfig) -> Result { let logging = &config.logging; let filter = build_env_filter(logging); @@ -494,6 +502,7 @@ fn install_propagator(format: &PropagationFormat) { } /// Build the environment filter from logging config. +#[cfg(feature = "server")] fn build_env_filter(config: &LoggingConfig) -> EnvFilter { // Start with the configured level let base_level = match config.level { @@ -521,6 +530,7 @@ fn build_env_filter(config: &LoggingConfig) -> EnvFilter { } /// Guard that ensures OpenTelemetry is properly shut down. +#[cfg(feature = "server")] pub struct TracingGuard { #[cfg(feature = "otlp")] provider: Option, @@ -529,6 +539,7 @@ pub struct TracingGuard { provider: Option, } +#[cfg(feature = "server")] impl Drop for TracingGuard { fn drop(&mut self) { // Shutdown the tracer provider to flush pending spans @@ -542,6 +553,7 @@ impl Drop for TracingGuard { } /// Tracing initialization errors. +#[cfg(feature = "server")] #[derive(Debug, thiserror::Error)] pub enum TracingError { #[error("Failed to initialize tracing: {0}")] diff --git a/src/pricing/mod.rs b/src/pricing/mod.rs index efbdf61..562c422 100644 --- a/src/pricing/mod.rs +++ b/src/pricing/mod.rs @@ -30,7 +30,7 @@ impl CostPricingSource { } } - pub fn from_str(s: &str) -> Self { + pub fn parse(s: &str) -> Self { match s { "provider" => Self::Provider, "provider_config" => Self::ProviderConfig, diff --git a/src/providers/anthropic/mod.rs b/src/providers/anthropic/mod.rs index fda3e9a..a89e494 100644 --- a/src/providers/anthropic/mod.rs +++ b/src/providers/anthropic/mod.rs @@ -18,9 +18,12 @@ use convert::{ convert_stop, convert_tool_choice, convert_tools, supports_adaptive_thinking, }; use serde::Deserialize; +#[cfg(not(target_arch = "wasm32"))] use stream::{AnthropicToOpenAIStream, AnthropicToResponsesStream}; use types::{AnthropicMetadata, AnthropicRequest, AnthropicResponse}; +#[cfg(not(target_arch = "wasm32"))] +use crate::providers::response::streaming_response; use crate::{ api_types::{ CreateChatCompletionPayload, CreateCompletionPayload, CreateEmbeddingPayload, @@ -32,7 +35,7 @@ use crate::{ circuit_breaker::CircuitBreaker, error::AnthropicErrorParser, image::{ImageFetchConfig, preprocess_messages_for_images}, - response::{error_response, json_response, streaming_response}, + response::{error_response, json_response}, retry::with_circuit_breaker_and_retry, }, }; @@ -115,7 +118,8 @@ impl AnthropicProvider { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl Provider for AnthropicProvider { fn default_health_check_model(&self) -> Option<&str> { Some("claude-haiku-4-5-20251001") @@ -237,19 +241,29 @@ impl Provider for AnthropicProvider { } if stream { - // Transform Anthropic SSE events to OpenAI-compatible format - use futures_util::StreamExt; - - let byte_stream = - response - .bytes_stream() - .map(|result| -> Result { - result.map_err(std::io::Error::other) - }); - let transformed_stream = - AnthropicToOpenAIStream::new(byte_stream, &self.streaming_buffer); - - streaming_response(status, transformed_stream) + #[cfg(not(target_arch = "wasm32"))] + { + // Transform Anthropic SSE events to OpenAI-compatible format + use futures_util::StreamExt; + + let byte_stream = + response + .bytes_stream() + .map(|result| -> Result { + result.map_err(std::io::Error::other) + }); + let transformed_stream = + AnthropicToOpenAIStream::new(byte_stream, &self.streaming_buffer); + + streaming_response(status, transformed_stream) + } + #[cfg(target_arch = "wasm32")] + { + // WASM reqwest streams are !Send; buffer the full response + let anthropic_response: AnthropicResponse = response.json().await?; + let openai_response = convert_response(anthropic_response); + json_response(status, &openai_response) + } } else { let anthropic_response: AnthropicResponse = response.json().await?; let openai_response = convert_response(anthropic_response); @@ -366,19 +380,33 @@ impl Provider for AnthropicProvider { } if stream { - // Transform Anthropic SSE events to OpenAI Responses API format - use futures_util::StreamExt; - - let byte_stream = - response - .bytes_stream() - .map(|result| -> Result { - result.map_err(std::io::Error::other) - }); - let transformed_stream = - AnthropicToResponsesStream::new(byte_stream, &self.streaming_buffer); - - streaming_response(status, transformed_stream) + #[cfg(not(target_arch = "wasm32"))] + { + // Transform Anthropic SSE events to OpenAI Responses API format + use futures_util::StreamExt; + + let byte_stream = + response + .bytes_stream() + .map(|result| -> Result { + result.map_err(std::io::Error::other) + }); + let transformed_stream = + AnthropicToResponsesStream::new(byte_stream, &self.streaming_buffer); + + streaming_response(status, transformed_stream) + } + #[cfg(target_arch = "wasm32")] + { + // WASM reqwest streams are !Send; buffer the full response + let anthropic_response: AnthropicResponse = response.json().await?; + let responses_response = convert_anthropic_to_responses_response( + anthropic_response, + payload.reasoning.as_ref(), + payload.user, + ); + json_response(status, &responses_response) + } } else { let anthropic_response: AnthropicResponse = response.json().await?; let responses_response = convert_anthropic_to_responses_response( diff --git a/src/providers/azure_openai/mod.rs b/src/providers/azure_openai/mod.rs index b91bb97..aa1f223 100644 --- a/src/providers/azure_openai/mod.rs +++ b/src/providers/azure_openai/mod.rs @@ -153,7 +153,8 @@ impl AzureOpenAIProvider { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl Provider for AzureOpenAIProvider { #[tracing::instrument( skip(self, client, payload), diff --git a/src/providers/bedrock/mod.rs b/src/providers/bedrock/mod.rs index 3380aba..475159b 100644 --- a/src/providers/bedrock/mod.rs +++ b/src/providers/bedrock/mod.rs @@ -514,7 +514,8 @@ impl BedrockProvider { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl Provider for BedrockProvider { #[tracing::instrument( skip(self, client, payload), diff --git a/src/providers/fallback.rs b/src/providers/fallback.rs index 1308ec2..a288dd1 100644 --- a/src/providers/fallback.rs +++ b/src/providers/fallback.rs @@ -66,6 +66,7 @@ pub fn classify_provider_error(error: &ProviderError) -> FallbackDecision { /// Classifies a `reqwest::Error` for fallback purposes. fn classify_reqwest_error(error: &reqwest::Error) -> FallbackDecision { // Connection errors are retryable - different provider might be reachable + #[cfg(not(target_arch = "wasm32"))] if error.is_connect() { return FallbackDecision::Retry; } diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 0c5afa4..b7de3cf 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -75,6 +75,7 @@ use http::{ pub use registry::{CircuitBreakerRegistry, CircuitBreakerStatus}; use serde::{Deserialize, Serialize}; use thiserror::Error; +#[cfg(feature = "server")] use tokio_util::task::TaskTracker; use crate::{ @@ -121,6 +122,7 @@ pub struct CostInjectionParams<'a> { pub pricing: &'a crate::pricing::PricingConfig, pub db: Option<&'a std::sync::Arc>, pub usage_entry: Option, + #[cfg(feature = "server")] pub task_tracker: Option<&'a TaskTracker>, pub max_response_body_bytes: usize, /// Idle timeout for streaming responses in seconds. @@ -199,7 +201,8 @@ impl From for ProviderError { /// at startup and shared across all providers. This works well because reqwest maintains /// per-host connection pools internally, so each provider endpoint gets its own pool. /// See [`crate::config::HttpClientConfig`] for connection pool tuning options. -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait Provider: Send + Sync { async fn create_chat_completion( &self, @@ -523,7 +526,15 @@ async fn build_response( let status = response.status(); let body = if stream { - Body::from_stream(response.bytes_stream()) + #[cfg(not(target_arch = "wasm32"))] + { + Body::from_stream(response.bytes_stream()) + } + #[cfg(target_arch = "wasm32")] + { + // WASM reqwest streams are !Send; buffer the full response instead + Body::from(response.bytes().await?) + } } else { Body::from(response.bytes().await?) }; @@ -543,6 +554,8 @@ async fn build_response( /// For non-streaming: adds usage/cost headers by parsing the body /// For streaming: wraps body to track tokens as they arrive via SSE parsing pub async fn inject_cost_into_response(params: CostInjectionParams<'_>) -> Response { + #[cfg(feature = "server")] + let task_tracker = params.task_tracker; let CostInjectionParams { response, provider, @@ -550,11 +563,11 @@ pub async fn inject_cost_into_response(params: CostInjectionParams<'_>) -> Respo pricing, db, usage_entry, - task_tracker, max_response_body_bytes, streaming_idle_timeout_secs, validation_config, response_type, + .. } = params; // Only process successful JSON responses if !response.status().is_success() { @@ -587,78 +600,88 @@ pub async fn inject_cost_into_response(params: CostInjectionParams<'_>) -> Respo .is_some_and(|s| s.contains("chunked")); if is_streaming { - // For streaming responses, wrap the body to track tokens as they arrive - if let (Some(db_pool), Some(entry), Some(tracker)) = (db, usage_entry, task_tracker) { - use futures_util::StreamExt; + #[cfg(feature = "server")] + { + // For streaming responses, wrap the body to track tokens as they arrive + if let (Some(db_pool), Some(entry), Some(tracker)) = (db, usage_entry, task_tracker) { + use futures_util::StreamExt; - let (parts, body) = response.into_parts(); + let (parts, body) = response.into_parts(); - // Convert body to byte stream with proper type annotations - let stream = body.into_data_stream().map( + // Convert body to byte stream with proper type annotations + let stream = body.into_data_stream().map( |result: Result| -> Result { result.map_err(std::io::Error::other) }, ); - // Apply validation wrapper if enabled - // This validates each SSE chunk against the OpenAPI schema - let validated_stream = if validation_config.enabled { - let validating = crate::validation::stream::ValidatingStream::new( - stream, - response_type, - validation_config.mode, - ); - // Box to unify the stream types - Box::new(validating) - as Box< - dyn futures_util::Stream> - + Send - + Unpin, - > - } else { - Box::new(stream) - as Box< - dyn futures_util::Stream> - + Send - + Unpin, - > - }; - - // Apply idle timeout wrapper if enabled (timeout > 0) - // This terminates the stream if no chunk is received within the timeout, - // protecting against stalled providers and slow client attacks. - let idle_timeout = std::time::Duration::from_secs(streaming_idle_timeout_secs); - let timeout_stream = - crate::streaming::IdleTimeoutStream::new(validated_stream, idle_timeout); - - // Wrap with usage tracking (after idle timeout so usage is still logged on timeout) - let tracking_stream = crate::streaming::UsageTrackingStream::new( - timeout_stream, - db_pool.clone(), - std::sync::Arc::new(pricing.clone()), - entry, - provider.to_string(), - model.to_string(), - tracker.clone(), - ); + // Apply validation wrapper if enabled + // This validates each SSE chunk against the OpenAPI schema + let validated_stream = if validation_config.enabled { + let validating = crate::validation::stream::ValidatingStream::new( + stream, + response_type, + validation_config.mode, + ); + // Box to unify the stream types + Box::new(validating) + as Box< + dyn futures_util::Stream> + + Send + + Unpin, + > + } else { + Box::new(stream) + as Box< + dyn futures_util::Stream> + + Send + + Unpin, + > + }; - let new_body = axum::body::Body::from_stream(tracking_stream); - if streaming_idle_timeout_secs > 0 { - tracing::debug!( - idle_timeout_secs = streaming_idle_timeout_secs, - validation_enabled = validation_config.enabled, - "Streaming response wrapped with idle timeout and usage tracking" + // Apply idle timeout wrapper if enabled (timeout > 0) + // This terminates the stream if no chunk is received within the timeout, + // protecting against stalled providers and slow client attacks. + let idle_timeout = std::time::Duration::from_secs(streaming_idle_timeout_secs); + let timeout_stream = + crate::streaming::IdleTimeoutStream::new(validated_stream, idle_timeout); + + // Wrap with usage tracking (after idle timeout so usage is still logged on timeout) + let tracking_stream = crate::streaming::UsageTrackingStream::new( + timeout_stream, + db_pool.clone(), + std::sync::Arc::new(pricing.clone()), + entry, + provider.to_string(), + model.to_string(), + tracker.clone(), ); + + let new_body = axum::body::Body::from_stream(tracking_stream); + if streaming_idle_timeout_secs > 0 { + tracing::debug!( + idle_timeout_secs = streaming_idle_timeout_secs, + validation_enabled = validation_config.enabled, + "Streaming response wrapped with idle timeout and usage tracking" + ); + } else { + tracing::debug!( + validation_enabled = validation_config.enabled, + "Streaming response wrapped with usage tracking (idle timeout disabled)" + ); + } + return Response::from_parts(parts, new_body); } else { - tracing::debug!( - validation_enabled = validation_config.enabled, - "Streaming response wrapped with usage tracking (idle timeout disabled)" + // No DB, entry, or tracker - return untracked streaming + tracing::warn!( + "Streaming response without DB/entry/tracker - cost tracking disabled" ); + return response; } - return Response::from_parts(parts, new_body); - } else { - // No DB, entry, or tracker - return untracked streaming - tracing::warn!("Streaming response without DB/entry/tracker - cost tracking disabled"); + } + #[cfg(not(feature = "server"))] + { + // No task tracker available - return untracked streaming return response; } } @@ -878,6 +901,7 @@ pub struct MediaUsageParams<'a> { pub pricing: &'a crate::pricing::PricingConfig, pub db: Option<&'a std::sync::Arc>, pub api_key_id: Option, + #[cfg(feature = "server")] pub task_tracker: &'a TaskTracker, pub usage: crate::pricing::TokenUsage, } @@ -891,14 +915,16 @@ pub struct MediaUsageParams<'a> { /// /// Returns (cost_microcents, usage_logged) tuple pub async fn log_media_usage(params: MediaUsageParams<'_>) -> (Option, bool) { + #[cfg(feature = "server")] + let task_tracker = params.task_tracker; let MediaUsageParams { provider, model, pricing, db, api_key_id, - task_tracker, usage, + .. } = params; // Calculate cost @@ -940,6 +966,7 @@ pub async fn log_media_usage(params: MediaUsageParams<'_>) -> (Option, bool }; let db = db_pool.clone(); + #[cfg(feature = "server")] task_tracker.spawn(async move { for attempt in 0..3 { match db.usage().log(entry.clone()).await { diff --git a/src/providers/open_ai/mod.rs b/src/providers/open_ai/mod.rs index c6ac4f4..3a51914 100644 --- a/src/providers/open_ai/mod.rs +++ b/src/providers/open_ai/mod.rs @@ -144,7 +144,8 @@ impl OpenAICompatibleProvider { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl Provider for OpenAICompatibleProvider { fn default_health_check_model(&self) -> Option<&str> { Some("gpt-4o-mini") diff --git a/src/providers/response.rs b/src/providers/response.rs index 8ff0152..c2f5eaf 100644 --- a/src/providers/response.rs +++ b/src/providers/response.rs @@ -20,7 +20,9 @@ //! ``` use axum::{body::Body, response::Response}; +#[cfg(not(target_arch = "wasm32"))] use bytes::Bytes; +#[cfg(not(target_arch = "wasm32"))] use futures_util::Stream; use http::StatusCode; use serde::Serialize; @@ -73,6 +75,7 @@ pub fn json_response( /// let transformed = AnthropicToOpenAIStream::new(byte_stream, &streaming_buffer); /// streaming_response(StatusCode::OK, transformed) /// ``` +#[cfg(not(target_arch = "wasm32"))] pub fn streaming_response(status: StatusCode, stream: S) -> Result where S: Stream> + Send + 'static, diff --git a/src/providers/retry.rs b/src/providers/retry.rs index 9c2417a..bf461bb 100644 --- a/src/providers/retry.rs +++ b/src/providers/retry.rs @@ -21,9 +21,12 @@ use crate::{ /// Connection errors, timeouts, and other transient issues are retryable. pub fn is_retryable_error(error: &reqwest::Error) -> bool { // Connection errors, timeouts, and other transient issues - error.is_connect() - || error.is_timeout() - || error.is_request() + let mut retryable = error.is_timeout() || error.is_request(); + #[cfg(not(target_arch = "wasm32"))] + { + retryable = retryable || error.is_connect(); + } + retryable // Status errors where we got a response but it was a server error || error .status() diff --git a/src/providers/test/mod.rs b/src/providers/test/mod.rs index 6436b20..99199fd 100644 --- a/src/providers/test/mod.rs +++ b/src/providers/test/mod.rs @@ -276,7 +276,8 @@ fn generate_word_based_embedding(text: &str, dims: usize) -> Vec { embedding } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl Provider for TestProvider { fn default_health_check_model(&self) -> Option<&str> { Some("test-model") diff --git a/src/providers/test_utils.rs b/src/providers/test_utils.rs index 5d85358..116bf48 100644 --- a/src/providers/test_utils.rs +++ b/src/providers/test_utils.rs @@ -1059,7 +1059,37 @@ pub mod schema { .ok_or_else(|| format!("Schema '{}' not found in OpenAPI spec", name))?; // Resolve $ref references within the schema - self.resolve_refs(raw_schema.clone(), 0) + let mut resolved = self.resolve_refs(raw_schema.clone(), 0)?; + // Strip OpenAPI extension keys (x-*) from the entire tree — they aren't + // valid JSON Schema and can cause compilation failures (e.g. x-stainless-const: true) + Self::strip_openapi_extensions(&mut resolved); + Ok(resolved) + } + + /// Recursively strip non-JSON-Schema keywords from a resolved value tree. + /// Removes OpenAPI extension keys (x-*) and Draft 2019-09 keywords not in Draft 2020-12. + fn strip_openapi_extensions(value: &mut Value) { + match value { + Value::Object(map) => { + map.retain(|k, _| { + !k.starts_with("x-") && k != "$recursiveAnchor" && k != "discriminator" + }); + // Replace $recursiveRef with a permissive schema (circular ref) + if map.contains_key("$recursiveRef") { + map.clear(); + return; + } + for v in map.values_mut() { + Self::strip_openapi_extensions(v); + } + } + Value::Array(arr) => { + for v in arr { + Self::strip_openapi_extensions(v); + } + } + _ => {} + } } /// Recursively resolve $ref references in a schema. diff --git a/src/providers/vertex/mod.rs b/src/providers/vertex/mod.rs index 988f983..5b0e6b4 100644 --- a/src/providers/vertex/mod.rs +++ b/src/providers/vertex/mod.rs @@ -324,7 +324,8 @@ impl VertexProvider { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl Provider for VertexProvider { #[tracing::instrument( skip(self, client, payload), diff --git a/src/routes/admin/mod.rs b/src/routes/admin/mod.rs index a14d4a0..f2f8b49 100644 --- a/src/routes/admin/mod.rs +++ b/src/routes/admin/mod.rs @@ -7,6 +7,7 @@ mod csv_export; pub mod dlq; #[cfg(feature = "sso")] pub mod domain_verifications; +#[cfg(feature = "server")] pub mod dynamic_providers; mod error; pub mod me; @@ -35,14 +36,19 @@ pub mod teams; pub mod ui_config; pub mod usage; pub mod users; -use axum::{ - Router, - routing::{delete, get, patch, post, put}, -}; + +#[cfg(any(feature = "server", feature = "wasm"))] +use axum::Router; +#[cfg(feature = "server")] +use axum::routing::{delete, get, patch, post, put}; pub use error::{AdminError, AuditActor}; +#[cfg(any(feature = "server", feature = "wasm"))] use crate::AppState; +#[cfg(feature = "wasm")] +use crate::compat::wasm_routing::{delete, get, patch, post, put}; +#[cfg(any(feature = "server", feature = "wasm"))] pub fn get_admin_routes() -> Router { Router::new().nest("/v1", admin_v1_routes()) } @@ -50,6 +56,7 @@ pub fn get_admin_routes() -> Router { /// Get admin routes with authentication middleware applied. /// This requires UI auth (Zero Trust or OIDC) to be configured. /// Note: The middleware layer is applied in main.rs where state is available. +#[cfg(any(feature = "server", feature = "wasm"))] pub fn get_protected_admin_routes() -> Router { // The protection is applied in build_app via route_layer Router::new().nest("/v1", admin_v1_routes()) @@ -57,30 +64,33 @@ pub fn get_protected_admin_routes() -> Router { /// Get public admin routes that don't require authentication. /// These are needed for frontend bootstrap before the user logs in. +#[cfg(any(feature = "server", feature = "wasm"))] pub fn get_public_admin_routes() -> Router { Router::new().nest("/v1", public_admin_v1_routes()) } -fn public_admin_v1_routes() -> Router { +#[cfg(any(feature = "server", feature = "wasm"))] +pub(crate) fn public_admin_v1_routes() -> Router { Router::new() // UI Configuration (unauthenticated - needed for frontend bootstrap) .route("/ui/config", get(ui_config::get_ui_config)) } -fn admin_v1_routes() -> Router { +#[cfg(any(feature = "server", feature = "wasm"))] +pub(crate) fn admin_v1_routes() -> Router { let router = Router::new() // Self-service endpoints (current user) .route("/me", delete(me::delete)) .route("/me/export", get(me::export)) .route( "/me/providers", - get(me_providers::list).post(me_providers::create), + get(me_providers::list).merge(post(me_providers::create)), ) .route( "/me/providers/{id}", get(me_providers::get) - .patch(me_providers::update) - .delete(me_providers::delete), + .merge(patch(me_providers::update)) + .merge(delete(me_providers::delete)), ) .route( "/me/providers/test-credentials", @@ -96,88 +106,92 @@ fn admin_v1_routes() -> Router { ) .route( "/me/api-keys", - get(me_api_keys::list).post(me_api_keys::create), + get(me_api_keys::list).merge(post(me_api_keys::create)), ) .route( "/me/api-keys/{key_id}", - get(me_api_keys::get).delete(me_api_keys::revoke), + get(me_api_keys::get).merge(delete(me_api_keys::revoke)), ) .route("/me/api-keys/{key_id}/rotate", post(me_api_keys::rotate)) // Organizations .route( "/organizations", - post(organizations::create).get(organizations::list), + post(organizations::create).merge(get(organizations::list)), ) .route( "/organizations/{slug}", get(organizations::get) - .patch(organizations::update) - .delete(organizations::delete), + .merge(patch(organizations::update)) + .merge(delete(organizations::delete)), ) // Projects .route( "/organizations/{org_slug}/projects", - post(projects::create).get(projects::list), + post(projects::create).merge(get(projects::list)), ) .route( "/organizations/{org_slug}/projects/{project_slug}", get(projects::get) - .patch(projects::update) - .delete(projects::delete), + .merge(patch(projects::update)) + .merge(delete(projects::delete)), ) // Teams .route( "/organizations/{org_slug}/teams", - post(teams::create).get(teams::list), + post(teams::create).merge(get(teams::list)), ) .route( "/organizations/{org_slug}/teams/{team_slug}", - get(teams::get).patch(teams::update).delete(teams::delete), + get(teams::get) + .merge(patch(teams::update)) + .merge(delete(teams::delete)), ) // Team memberships .route( "/organizations/{org_slug}/teams/{team_slug}/members", - get(teams::list_members).post(teams::add_member), + get(teams::list_members).merge(post(teams::add_member)), ) .route( "/organizations/{org_slug}/teams/{team_slug}/members/{user_id}", - patch(teams::update_member).delete(teams::remove_member), + patch(teams::update_member).merge(delete(teams::remove_member)), ) // Service Accounts .route( "/organizations/{org_slug}/service-accounts", - post(service_accounts::create).get(service_accounts::list), + post(service_accounts::create).merge(get(service_accounts::list)), ) .route( "/organizations/{org_slug}/service-accounts/{sa_slug}", get(service_accounts::get) - .patch(service_accounts::update) - .delete(service_accounts::delete), + .merge(patch(service_accounts::update)) + .merge(delete(service_accounts::delete)), ) // Users (top-level) - .route("/users", post(users::create).get(users::list)) + .route("/users", post(users::create).merge(get(users::list))) .route( "/users/{user_id}", - get(users::get).patch(users::update).delete(users::delete), + get(users::get) + .merge(patch(users::update)) + .merge(delete(users::delete)), ) .route("/users/{user_id}/export", get(users::export)) // Organization memberships .route( "/organizations/{org_slug}/members", - get(users::list_org_members).post(users::add_org_member), + get(users::list_org_members).merge(post(users::add_org_member)), ) .route( "/organizations/{org_slug}/members/{user_id}", - delete(users::remove_org_member).patch(users::update_org_member), + delete(users::remove_org_member).merge(patch(users::update_org_member)), ) // Project memberships .route( "/organizations/{org_slug}/projects/{project_slug}/members", - get(users::list_project_members).post(users::add_project_member), + get(users::list_project_members).merge(post(users::add_project_member)), ) .route( "/organizations/{org_slug}/projects/{project_slug}/members/{user_id}", - delete(users::remove_project_member).patch(users::update_project_member), + delete(users::remove_project_member).merge(patch(users::update_project_member)), ) // API Keys .route("/api-keys", post(api_keys::create)) @@ -195,8 +209,10 @@ fn admin_v1_routes() -> Router { .route( "/organizations/{org_slug}/service-accounts/{sa_slug}/api-keys", get(api_keys::list_by_service_account), - ) - // Dynamic Providers + ); + // Dynamic Providers (requires server feature — module is cfg-gated) + #[cfg(feature = "server")] + let router = router .route("/dynamic-providers", post(dynamic_providers::create)) .route( "/dynamic-providers/{id}", @@ -223,8 +239,9 @@ fn admin_v1_routes() -> Router { .route( "/users/{user_id}/dynamic-providers", get(dynamic_providers::list_by_user), - ) - // Usage endpoints - API Key level + ); + // Usage endpoints - API Key level + let router = router .route("/api-keys/{key_id}/usage", get(usage::get_summary)) .route("/api-keys/{key_id}/usage/by-date", get(usage::get_by_date)) .route( @@ -516,15 +533,15 @@ fn admin_v1_routes() -> Router { // Model Pricing .route( "/model-pricing", - post(model_pricing::create).get(model_pricing::list_global), + post(model_pricing::create).merge(get(model_pricing::list_global)), ) .route("/model-pricing/upsert", post(model_pricing::upsert)) .route("/model-pricing/bulk", post(model_pricing::bulk_upsert)) .route( "/model-pricing/{id}", get(model_pricing::get) - .patch(model_pricing::update) - .delete(model_pricing::delete), + .merge(patch(model_pricing::update)) + .merge(delete(model_pricing::delete)), ) .route( "/model-pricing/provider/{provider}", @@ -547,8 +564,8 @@ fn admin_v1_routes() -> Router { .route( "/conversations/{id}", get(conversations::get) - .patch(conversations::update) - .delete(conversations::delete), + .merge(patch(conversations::update)) + .merge(delete(conversations::delete)), ) .route( "/conversations/{id}/messages", @@ -572,8 +589,8 @@ fn admin_v1_routes() -> Router { .route( "/prompts/{id}", get(prompts::get) - .patch(prompts::update) - .delete(prompts::delete), + .merge(patch(prompts::update)) + .merge(delete(prompts::delete)), ) .route( "/organizations/{org_slug}/prompts", @@ -613,10 +630,10 @@ fn admin_v1_routes() -> Router { get(providers::get_provider_stats_history), ) // Dead Letter Queue - .route("/dlq", get(dlq::list).delete(dlq::purge)) + .route("/dlq", get(dlq::list).merge(delete(dlq::purge))) .route("/dlq/stats", get(dlq::stats)) .route("/dlq/prune", post(dlq::prune)) - .route("/dlq/{id}", get(dlq::get).delete(dlq::delete)) + .route("/dlq/{id}", get(dlq::get).merge(delete(dlq::delete))) .route("/dlq/{id}/retry", post(dlq::retry)) // Audit Logs .route("/audit-logs", get(audit_logs::list)) @@ -641,13 +658,13 @@ fn admin_v1_routes() -> Router { // Organization RBAC Policies .route( "/organizations/{org_slug}/rbac-policies", - get(org_rbac_policies::list).post(org_rbac_policies::create), + get(org_rbac_policies::list).merge(post(org_rbac_policies::create)), ) .route( "/organizations/{org_slug}/rbac-policies/{policy_id}", get(org_rbac_policies::get) - .patch(org_rbac_policies::update) - .delete(org_rbac_policies::delete), + .merge(patch(org_rbac_policies::update)) + .merge(delete(org_rbac_policies::delete)), ) .route( "/organizations/{org_slug}/rbac-policies/{policy_id}/versions", @@ -755,7 +772,7 @@ fn admin_v1_routes() -> Router { router } -#[cfg(all(test, feature = "database-sqlite"))] +#[cfg(all(test, feature = "database-sqlite", feature = "server"))] mod tests { use axum::{ body::Body, @@ -793,8 +810,8 @@ api_key = "sk-test-key" db_id ); - let config = crate::config::GatewayConfig::from_str(&config_str) - .expect("Failed to parse test config"); + let config = + crate::config::GatewayConfig::parse(&config_str).expect("Failed to parse test config"); let state = crate::AppState::new(config.clone()) .await .expect("Failed to create AppState"); @@ -4097,8 +4114,8 @@ ttl_secs = 86400 db_id ); - let config = crate::config::GatewayConfig::from_str(&config_str) - .expect("Failed to parse test config"); + let config = + crate::config::GatewayConfig::parse(&config_str).expect("Failed to parse test config"); let state = crate::AppState::new(config.clone()) .await .expect("Failed to create AppState"); @@ -5292,8 +5309,8 @@ ttl_secs = 86400 /// Create a test application with a custom config string async fn test_app_with_config(config_str: &str) -> axum::Router { - let config = crate::config::GatewayConfig::from_str(config_str) - .expect("Failed to parse test config"); + let config = + crate::config::GatewayConfig::parse(config_str).expect("Failed to parse test config"); let state = crate::AppState::new(config.clone()) .await .expect("Failed to create AppState"); diff --git a/src/routes/api/audio.rs b/src/routes/api/audio.rs index ff4811a..cc3d687 100644 --- a/src/routes/api/audio.rs +++ b/src/routes/api/audio.rs @@ -1,9 +1,6 @@ -use axum::{ - Extension, Json, - body::Bytes, - extract::{Multipart, State}, - response::Response, -}; +#[cfg(feature = "server")] +use axum::extract::Multipart; +use axum::{Extension, Json, body::Bytes, extract::State, response::Response}; use axum_valid::Valid; use http::StatusCode; @@ -185,6 +182,7 @@ pub async fn api_v1_audio_speech( pricing: &state.pricing, db: state.db.as_ref(), api_key_id, + #[cfg(feature = "server")] task_tracker: &state.task_tracker, usage: crate::pricing::TokenUsage::for_tts_characters(character_count), }) @@ -214,6 +212,7 @@ pub async fn api_v1_audio_speech( Ok(response) } +#[cfg(feature = "server")] /// Transcribe audio to text /// /// POST /v1/audio/transcriptions @@ -523,6 +522,7 @@ pub async fn api_v1_audio_transcriptions( pricing: &state.pricing, db: state.db.as_ref(), api_key_id, + #[cfg(feature = "server")] task_tracker: &state.task_tracker, usage: crate::pricing::TokenUsage::for_audio_seconds(estimated_seconds), }) @@ -552,6 +552,7 @@ pub async fn api_v1_audio_transcriptions( Ok(response) } +#[cfg(feature = "server")] /// Translate audio to English text /// /// POST /v1/audio/translations @@ -819,6 +820,7 @@ pub async fn api_v1_audio_translations( pricing: &state.pricing, db: state.db.as_ref(), api_key_id, + #[cfg(feature = "server")] task_tracker: &state.task_tracker, usage: crate::pricing::TokenUsage::for_audio_seconds(estimated_seconds), }) diff --git a/src/routes/api/chat.rs b/src/routes/api/chat.rs index b7f0f35..6b626b0 100644 --- a/src/routes/api/chat.rs +++ b/src/routes/api/chat.rs @@ -900,6 +900,7 @@ pub async fn api_v1_chat_completions( .and_then(|a| a.project_id()) .map(|id| id.to_string()); + #[cfg(feature = "server")] state.task_tracker.spawn(async move { let params = StoreParams { payload: &payload_clone, @@ -925,6 +926,7 @@ pub async fn api_v1_chat_completions( let provider_clone = provider_name.clone(); let content_type_clone = content_type; let body_clone = body_vec.clone(); + #[cfg(feature = "server")] state.task_tracker.spawn(async move { cache .store( @@ -975,6 +977,7 @@ pub async fn api_v1_chat_completions( pricing: &state.pricing, db: state.db.as_ref(), usage_entry, + #[cfg(feature = "server")] task_tracker: Some(&state.task_tracker), max_response_body_bytes: state.config.server.max_response_body_bytes, streaming_idle_timeout_secs: state.config.server.streaming_idle_timeout_secs, @@ -1526,6 +1529,7 @@ pub async fn api_v1_responses( let provider_clone = provider_name.clone(); let content_type_clone = content_type; let body_clone = body_vec.clone(); + #[cfg(feature = "server")] state.task_tracker.spawn(async move { cache .store_responses( @@ -1576,6 +1580,7 @@ pub async fn api_v1_responses( pricing: &state.pricing, db: state.db.as_ref(), usage_entry, + #[cfg(feature = "server")] task_tracker: Some(&state.task_tracker), max_response_body_bytes: state.config.server.max_response_body_bytes, streaming_idle_timeout_secs: state.config.server.streaming_idle_timeout_secs, @@ -2112,6 +2117,7 @@ pub async fn api_v1_completions( let provider_clone = provider_name.clone(); let content_type_clone = content_type; let body_clone = body_vec.clone(); + #[cfg(feature = "server")] state.task_tracker.spawn(async move { cache .store_completions( @@ -2162,6 +2168,7 @@ pub async fn api_v1_completions( pricing: &state.pricing, db: state.db.as_ref(), usage_entry, + #[cfg(feature = "server")] task_tracker: Some(&state.task_tracker), max_response_body_bytes: state.config.server.max_response_body_bytes, streaming_idle_timeout_secs: state.config.server.streaming_idle_timeout_secs, diff --git a/src/routes/api/embeddings.rs b/src/routes/api/embeddings.rs index ca9ecb0..fd3d1d3 100644 --- a/src/routes/api/embeddings.rs +++ b/src/routes/api/embeddings.rs @@ -238,6 +238,7 @@ pub async fn api_v1_embeddings( let provider_clone = provider_name.clone(); let content_type_clone = content_type; let body_clone = body_vec.clone(); + #[cfg(feature = "server")] state.task_tracker.spawn(async move { cache .store_embeddings( @@ -277,6 +278,7 @@ pub async fn api_v1_embeddings( pricing: &state.pricing, db: state.db.as_ref(), usage_entry: None, + #[cfg(feature = "server")] task_tracker: Some(&state.task_tracker), max_response_body_bytes: state.config.server.max_response_body_bytes, streaming_idle_timeout_secs: 0, // Embeddings don't stream diff --git a/src/routes/api/files.rs b/src/routes/api/files.rs index 65a81c8..98ee84f 100644 --- a/src/routes/api/files.rs +++ b/src/routes/api/files.rs @@ -1,7 +1,9 @@ +#[cfg(feature = "server")] +use axum::extract::Multipart; use axum::{ Extension, Json, body::Bytes, - extract::{Multipart, Path, Query, State}, + extract::{Path, Query, State}, http::header, response::{IntoResponse, Response}, }; @@ -79,6 +81,7 @@ pub struct DeleteFileResponse { pub deleted: bool, } +#[cfg(feature = "server")] /// Upload a file /// /// Uploads a file that can be used with vector stores for RAG. diff --git a/src/routes/api/images.rs b/src/routes/api/images.rs index 31a24f4..7a3dd03 100644 --- a/src/routes/api/images.rs +++ b/src/routes/api/images.rs @@ -1,7 +1,9 @@ +#[cfg(feature = "server")] +use axum::extract::Multipart; use axum::{ Extension, Json, body::Bytes, - extract::{Multipart, State}, + extract::State, response::{IntoResponse, Response}, }; use axum_valid::Valid; @@ -198,6 +200,7 @@ pub async fn api_v1_images_generations( pricing: &state.pricing, db: state.db.as_ref(), api_key_id, + #[cfg(feature = "server")] task_tracker: &state.task_tracker, usage: crate::pricing::TokenUsage::for_images( image_count, @@ -233,6 +236,7 @@ pub async fn api_v1_images_generations( Ok(response) } +#[cfg(feature = "server")] /// Edit image with text instructions /// /// POST /v1/images/edits @@ -538,6 +542,7 @@ pub async fn api_v1_images_edits( pricing: &state.pricing, db: state.db.as_ref(), api_key_id, + #[cfg(feature = "server")] task_tracker: &state.task_tracker, usage: crate::pricing::TokenUsage::for_images( image_count, @@ -573,6 +578,7 @@ pub async fn api_v1_images_edits( Ok(response) } +#[cfg(feature = "server")] /// Create image variations /// /// POST /v1/images/variations @@ -838,6 +844,7 @@ pub async fn api_v1_images_variations( pricing: &state.pricing, db: state.db.as_ref(), api_key_id, + #[cfg(feature = "server")] task_tracker: &state.task_tracker, usage: crate::pricing::TokenUsage::for_images( image_count, diff --git a/src/routes/api/mod.rs b/src/routes/api/mod.rs index 817e83a..9d09cb6 100644 --- a/src/routes/api/mod.rs +++ b/src/routes/api/mod.rs @@ -1,15 +1,22 @@ +#[cfg(any(feature = "server", feature = "wasm"))] +use axum::Router; +#[cfg(feature = "server")] +use axum::middleware::from_fn_with_state; +#[cfg(feature = "server")] +use axum::routing::{delete, get, post}; use axum::{ - Extension, Json, Router, + Extension, Json, http::HeaderMap, - middleware::from_fn_with_state, response::{IntoResponse, Response}, - routing::{get, post}, }; use http::StatusCode; use serde::Deserialize; +#[cfg(feature = "server")] use tower::ServiceBuilder; use uuid::Uuid; +#[cfg(feature = "wasm")] +use crate::compat::wasm_routing::{delete, get, post}; use crate::{ AppState, api_types, auth::AuthenticatedRequest, @@ -566,6 +573,7 @@ fn log_guardrails_evaluation( }) .collect(); + #[cfg(feature = "server")] state.task_tracker.spawn(async move { let result = db .audit_logs() @@ -685,6 +693,7 @@ fn log_output_guardrails_evaluation( }) .collect(); + #[cfg(feature = "server")] state.task_tracker.spawn(async move { let mut details = serde_json::json!({ "provider": provider, @@ -746,52 +755,66 @@ fn get_services(state: &AppState) -> Result<&Services, ApiError> { }) } -pub fn get_api_routes(state: AppState) -> Router { - Router::new() +/// Route definitions for the OpenAI-compatible API. +/// +/// Shared between server and WASM builds. The server wraps these with auth/rate-limit +/// middleware in [`get_api_routes`]; the WASM build uses them directly. +#[cfg(any(feature = "server", feature = "wasm"))] +pub(crate) fn api_v1_routes() -> Router { + let router = Router::new() .route("/v1/chat/completions", post(api_v1_chat_completions)) .route("/v1/responses", post(api_v1_responses)) .route("/v1/completions", post(api_v1_completions)) .route("/v1/embeddings", post(api_v1_embeddings)) .route("/v1/models", get(api_v1_models)) // Images API (OpenAI-compatible) - .route("/v1/images/generations", post(api_v1_images_generations)) + .route("/v1/images/generations", post(api_v1_images_generations)); + #[cfg(feature = "server")] + let router = router .route("/v1/images/edits", post(api_v1_images_edits)) - .route("/v1/images/variations", post(api_v1_images_variations)) + .route("/v1/images/variations", post(api_v1_images_variations)); + let router = router // Audio API (OpenAI-compatible) - .route("/v1/audio/speech", post(api_v1_audio_speech)) + .route("/v1/audio/speech", post(api_v1_audio_speech)); + #[cfg(feature = "server")] + let router = router .route( "/v1/audio/transcriptions", post(api_v1_audio_transcriptions), ) - .route("/v1/audio/translations", post(api_v1_audio_translations)) - // Files API (OpenAI-compatible) - .route( - "/v1/files", - post(api_v1_files_upload).get(api_v1_files_list), - ) + .route("/v1/audio/translations", post(api_v1_audio_translations)); + // Files API (OpenAI-compatible) + #[cfg(feature = "server")] + let router = router.route( + "/v1/files", + post(api_v1_files_upload).merge(get(api_v1_files_list)), + ); + #[cfg(not(feature = "server"))] + let router = router.route("/v1/files", get(api_v1_files_list)); + router .route( "/v1/files/{file_id}", - get(api_v1_files_get).delete(api_v1_files_delete), + get(api_v1_files_get).merge(delete(api_v1_files_delete)), ) .route("/v1/files/{file_id}/content", get(api_v1_files_get_content)) // Vector Stores API (OpenAI-compatible) .route( "/v1/vector_stores", - post(api_v1_vector_stores_create).get(api_v1_vector_stores_list), + post(api_v1_vector_stores_create).merge(get(api_v1_vector_stores_list)), ) .route( "/v1/vector_stores/{vector_store_id}", get(api_v1_vector_stores_get) - .post(api_v1_vector_stores_modify) - .delete(api_v1_vector_stores_delete), + .merge(post(api_v1_vector_stores_modify)) + .merge(delete(api_v1_vector_stores_delete)), ) .route( "/v1/vector_stores/{vector_store_id}/files", - post(api_v1_vector_stores_create_file).get(api_v1_vector_stores_list_files), + post(api_v1_vector_stores_create_file).merge(get(api_v1_vector_stores_list_files)), ) .route( "/v1/vector_stores/{vector_store_id}/files/{file_id}", - get(api_v1_vector_stores_get_file).delete(api_v1_vector_stores_delete_file), + get(api_v1_vector_stores_get_file).merge(delete(api_v1_vector_stores_delete_file)), ) // Hadrian extension: chunk inspection (not in OpenAI API) .route( @@ -810,12 +833,19 @@ pub fn get_api_routes(state: AppState) -> Router { ) .route( "/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}", - get(api_v1_vector_stores_get_file_batch).delete(api_v1_vector_stores_cancel_file_batch), + get(api_v1_vector_stores_get_file_batch) + .merge(delete(api_v1_vector_stores_cancel_file_batch)), ) .route( "/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/files", get(api_v1_vector_stores_list_batch_files), ) +} + +/// Server-only: wraps [`api_v1_routes`] with auth, rate-limit, and authz middleware. +#[cfg(feature = "server")] +pub fn get_api_routes(state: AppState) -> Router { + api_v1_routes() // Apply middleware layers in order (ServiceBuilder runs top-to-bottom): // 1. Rate limiting - reject requests early before auth overhead // 2. Auth, budget, usage - authenticates and sets AuthenticatedRequest @@ -883,8 +913,8 @@ model_name = "secondary-model" db_id ); - let config = crate::config::GatewayConfig::from_str(&config_str) - .expect("Failed to parse test config"); + let config = + crate::config::GatewayConfig::parse(&config_str).expect("Failed to parse test config"); let state = crate::AppState::new(config.clone()) .await .expect("Failed to create AppState"); @@ -2751,8 +2781,8 @@ max_file_size_mb = {} db_id, max_file_size_mb ); - let config = crate::config::GatewayConfig::from_str(&config_str) - .expect("Failed to parse test config"); + let config = + crate::config::GatewayConfig::parse(&config_str).expect("Failed to parse test config"); let state = crate::AppState::new(config.clone()) .await .expect("Failed to create AppState"); @@ -2800,8 +2830,8 @@ model_name = "test-model" db_id ); - let config = crate::config::GatewayConfig::from_str(&config_str) - .expect("Failed to parse test config"); + let config = + crate::config::GatewayConfig::parse(&config_str).expect("Failed to parse test config"); let mut state = crate::AppState::new(config.clone()) .await .expect("Failed to create AppState"); @@ -2886,8 +2916,8 @@ model_name = "test-model" db_id ); - let config = crate::config::GatewayConfig::from_str(&config_str) - .expect("Failed to parse test config"); + let config = + crate::config::GatewayConfig::parse(&config_str).expect("Failed to parse test config"); let mut state = crate::AppState::new(config.clone()) .await .expect("Failed to create AppState"); diff --git a/src/routes/api/models.rs b/src/routes/api/models.rs index cd81868..e0bc4a3 100644 --- a/src/routes/api/models.rs +++ b/src/routes/api/models.rs @@ -313,6 +313,7 @@ pub async fn api_v1_models( }; // Collect all enabled providers across scopes, auto-paginating through cursor pages + #[cfg(not(target_arch = "wasm32"))] type ProviderPageFn = Box< dyn Fn( crate::db::repos::ListParams, @@ -326,6 +327,20 @@ pub async fn api_v1_models( >, > + Send, >; + #[cfg(target_arch = "wasm32")] + type ProviderPageFn = Box< + dyn Fn( + crate::db::repos::ListParams, + ) -> std::pin::Pin< + Box< + dyn std::future::Future< + Output = crate::db::DbResult< + crate::db::repos::ListResult, + >, + >, + >, + >, + >; let collect_all_enabled = |fetch_page: ProviderPageFn| async move { let mut all = Vec::new(); let mut params = crate::db::repos::ListParams { diff --git a/src/routes/auth.rs b/src/routes/auth.rs index 2e71382..cc4810a 100644 --- a/src/routes/auth.rs +++ b/src/routes/auth.rs @@ -1428,8 +1428,8 @@ type = "test" db_id ); - let config = crate::config::GatewayConfig::from_str(&config_str) - .expect("Failed to parse test config"); + let config = + crate::config::GatewayConfig::parse(&config_str).expect("Failed to parse test config"); let state = crate::AppState::new(config.clone()) .await .expect("Failed to create AppState"); @@ -1673,8 +1673,8 @@ type = "test" db_id ); - let config = crate::config::GatewayConfig::from_str(&config_str) - .expect("Failed to parse test config"); + let config = + crate::config::GatewayConfig::parse(&config_str).expect("Failed to parse test config"); let state = crate::AppState::new(config.clone()) .await .expect("Failed to create AppState"); @@ -2080,8 +2080,8 @@ type = "test" db_id ); - let config = crate::config::GatewayConfig::from_str(&config_str) - .expect("Failed to parse test config"); + let config = + crate::config::GatewayConfig::parse(&config_str).expect("Failed to parse test config"); let state = crate::AppState::new(config.clone()) .await .expect("Failed to create AppState"); @@ -2145,8 +2145,8 @@ type = "test" db_id ); - let config = crate::config::GatewayConfig::from_str(&config_str) - .expect("Failed to parse test config"); + let config = + crate::config::GatewayConfig::parse(&config_str).expect("Failed to parse test config"); let state = crate::AppState::new(config.clone()) .await .expect("Failed to create AppState"); diff --git a/src/routes/execution.rs b/src/routes/execution.rs index dcdf56c..843cb00 100644 --- a/src/routes/execution.rs +++ b/src/routes/execution.rs @@ -130,11 +130,15 @@ impl ApiPayload for api_types::CreateEmbeddingPayload { /// /// This trait is implemented for marker types that represent each API operation, /// allowing us to dispatch to the correct provider method generically. +/// +/// On native targets, futures must be `Send` for use with multi-threaded runtimes. +/// On WASM, `Send` is not required (single-threaded). pub trait ProviderExecutor: Send + Sync + 'static { /// The payload type for this operation. type Payload: ApiPayload; /// Execute the request against the given provider. + #[cfg(not(target_arch = "wasm32"))] fn execute( state: &AppState, provider_name: &str, @@ -142,6 +146,15 @@ pub trait ProviderExecutor: Send + Sync + 'static { payload: Self::Payload, ) -> impl std::future::Future> + Send; + /// Execute the request against the given provider. + #[cfg(target_arch = "wasm32")] + fn execute( + state: &AppState, + provider_name: &str, + provider_config: &ProviderConfig, + payload: Self::Payload, + ) -> impl std::future::Future>; + /// Name of the operation for logging/tracing. fn operation_name() -> &'static str; } @@ -781,7 +794,7 @@ mod tests { /// Create a minimal AppState for testing with the given providers config. fn create_test_state(providers: ProvidersConfig) -> AppState { - let mut config = GatewayConfig::from_str("").expect("Empty config should parse"); + let mut config = GatewayConfig::parse("").expect("Empty config should parse"); config.providers = providers; let config = Arc::new(config); diff --git a/src/routes/health.rs b/src/routes/health.rs index 579a0e5..5f3aac1 100644 --- a/src/routes/health.rs +++ b/src/routes/health.rs @@ -278,8 +278,8 @@ api_key = "sk-test-key" db_id ); - let config = crate::config::GatewayConfig::from_str(&config_str) - .expect("Failed to parse test config"); + let config = + crate::config::GatewayConfig::parse(&config_str).expect("Failed to parse test config"); let state = crate::AppState::new(config.clone()) .await .expect("Failed to create AppState"); @@ -296,8 +296,8 @@ type = "open_ai" api_key = "sk-test-key" "#; - let config = crate::config::GatewayConfig::from_str(config_str) - .expect("Failed to parse test config"); + let config = + crate::config::GatewayConfig::parse(config_str).expect("Failed to parse test config"); let state = crate::AppState::new(config.clone()) .await .expect("Failed to create AppState"); diff --git a/src/routes/mod.rs b/src/routes/mod.rs index c750b6f..5fbd00e 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -6,6 +6,7 @@ pub mod execution; pub mod health; #[cfg(feature = "sso")] pub mod scim; +#[cfg(feature = "server")] pub mod ws; pub use api::*; @@ -13,4 +14,5 @@ pub use api::*; pub use auth as auth_routes; #[cfg(feature = "sso")] pub use scim::scim_routes; +#[cfg(feature = "server")] pub use ws::ws_handler; diff --git a/src/scim/filter.rs b/src/scim/filter.rs index bd22b0d..5886fd0 100644 --- a/src/scim/filter.rs +++ b/src/scim/filter.rs @@ -231,7 +231,7 @@ impl std::error::Error for FilterParseError {} /// # Examples /// /// ``` -/// use gateway::scim::filter::parse_filter; +/// use hadrian::scim::filter::parse_filter; /// /// let filter = parse_filter("userName eq \"john\"").unwrap(); /// let filter = parse_filter("active eq true and emails pr").unwrap(); diff --git a/src/scim/patch.rs b/src/scim/patch.rs index 2d926db..13c429a 100644 --- a/src/scim/patch.rs +++ b/src/scim/patch.rs @@ -259,7 +259,7 @@ impl From for PatchError { /// # Examples /// /// ``` -/// use gateway::scim::patch::parse_path; +/// use hadrian::scim::patch::parse_path; /// /// let path = parse_path("displayName").unwrap(); /// let path = parse_path("name.familyName").unwrap(); diff --git a/src/secrets/aws.rs b/src/secrets/aws.rs index 761ce55..cc866eb 100644 --- a/src/secrets/aws.rs +++ b/src/secrets/aws.rs @@ -92,7 +92,8 @@ impl AwsSecretsManager { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl SecretManager for AwsSecretsManager { async fn get(&self, key: &str) -> SecretResult> { let name = self.full_name(key); diff --git a/src/secrets/azure.rs b/src/secrets/azure.rs index ec5ecbf..579656e 100644 --- a/src/secrets/azure.rs +++ b/src/secrets/azure.rs @@ -109,7 +109,8 @@ impl AzureKeyVaultManager { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl SecretManager for AzureKeyVaultManager { async fn get(&self, key: &str) -> SecretResult> { let name = self.full_name(key); diff --git a/src/secrets/gcp.rs b/src/secrets/gcp.rs index be2fa68..80c721b 100644 --- a/src/secrets/gcp.rs +++ b/src/secrets/gcp.rs @@ -116,7 +116,8 @@ impl GcpSecretManager { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl SecretManager for GcpSecretManager { async fn get(&self, key: &str) -> SecretResult> { let name = self.secret_version_name(key); diff --git a/src/secrets/mod.rs b/src/secrets/mod.rs index b8d290e..277f978 100644 --- a/src/secrets/mod.rs +++ b/src/secrets/mod.rs @@ -46,7 +46,8 @@ pub enum SecretError { pub type SecretResult = Result; /// Trait for managing secrets (provider API keys, etc.) -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait SecretManager: Send + Sync { /// Get a secret by key. Returns None if not found. async fn get(&self, key: &str) -> SecretResult>; @@ -82,7 +83,8 @@ impl Default for MemorySecretManager { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl SecretManager for MemorySecretManager { async fn get(&self, key: &str) -> SecretResult> { Ok(self.secrets.get(key).map(|v| v.value().clone())) @@ -114,7 +116,8 @@ impl Default for EnvSecretManager { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl SecretManager for EnvSecretManager { async fn get(&self, key: &str) -> SecretResult> { Ok(std::env::var(key).ok()) diff --git a/src/secrets/vault.rs b/src/secrets/vault.rs index 6fa47c5..89384c0 100644 --- a/src/secrets/vault.rs +++ b/src/secrets/vault.rs @@ -206,7 +206,8 @@ impl VaultSecretManager { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl SecretManager for VaultSecretManager { async fn get(&self, key: &str) -> SecretResult> { let path = self.full_path(key); diff --git a/src/services/file_search_tool.rs b/src/services/file_search_tool.rs index 2f52ab7..914b600 100644 --- a/src/services/file_search_tool.rs +++ b/src/services/file_search_tool.rs @@ -431,6 +431,7 @@ impl FileSearchAuthContext { /// /// This callback is used to send continuation requests to the provider /// after executing file_search tool calls. +#[cfg(not(target_arch = "wasm32"))] pub type ProviderCallback = Arc< dyn Fn( CreateResponsesPayload, @@ -439,6 +440,13 @@ pub type ProviderCallback = Arc< + Sync, >; +#[cfg(target_arch = "wasm32")] +pub type ProviderCallback = Arc< + dyn Fn( + CreateResponsesPayload, + ) -> Pin, ProviderError>>>>, +>; + /// Errors that can occur during file search middleware processing. #[derive(Debug, Error)] #[allow(dead_code)] // Variants will be used as implementation grows @@ -1536,7 +1544,7 @@ pub fn wrap_streaming_with_file_search( let (tx, rx) = mpsc::channel::>(32); // Spawn a task to process the stream, propagating the span context - tokio::spawn( + crate::compat::spawn_detached( async move { let mut iteration = 0; let mut current_body = body; diff --git a/src/services/file_storage.rs b/src/services/file_storage.rs index e721c74..55d2736 100644 --- a/src/services/file_storage.rs +++ b/src/services/file_storage.rs @@ -51,7 +51,8 @@ pub type FileStorageResult = Result; /// Trait for pluggable file storage backends. /// /// Implementations must be `Send + Sync` to support async contexts. -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait FileStorage: Send + Sync { /// Store file content and return the storage path/key. /// @@ -89,7 +90,8 @@ impl DatabaseFileStorage { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl FileStorage for DatabaseFileStorage { #[instrument(skip(self, content), fields(size = content.len()))] async fn store(&self, file_id: &str, content: &[u8]) -> FileStorageResult> { @@ -151,10 +153,12 @@ impl FileStorage for DatabaseFileStorage { /// /// Stores file content on the local filesystem. /// Files are stored as `{base_path}/{file_id}`. +#[cfg(feature = "server")] pub struct FilesystemFileStorage { config: FilesystemStorageConfig, } +#[cfg(feature = "server")] impl FilesystemFileStorage { pub fn new(config: FilesystemStorageConfig) -> FileStorageResult { let storage = Self { config }; @@ -186,7 +190,9 @@ impl FilesystemFileStorage { } } -#[async_trait] +#[cfg(feature = "server")] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl FileStorage for FilesystemFileStorage { #[instrument(skip(self, content), fields(size = content.len()))] async fn store(&self, file_id: &str, content: &[u8]) -> FileStorageResult> { @@ -348,7 +354,8 @@ impl S3FileStorage { } #[cfg(feature = "s3-storage")] -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl FileStorage for S3FileStorage { #[instrument(skip(self, content), fields(size = content.len(), bucket = %self.config.bucket))] async fn store(&self, file_id: &str, content: &[u8]) -> FileStorageResult> { @@ -497,13 +504,22 @@ pub async fn create_file_storage( Ok(Arc::new(DatabaseFileStorage::new(db))) } FileStorageBackend::Filesystem => { - let fs_config = config.filesystem.clone().ok_or_else(|| { - FileStorageError::Config( - "Filesystem backend requires [storage.files.filesystem] config".to_string(), - ) - })?; - info!(path = %fs_config.path, "Using filesystem file storage backend"); - Ok(Arc::new(FilesystemFileStorage::new(fs_config)?)) + #[cfg(feature = "server")] + { + let fs_config = config.filesystem.clone().ok_or_else(|| { + FileStorageError::Config( + "Filesystem backend requires [storage.files.filesystem] config".to_string(), + ) + })?; + info!(path = %fs_config.path, "Using filesystem file storage backend"); + Ok(Arc::new(FilesystemFileStorage::new(fs_config)?)) + } + #[cfg(not(feature = "server"))] + { + Err(FileStorageError::Config( + "Filesystem backend requires the 'server' feature.".to_string(), + )) + } } #[cfg(feature = "s3-storage")] FileStorageBackend::S3 => { diff --git a/src/services/mod.rs b/src/services/mod.rs index 20b0f17..fb1131f 100644 --- a/src/services/mod.rs +++ b/src/services/mod.rs @@ -67,11 +67,12 @@ pub use file_search_tool::{ FileSearchAuthContext, FileSearchContext, FileSearchToolArguments, ProviderCallback, preprocess_file_search_tools, wrap_streaming_with_file_search, }; +#[cfg(feature = "server")] +pub use file_storage::FilesystemFileStorage; #[cfg(feature = "s3-storage")] pub use file_storage::S3FileStorage; pub use file_storage::{ - DatabaseFileStorage, FileStorage, FileStorageError, FileStorageResult, FilesystemFileStorage, - create_file_storage, + DatabaseFileStorage, FileStorage, FileStorageError, FileStorageResult, create_file_storage, }; pub use files::{FilesService, FilesServiceError, FilesServiceResult}; pub use model_pricing::ModelPricingService; @@ -86,8 +87,8 @@ pub use provider_metrics::{ StatsGranularity, TimeBucketStats, }; pub use providers::{ - DynamicProviderError, DynamicProviderService, validate_provider_config, - validate_provider_config_with_url, validate_provider_type, + DynamicProviderError, DynamicProviderService, validate_provider_config_with_url, + validate_provider_type, }; pub use reranker::{ LlmReranker, NoOpReranker, RankedResult, RerankError, RerankRequest, RerankResponse, diff --git a/src/services/providers.rs b/src/services/providers.rs index d1dde04..9b0c0ad 100644 --- a/src/services/providers.rs +++ b/src/services/providers.rs @@ -54,27 +54,10 @@ const FORBIDDEN_AWS_CREDENTIAL_TYPES: &[&str] = &["default", "profile", "assume_ #[cfg(feature = "provider-vertex")] const FORBIDDEN_GCP_CREDENTIAL_TYPES: &[&str] = &["default", "service_account"]; -/// Validate provider-specific configuration. -/// -/// Different provider types require different config fields: -/// - Bedrock: requires `config.region` and `config.credentials.type` = "static" -/// - Vertex: requires either `api_key` OR (`config.project` + `config.region` -/// with `config.credentials.type` = "service_account_json") -/// - Other types: no config validation needed -/// -/// Dynamic providers must not use credential types that source from the server's -/// environment or filesystem (e.g., "default", "profile", "assume_role" for AWS; -/// "default", "service_account" for GCP). Only explicitly-provided credentials are -/// allowed to prevent users from accessing the gateway's own cloud identity. -pub fn validate_provider_config( - provider_type: &str, - config: Option<&serde_json::Value>, - api_key: Option<&str>, -) -> Result<(), AdminError> { - validate_provider_config_inner(provider_type, config, api_key, false) -} - /// Validate provider-specific configuration with SSRF protection. +/// +/// On wasm32 SSRF validation is skipped — the browser enforces its own CORS/security, +/// and `std::net::ToSocketAddrs` (DNS resolution) is not available. pub fn validate_provider_config_with_url( provider_type: &str, base_url: &str, @@ -82,7 +65,9 @@ pub fn validate_provider_config_with_url( api_key: Option<&str>, allow_loopback: bool, ) -> Result<(), AdminError> { - // Validate base URL against SSRF if non-empty + // Validate base URL against SSRF if non-empty. + // Skip on wasm32: browser enforces CORS and DNS resolution is unavailable. + #[cfg(not(target_arch = "wasm32"))] if !base_url.is_empty() { crate::validation::validate_base_url(base_url, allow_loopback) .map_err(|e| AdminError::Validation(format!("Invalid base URL: {e}")))?; @@ -469,7 +454,7 @@ impl DynamicProviderService { state: &AppState, secrets: Option<&Arc>, ) -> ConnectivityTestResponse { - let start = std::time::Instant::now(); + let start = now_ms(); let config_result = crate::routing::resolver::dynamic_provider_to_config(provider, secrets).await; @@ -486,7 +471,7 @@ impl DynamicProviderService { return ConnectivityTestResponse { status: "error".to_string(), message: "Failed to resolve provider configuration".to_string(), - latency_ms: Some(start.elapsed().as_millis() as u64), + latency_ms: Some(elapsed_ms(start)), }; } }; @@ -499,7 +484,7 @@ impl DynamicProviderService { ) .await; - let latency_ms = start.elapsed().as_millis() as u64; + let latency_ms = elapsed_ms(start); match result { Ok(models) => ConnectivityTestResponse { @@ -526,3 +511,22 @@ impl DynamicProviderService { } } } + +/// Get current time in milliseconds (cross-platform: works on both native and wasm32). +#[cfg(not(target_arch = "wasm32"))] +fn now_ms() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as u64 +} + +#[cfg(target_arch = "wasm32")] +fn now_ms() -> u64 { + js_sys::Date::now() as u64 +} + +/// Compute elapsed milliseconds since a `now_ms()` timestamp. +fn elapsed_ms(start: u64) -> u64 { + now_ms().saturating_sub(start) +} diff --git a/src/services/reranker.rs b/src/services/reranker.rs index 8f11629..0947605 100644 --- a/src/services/reranker.rs +++ b/src/services/reranker.rs @@ -281,7 +281,8 @@ impl RerankResponse { /// } /// } /// ``` -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait Reranker: Send + Sync { /// Re-rank the given search results based on relevance to the query. /// @@ -319,7 +320,8 @@ impl fmt::Debug for dyn Reranker { /// Useful as a fallback or for testing. pub struct NoOpReranker; -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl Reranker for NoOpReranker { async fn rerank(&self, request: RerankRequest) -> Result { if request.results.is_empty() { @@ -665,7 +667,8 @@ struct LlmScore { score: f64, } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl Reranker for LlmReranker { #[instrument(skip(self, request), fields( query_len = request.query.len(), diff --git a/src/services/virus_scan.rs b/src/services/virus_scan.rs index b387546..10062d1 100644 --- a/src/services/virus_scan.rs +++ b/src/services/virus_scan.rs @@ -101,7 +101,8 @@ impl ScanResult { /// Trait for virus scanning backends. /// /// Implementations must be `Send + Sync` to support async contexts. -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait VirusScanner: Send + Sync { /// Scan file content for viruses/malware. /// @@ -170,7 +171,8 @@ impl ClamAvScanner { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl VirusScanner for ClamAvScanner { #[instrument(skip(self, content), fields(size = content.len()))] async fn scan(&self, content: &[u8]) -> VirusScanResult { @@ -256,7 +258,8 @@ impl VirusScanner for ClamAvScanner { /// Used when virus scanning is disabled. pub struct NoOpScanner; -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl VirusScanner for NoOpScanner { async fn scan(&self, content: &[u8]) -> VirusScanResult { Ok(ScanResult::clean(content.len(), 0)) diff --git a/src/streaming/mod.rs b/src/streaming/mod.rs index 973217f..2af4fb2 100644 --- a/src/streaming/mod.rs +++ b/src/streaming/mod.rs @@ -13,6 +13,7 @@ use bytes::Bytes; use futures_util::stream::Stream; use serde_json::Value; use tokio::time::Sleep; +#[cfg(feature = "server")] use tokio_util::task::TaskTracker; use crate::{db::DbPool, models::UsageLogEntry, observability::metrics, pricing::PricingConfig}; @@ -335,7 +336,7 @@ pub struct TokenAccumulator { /// Uses NONE_SENTINEL for None. provider_cost_nanodollars: AtomicI64, /// How the generation ended (stop, length, etc.) - finish_reason: parking_lot::Mutex>, + finish_reason: crate::compat::Mutex>, } impl Default for TokenAccumulator { @@ -348,7 +349,7 @@ impl Default for TokenAccumulator { cached_tokens: AtomicI64::new(NONE_SENTINEL), reasoning_tokens: AtomicI64::new(NONE_SENTINEL), provider_cost_nanodollars: AtomicI64::new(NONE_SENTINEL), - finish_reason: parking_lot::Mutex::new(None), + finish_reason: crate::compat::Mutex::new(None), } } } @@ -456,6 +457,7 @@ pub struct UsageTrackingStream { accumulated_tokens: Arc, usage_logger: Arc, stream_ended: bool, + #[cfg(feature = "server")] task_tracker: TaskTracker, /// Streaming metrics tracking streaming_metrics: Arc, @@ -572,6 +574,7 @@ pub struct UsageLogger { usage_entry: UsageLogEntry, provider: String, model: String, + #[cfg(feature = "server")] task_tracker: TaskTracker, } @@ -582,7 +585,7 @@ impl UsageLogger { usage_entry: UsageLogEntry, provider: String, model: String, - task_tracker: TaskTracker, + #[cfg(feature = "server")] task_tracker: TaskTracker, ) -> Self { Self { db, @@ -590,6 +593,7 @@ impl UsageLogger { usage_entry, provider, model, + #[cfg(feature = "server")] task_tracker, } } @@ -685,6 +689,7 @@ impl UsageLogger { // Log to database with retry logic, using task_tracker to ensure completion on shutdown let db = self.db.clone(); + #[cfg(feature = "server")] self.task_tracker.spawn(async move { for attempt in 0..3 { match db.usage().log(entry.clone()).await { @@ -732,7 +737,7 @@ where usage_entry: UsageLogEntry, provider: String, model: String, - task_tracker: TaskTracker, + #[cfg(feature = "server")] task_tracker: TaskTracker, ) -> Self { let logger = Arc::new(UsageLogger::new( db, @@ -740,6 +745,7 @@ where usage_entry, provider.clone(), model.clone(), + #[cfg(feature = "server")] task_tracker.clone(), )); @@ -748,6 +754,7 @@ where accumulated_tokens: Arc::new(TokenAccumulator::default()), usage_logger: logger, stream_ended: false, + #[cfg(feature = "server")] task_tracker: task_tracker.clone(), streaming_metrics: Arc::new(StreamingMetrics::new(provider, model)), } @@ -817,6 +824,7 @@ where streaming_metrics.report("completed"); // Use task_tracker to ensure usage logging completes during graceful shutdown + #[cfg(feature = "server")] self.task_tracker.spawn(async move { logger.log_usage(&tokens).await; }); @@ -836,6 +844,7 @@ where streaming_metrics.report("error"); // Use task_tracker to ensure usage logging completes during graceful shutdown + #[cfg(feature = "server")] self.task_tracker.spawn(async move { tracing::warn!("Stream ended with error, logging partial usage"); logger.log_usage(&tokens).await; @@ -872,6 +881,7 @@ impl Drop for UsageTrackingStream { // Spawn async task to log usage // Note: We can't await here since Drop is sync, so we spawn a task. // The task_tracker ensures this completes during graceful shutdown. + #[cfg(feature = "server")] self.task_tracker.spawn(async move { tracing::warn!( "Stream dropped without completing - logging partial usage for budget accuracy" diff --git a/src/tests/provider_e2e.rs b/src/tests/provider_e2e.rs index 3bdaa58..11fc1c8 100644 --- a/src/tests/provider_e2e.rs +++ b/src/tests/provider_e2e.rs @@ -879,7 +879,7 @@ default_provider = "mock-provider" db_id, provider_config, spec.extra_config ); - let config = GatewayConfig::from_str(&config_str).expect("Failed to parse test config"); + let config = GatewayConfig::parse(&config_str).expect("Failed to parse test config"); let state = crate::AppState::new(config.clone()) .await .expect("Failed to create AppState"); @@ -2487,7 +2487,7 @@ retryable_status_codes = [429, 500, 502, 503, 504] initial_delay_ms ); - let config = GatewayConfig::from_str(&config_str).expect("Failed to parse test config"); + let config = GatewayConfig::parse(&config_str).expect("Failed to parse test config"); let state = crate::AppState::new(config.clone()) .await .expect("Failed to create AppState"); diff --git a/src/usage_buffer.rs b/src/usage_buffer.rs index fd9b1d8..650b898 100644 --- a/src/usage_buffer.rs +++ b/src/usage_buffer.rs @@ -15,18 +15,13 @@ //! At high request rates, batching reduces write pressure: //! - 100K requests/sec → ~100 batch operations/sec (1000x reduction) //! - Lock-free MPSC channel eliminates mutex contention on push() +//! +//! ## Availability +//! `UsageLogBuffer` requires the `concurrency` feature (crossbeam channels). +//! Without it (e.g. WASM builds), only `UsageBufferConfig` is available and +//! usage entries are written directly to the database without buffering. -use std::{sync::Arc, time::Duration}; - -use chrono::Utc; -use crossbeam_channel::{Receiver, Sender, TrySendError}; -use uuid::Uuid; - -use crate::{ - events::{EventBus, ServerEvent}, - models::UsageLogEntry, - usage_sink::UsageSink, -}; +use std::time::Duration; /// Configuration for the usage log buffer. #[derive(Debug, Clone)] @@ -63,376 +58,405 @@ impl From<&crate::config::UsageBufferConfig> for UsageBufferConfig { } } -/// Async buffer for usage log entries. -/// -/// Entries are collected and flushed to configured sinks in batches. -/// The buffer flushes when: -/// - The buffer reaches `max_size` entries -/// - The `flush_interval` timer expires -/// - `shutdown()` is called (during graceful shutdown) -/// -/// Uses a lock-free MPSC channel for push operations, eliminating mutex -/// contention under high load. If the channel is full (exceeds `max_pending_entries`), -/// new entries are dropped to prevent OOM. -pub struct UsageLogBuffer { - /// Lock-free sender for push operations. - sender: Sender, - /// Receiver for the background worker (only used by start_worker). - receiver: Receiver, - config: UsageBufferConfig, - /// Flag to signal shutdown. - shutdown: Arc, - /// Optional event bus for publishing usage events. - event_bus: Option>, - /// Count of entries dropped due to buffer overflow. - dropped_count: std::sync::atomic::AtomicU64, -} - -impl UsageLogBuffer { - /// Create a new usage log buffer with the given configuration. - pub fn new(config: UsageBufferConfig) -> Self { - // Use max_pending_entries as channel capacity, or a reasonable default if 0 - let capacity = if config.max_pending_entries > 0 { - config.max_pending_entries - } else { - // Unbounded is risky; use a large but bounded capacity - 1_000_000 - }; - let (sender, receiver) = crossbeam_channel::bounded(capacity); +// ───────────────────────────────────────────────────────────────────────────── +// UsageLogBuffer — requires `concurrency` feature (crossbeam channels) +// ───────────────────────────────────────────────────────────────────────────── - Self { - sender, - receiver, - config, - shutdown: Arc::new(std::sync::atomic::AtomicBool::new(false)), - event_bus: None, - dropped_count: std::sync::atomic::AtomicU64::new(0), - } - } +#[cfg(feature = "concurrency")] +mod buffer { + use std::sync::Arc; - /// Create a new usage log buffer with EventBus for real-time notifications. - pub fn with_event_bus(config: UsageBufferConfig, event_bus: Arc) -> Self { - let capacity = if config.max_pending_entries > 0 { - config.max_pending_entries - } else { - 1_000_000 - }; - let (sender, receiver) = crossbeam_channel::bounded(capacity); + use chrono::Utc; + use crossbeam_channel::{Receiver, Sender, TrySendError}; + use uuid::Uuid; - Self { - sender, - receiver, - config, - shutdown: Arc::new(std::sync::atomic::AtomicBool::new(false)), - event_bus: Some(event_bus), - dropped_count: std::sync::atomic::AtomicU64::new(0), - } - } + use super::UsageBufferConfig; + use crate::{ + events::{EventBus, ServerEvent}, + models::UsageLogEntry, + usage_sink::UsageSink, + }; - /// Add a usage entry to the buffer. + /// Async buffer for usage log entries. /// - /// This is a **lock-free** operation using a crossbeam channel. - /// Multiple threads can call this concurrently without contention. + /// Entries are collected and flushed to configured sinks in batches. + /// The buffer flushes when: + /// - The buffer reaches `max_size` entries + /// - The `flush_interval` timer expires + /// - `shutdown()` is called (during graceful shutdown) /// - /// If the channel has exceeded `max_pending_entries`, the entry is dropped. - pub fn push(&self, entry: UsageLogEntry) { - match self.sender.try_send(entry) { - Ok(()) => {} - Err(TrySendError::Full(_)) => { - #[cfg(feature = "prometheus")] - metrics::counter!("hadrian_usage_buffer_entries_dropped_total").increment(1); - let count = self - .dropped_count - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - // Log periodically to avoid log spam (every 100 drops) - if count.is_multiple_of(100) { - tracing::warn!( - dropped_count = count + 1, - max_pending = self.config.max_pending_entries, - "Usage buffer overflow: dropping entries (sink may be slow/unavailable)" - ); - } - } - Err(TrySendError::Disconnected(_)) => { - // Channel closed - worker has shut down, silently drop + /// Uses a lock-free MPSC channel for push operations, eliminating mutex + /// contention under high load. If the channel is full (exceeds `max_pending_entries`), + /// new entries are dropped to prevent OOM. + pub struct UsageLogBuffer { + /// Lock-free sender for push operations. + sender: Sender, + /// Receiver for the background worker (only used by start_worker). + receiver: Receiver, + config: UsageBufferConfig, + /// Flag to signal shutdown. + shutdown: Arc, + /// Optional event bus for publishing usage events. + event_bus: Option>, + /// Count of entries dropped due to buffer overflow. + dropped_count: std::sync::atomic::AtomicU64, + } + + impl UsageLogBuffer { + /// Create a new usage log buffer with the given configuration. + pub fn new(config: UsageBufferConfig) -> Self { + // Use max_pending_entries as channel capacity, or a reasonable default if 0 + let capacity = if config.max_pending_entries > 0 { + config.max_pending_entries + } else { + // Unbounded is risky; use a large but bounded capacity + 1_000_000 + }; + let (sender, receiver) = crossbeam_channel::bounded(capacity); + + Self { + sender, + receiver, + config, + shutdown: Arc::new(std::sync::atomic::AtomicBool::new(false)), + event_bus: None, + dropped_count: std::sync::atomic::AtomicU64::new(0), } } - } - /// Get the count of entries dropped due to buffer overflow. - pub fn dropped_count(&self) -> u64 { - self.dropped_count - .load(std::sync::atomic::Ordering::Relaxed) - } + /// Create a new usage log buffer with EventBus for real-time notifications. + pub fn with_event_bus(config: UsageBufferConfig, event_bus: Arc) -> Self { + let capacity = if config.max_pending_entries > 0 { + config.max_pending_entries + } else { + 1_000_000 + }; + let (sender, receiver) = crossbeam_channel::bounded(capacity); + + Self { + sender, + receiver, + config, + shutdown: Arc::new(std::sync::atomic::AtomicBool::new(false)), + event_bus: Some(event_bus), + dropped_count: std::sync::atomic::AtomicU64::new(0), + } + } - /// Start the background flush worker. - /// - /// This spawns a task that periodically flushes the buffer to all - /// configured sinks. The worker will run until `shutdown()` is called. - pub fn start_worker(self: &Arc, sink: Arc) -> tokio::task::JoinHandle<()> { - let buffer = Arc::clone(self); - let flush_interval = self.config.flush_interval; - let max_batch_size = self.config.max_size; - - tokio::spawn(async move { - let mut batch = Vec::with_capacity(max_batch_size); - - loop { - // Drain available entries up to batch size - buffer.drain_entries(&mut batch, max_batch_size); - - // If we have entries, flush them - if !batch.is_empty() { - buffer.flush_batch(&sink, &mut batch).await; + /// Add a usage entry to the buffer. + /// + /// This is a **lock-free** operation using a crossbeam channel. + /// Multiple threads can call this concurrently without contention. + /// + /// If the channel has exceeded `max_pending_entries`, the entry is dropped. + pub fn push(&self, entry: UsageLogEntry) { + match self.sender.try_send(entry) { + Ok(()) => {} + Err(TrySendError::Full(_)) => { + #[cfg(feature = "prometheus")] + metrics::counter!("hadrian_usage_buffer_entries_dropped_total").increment(1); + let count = self + .dropped_count + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + // Log periodically to avoid log spam (every 100 drops) + if count.is_multiple_of(100) { + tracing::warn!( + dropped_count = count + 1, + max_pending = self.config.max_pending_entries, + "Usage buffer overflow: dropping entries (sink may be slow/unavailable)" + ); + } + } + Err(TrySendError::Disconnected(_)) => { + // Channel closed - worker has shut down, silently drop } + } + } - // Check for shutdown - if buffer.shutdown.load(std::sync::atomic::Ordering::Acquire) { - // Final drain and flush before exiting - buffer.drain_all(&mut batch); + /// Get the count of entries dropped due to buffer overflow. + pub fn dropped_count(&self) -> u64 { + self.dropped_count + .load(std::sync::atomic::Ordering::Relaxed) + } + + /// Start the background flush worker. + /// + /// This spawns a task that periodically flushes the buffer to all + /// configured sinks. The worker will run until `shutdown()` is called. + pub fn start_worker( + self: &Arc, + sink: Arc, + ) -> tokio::task::JoinHandle<()> { + let buffer = Arc::clone(self); + let flush_interval = self.config.flush_interval; + let max_batch_size = self.config.max_size; + + tokio::spawn(async move { + let mut batch = Vec::with_capacity(max_batch_size); + + loop { + // Drain available entries up to batch size + buffer.drain_entries(&mut batch, max_batch_size); + + // If we have entries, flush them if !batch.is_empty() { buffer.flush_batch(&sink, &mut batch).await; } - tracing::info!("Usage log buffer worker shutting down"); - break; - } - // Wait for flush interval or shutdown - tokio::time::sleep(flush_interval).await; - } - }) - } + // Check for shutdown + if buffer.shutdown.load(std::sync::atomic::Ordering::Acquire) { + // Final drain and flush before exiting + buffer.drain_all(&mut batch); + if !batch.is_empty() { + buffer.flush_batch(&sink, &mut batch).await; + } + tracing::info!("Usage log buffer worker shutting down"); + break; + } - /// Drain entries from the channel into the batch vector. - fn drain_entries(&self, batch: &mut Vec, max_size: usize) { - while batch.len() < max_size { - match self.receiver.try_recv() { - Ok(entry) => batch.push(entry), - Err(crossbeam_channel::TryRecvError::Empty) => break, - Err(crossbeam_channel::TryRecvError::Disconnected) => break, - } + // Wait for flush interval or shutdown + tokio::time::sleep(flush_interval).await; + } + }) } - } - /// Drain all remaining entries from the channel. - fn drain_all(&self, batch: &mut Vec) { - while let Ok(entry) = self.receiver.try_recv() { - batch.push(entry); + /// Drain entries from the channel into the batch vector. + fn drain_entries(&self, batch: &mut Vec, max_size: usize) { + while batch.len() < max_size { + match self.receiver.try_recv() { + Ok(entry) => batch.push(entry), + Err(crossbeam_channel::TryRecvError::Empty) => break, + Err(crossbeam_channel::TryRecvError::Disconnected) => break, + } + } } - } - /// Signal the worker to shut down. - pub fn shutdown(&self) { - self.shutdown - .store(true, std::sync::atomic::Ordering::Release); - } - - /// Flush a batch of entries to the sink. - async fn flush_batch(&self, sink: &Arc, batch: &mut Vec) { - let entry_count = batch.len(); - tracing::debug!(count = entry_count, "Flushing usage log buffer"); - - // Publish usage events to WebSocket subscribers before writing to sink - if let Some(event_bus) = &self.event_bus { - for entry in batch.iter() { - event_bus.publish(ServerEvent::UsageRecorded { - request_id: Uuid::parse_str(&entry.request_id).unwrap_or_else(|_| Uuid::nil()), - timestamp: Utc::now(), - model: entry.model.clone(), - provider: entry.provider.clone(), - input_tokens: entry.input_tokens, - output_tokens: entry.output_tokens, - cost_microcents: entry.cost_microcents, - user_id: entry.user_id, - org_id: entry.org_id, - project_id: entry.project_id, - team_id: entry.team_id, - service_account_id: entry.service_account_id, - }); + /// Drain all remaining entries from the channel. + fn drain_all(&self, batch: &mut Vec) { + while let Ok(entry) = self.receiver.try_recv() { + batch.push(entry); } } - // Write to sink(s) - match sink.write_batch(batch).await { - Ok(written) => { - tracing::debug!( - written = written, - total = entry_count, - "Usage log flush successful" - ); - } - Err(e) => { - tracing::error!( - error = %e, - count = entry_count, - "Usage log flush failed" - ); - } + /// Signal the worker to shut down. + pub fn shutdown(&self) { + self.shutdown + .store(true, std::sync::atomic::Ordering::Release); } - batch.clear(); - } + /// Flush a batch of entries to the sink. + async fn flush_batch(&self, sink: &Arc, batch: &mut Vec) { + let entry_count = batch.len(); + tracing::debug!(count = entry_count, "Flushing usage log buffer"); + + // Publish usage events to WebSocket subscribers before writing to sink + if let Some(event_bus) = &self.event_bus { + for entry in batch.iter() { + event_bus.publish(ServerEvent::UsageRecorded { + request_id: Uuid::parse_str(&entry.request_id) + .unwrap_or_else(|_| Uuid::nil()), + timestamp: Utc::now(), + model: entry.model.clone(), + provider: entry.provider.clone(), + input_tokens: entry.input_tokens, + output_tokens: entry.output_tokens, + cost_microcents: entry.cost_microcents, + user_id: entry.user_id, + org_id: entry.org_id, + project_id: entry.project_id, + team_id: entry.team_id, + service_account_id: entry.service_account_id, + }); + } + } - /// Get the current number of buffered entries. - #[allow(dead_code)] // Used in tests; public API for buffer introspection - pub fn len(&self) -> usize { - self.receiver.len() - } + // Write to sink(s) + match sink.write_batch(batch).await { + Ok(written) => { + tracing::debug!( + written = written, + total = entry_count, + "Usage log flush successful" + ); + } + Err(e) => { + tracing::error!( + error = %e, + count = entry_count, + "Usage log flush failed" + ); + } + } - /// Check if the buffer is empty. - #[allow(dead_code)] // Used in tests; public API for buffer introspection - pub fn is_empty(&self) -> bool { - self.receiver.is_empty() - } -} + batch.clear(); + } -#[cfg(test)] -mod tests { - use chrono::Utc; - use uuid::Uuid; + /// Get the current number of buffered entries. + #[allow(dead_code)] // Used in tests; public API for buffer introspection + pub fn len(&self) -> usize { + self.receiver.len() + } - use super::*; - - fn make_test_entry() -> UsageLogEntry { - UsageLogEntry { - request_id: Uuid::new_v4().to_string(), - api_key_id: Some(Uuid::new_v4()), - user_id: None, - org_id: None, - project_id: None, - team_id: None, - service_account_id: None, - model: "test-model".to_string(), - provider: "test-provider".to_string(), - input_tokens: 100, - output_tokens: 50, - cost_microcents: Some(1000), - http_referer: None, - request_at: Utc::now(), - streamed: false, - cached_tokens: 0, - reasoning_tokens: 0, - finish_reason: Some("stop".to_string()), - latency_ms: Some(100), - cancelled: false, - status_code: Some(200), - pricing_source: crate::pricing::CostPricingSource::None, - image_count: None, - audio_seconds: None, - character_count: None, - provider_source: None, + /// Check if the buffer is empty. + #[allow(dead_code)] // Used in tests; public API for buffer introspection + pub fn is_empty(&self) -> bool { + self.receiver.is_empty() } } - #[test] - fn test_buffer_push_and_len() { - let buffer = UsageLogBuffer::new(UsageBufferConfig::default()); - - assert!(buffer.is_empty()); - assert_eq!(buffer.len(), 0); + #[cfg(test)] + mod tests { + use std::time::Duration; + + use chrono::Utc; + use uuid::Uuid; + + use super::*; + + fn make_test_entry() -> UsageLogEntry { + UsageLogEntry { + request_id: Uuid::new_v4().to_string(), + api_key_id: Some(Uuid::new_v4()), + user_id: None, + org_id: None, + project_id: None, + team_id: None, + service_account_id: None, + model: "test-model".to_string(), + provider: "test-provider".to_string(), + input_tokens: 100, + output_tokens: 50, + cost_microcents: Some(1000), + http_referer: None, + request_at: Utc::now(), + streamed: false, + cached_tokens: 0, + reasoning_tokens: 0, + finish_reason: Some("stop".to_string()), + latency_ms: Some(100), + cancelled: false, + status_code: Some(200), + pricing_source: crate::pricing::CostPricingSource::None, + image_count: None, + audio_seconds: None, + character_count: None, + provider_source: None, + } + } - buffer.push(make_test_entry()); - assert_eq!(buffer.len(), 1); + #[test] + fn test_buffer_push_and_len() { + let buffer = UsageLogBuffer::new(UsageBufferConfig::default()); - buffer.push(make_test_entry()); - assert_eq!(buffer.len(), 2); - } + assert!(buffer.is_empty()); + assert_eq!(buffer.len(), 0); - #[test] - fn test_buffer_config_defaults() { - let config = UsageBufferConfig::default(); - assert_eq!(config.max_size, 1000); - assert_eq!(config.flush_interval, Duration::from_secs(1)); - assert_eq!(config.max_pending_entries, 10_000); - } + buffer.push(make_test_entry()); + assert_eq!(buffer.len(), 1); - #[test] - fn test_buffer_with_custom_config() { - let config = UsageBufferConfig { - max_size: 100, - flush_interval: Duration::from_millis(500), - max_pending_entries: 1000, - }; - let buffer = UsageLogBuffer::new(config); - - // Push entries up to the limit - for _ in 0..99 { buffer.push(make_test_entry()); + assert_eq!(buffer.len(), 2); } - assert_eq!(buffer.len(), 99); - } - #[test] - fn test_buffer_overflow_drops_new_entries() { - let config = UsageBufferConfig { - max_size: 10, - flush_interval: Duration::from_secs(60), // Long interval so no auto-flush - max_pending_entries: 5, // Small limit for testing - }; - let buffer = UsageLogBuffer::new(config); - - // Push 5 entries (reaches channel capacity) - for _ in 0..5 { - buffer.push(make_test_entry()); + #[test] + fn test_buffer_config_defaults() { + let config = UsageBufferConfig::default(); + assert_eq!(config.max_size, 1000); + assert_eq!(config.flush_interval, Duration::from_secs(1)); + assert_eq!(config.max_pending_entries, 10_000); } - assert_eq!(buffer.len(), 5); - assert_eq!(buffer.dropped_count(), 0); - // Push one more - should be dropped (channel full) - buffer.push(make_test_entry()); - assert_eq!(buffer.len(), 5); // Still 5 (new entry dropped) - assert_eq!(buffer.dropped_count(), 1); + #[test] + fn test_buffer_with_custom_config() { + let config = UsageBufferConfig { + max_size: 100, + flush_interval: Duration::from_millis(500), + max_pending_entries: 1000, + }; + let buffer = UsageLogBuffer::new(config); + + // Push entries up to the limit + for _ in 0..99 { + buffer.push(make_test_entry()); + } + assert_eq!(buffer.len(), 99); + } + + #[test] + fn test_buffer_overflow_drops_new_entries() { + let config = UsageBufferConfig { + max_size: 10, + flush_interval: Duration::from_secs(60), // Long interval so no auto-flush + max_pending_entries: 5, // Small limit for testing + }; + let buffer = UsageLogBuffer::new(config); + + // Push 5 entries (reaches channel capacity) + for _ in 0..5 { + buffer.push(make_test_entry()); + } + assert_eq!(buffer.len(), 5); + assert_eq!(buffer.dropped_count(), 0); - // Push 3 more - all should be dropped - for _ in 0..3 { + // Push one more - should be dropped (channel full) buffer.push(make_test_entry()); + assert_eq!(buffer.len(), 5); // Still 5 (new entry dropped) + assert_eq!(buffer.dropped_count(), 1); + + // Push 3 more - all should be dropped + for _ in 0..3 { + buffer.push(make_test_entry()); + } + assert_eq!(buffer.len(), 5); // Still capped at 5 + assert_eq!(buffer.dropped_count(), 4); } - assert_eq!(buffer.len(), 5); // Still capped at 5 - assert_eq!(buffer.dropped_count(), 4); - } - #[test] - fn test_buffer_large_capacity_when_zero() { - let config = UsageBufferConfig { - max_size: 100, - flush_interval: Duration::from_secs(60), - max_pending_entries: 0, // Uses large default capacity - }; - let buffer = UsageLogBuffer::new(config); - - // Push many entries - should not drop any (large capacity) - for _ in 0..200 { - buffer.push(make_test_entry()); + #[test] + fn test_buffer_large_capacity_when_zero() { + let config = UsageBufferConfig { + max_size: 100, + flush_interval: Duration::from_secs(60), + max_pending_entries: 0, // Uses large default capacity + }; + let buffer = UsageLogBuffer::new(config); + + // Push many entries - should not drop any (large capacity) + for _ in 0..200 { + buffer.push(make_test_entry()); + } + assert_eq!(buffer.len(), 200); + assert_eq!(buffer.dropped_count(), 0); } - assert_eq!(buffer.len(), 200); - assert_eq!(buffer.dropped_count(), 0); - } - #[test] - fn test_drain_entries() { - let config = UsageBufferConfig { - max_size: 10, - flush_interval: Duration::from_secs(60), - max_pending_entries: 100, - }; - let buffer = UsageLogBuffer::new(config); - - // Push 15 entries - for _ in 0..15 { - buffer.push(make_test_entry()); + #[test] + fn test_drain_entries() { + let config = UsageBufferConfig { + max_size: 10, + flush_interval: Duration::from_secs(60), + max_pending_entries: 100, + }; + let buffer = UsageLogBuffer::new(config); + + // Push 15 entries + for _ in 0..15 { + buffer.push(make_test_entry()); + } + assert_eq!(buffer.len(), 15); + + // Drain up to 10 (max_size) + let mut batch = Vec::new(); + buffer.drain_entries(&mut batch, 10); + assert_eq!(batch.len(), 10); + assert_eq!(buffer.len(), 5); // 5 remaining + + // Drain the rest + batch.clear(); + buffer.drain_entries(&mut batch, 10); + assert_eq!(batch.len(), 5); + assert_eq!(buffer.len(), 0); } - assert_eq!(buffer.len(), 15); - - // Drain up to 10 (max_size) - let mut batch = Vec::new(); - buffer.drain_entries(&mut batch, 10); - assert_eq!(batch.len(), 10); - assert_eq!(buffer.len(), 5); // 5 remaining - - // Drain the rest - batch.clear(); - buffer.drain_entries(&mut batch, 10); - assert_eq!(batch.len(), 5); - assert_eq!(buffer.len(), 0); } } + +#[cfg(feature = "concurrency")] +pub use buffer::UsageLogBuffer; diff --git a/src/usage_sink.rs b/src/usage_sink.rs index 3ba5a95..3dbbf46 100644 --- a/src/usage_sink.rs +++ b/src/usage_sink.rs @@ -39,7 +39,8 @@ use crate::{ /// Trait for usage data sinks. /// /// Implementations can write usage data to various backends. -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait UsageSink: Send + Sync { /// Write a batch of usage entries. /// @@ -79,7 +80,8 @@ impl DatabaseSink { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl UsageSink for DatabaseSink { async fn write_batch(&self, entries: &[UsageLogEntry]) -> Result { if entries.is_empty() { @@ -275,7 +277,8 @@ impl OtlpSink { } #[cfg(feature = "otlp")] -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl UsageSink for OtlpSink { async fn write_batch(&self, entries: &[UsageLogEntry]) -> Result { use opentelemetry::{ @@ -487,7 +490,8 @@ impl CompositeSink { } } -#[async_trait] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl UsageSink for CompositeSink { async fn write_batch(&self, entries: &[UsageLogEntry]) -> Result { if entries.is_empty() { diff --git a/src/validation/schema.rs b/src/validation/schema.rs index 7c11537..5caf939 100644 --- a/src/validation/schema.rs +++ b/src/validation/schema.rs @@ -217,7 +217,37 @@ impl OpenApiSchemas { .ok_or_else(|| format!("Schema '{}' not found in OpenAPI spec", name))?; // Resolve $ref references within the schema - self.resolve_refs(raw_schema.clone(), 0) + let mut resolved = self.resolve_refs(raw_schema.clone(), 0)?; + // Strip OpenAPI extension keys (x-*) from the entire tree — they aren't + // valid JSON Schema and can cause compilation failures (e.g. x-stainless-const: true) + Self::strip_openapi_extensions(&mut resolved); + Ok(resolved) + } + + /// Recursively strip non-JSON-Schema keywords from a resolved value tree. + /// Removes OpenAPI extension keys (x-*) and Draft 2019-09 keywords not in Draft 2020-12. + fn strip_openapi_extensions(value: &mut Value) { + match value { + Value::Object(map) => { + map.retain(|k, _| { + !k.starts_with("x-") && k != "$recursiveAnchor" && k != "discriminator" + }); + // Replace $recursiveRef with a permissive schema (circular ref) + if map.contains_key("$recursiveRef") { + map.clear(); + return; + } + for v in map.values_mut() { + Self::strip_openapi_extensions(v); + } + } + Value::Array(arr) => { + for v in arr { + Self::strip_openapi_extensions(v); + } + } + _ => {} + } } /// Recursively resolve $ref references in a schema. diff --git a/src/wasm.rs b/src/wasm.rs new file mode 100644 index 0000000..54a80ed --- /dev/null +++ b/src/wasm.rs @@ -0,0 +1,424 @@ +//! WASM entry point for running Hadrian in the browser. +//! +//! Exports a [`HadrianGateway`] struct that can be instantiated from JavaScript +//! (service worker). Requests are dispatched via an Axum [`Router`] — the same +//! routing engine used by the native server — so path parameters, method matching, +//! and fallback handling all work identically. +//! +//! # Architecture +//! +//! The gateway runs entirely in the browser's service worker thread: +//! - HTTP requests are intercepted by the service worker's `fetch` event handler +//! - Converted from `web_sys::Request` → `http::Request` → Axum Router → service calls +//! - Responses converted back to `web_sys::Response` for the browser +//! - Provider API calls (OpenAI, Anthropic) go through `reqwest` which uses +//! the browser's `fetch()` API on wasm32 +//! - SQLite database via sql.js (in-memory) through JS FFI bridge +//! +//! # Route Handler Reuse +//! +//! Route handlers are shared with the native server via [`admin_v1_routes()`] and +//! [`api_v1_routes()`]. The WASM router injects `Extension`, +//! `Extension`, and `Extension` layers so handlers can +//! extract them identically. Only handlers with genuinely different WASM behavior +//! (health check, auth stub, conversations stub) are defined here. +//! +//! # Axum Send compatibility +//! +//! Axum requires handler futures to be `Send`, but on wasm32 `reqwest`/`wasm-bindgen` +//! futures are `!Send`. The [`crate::compat::wasm_routing`] module provides drop-in +//! replacements for `axum::routing::{get, post, ...}` that wrap handlers in +//! [`crate::compat::WasmHandler`], asserting `Send` since wasm32 is single-threaded. + +use std::sync::Arc; + +use axum::{ + Extension, Json, Router, + extract::State, + response::{IntoResponse, Response}, +}; +use wasm_bindgen::prelude::*; + +use crate::{ + auth::Identity, + authz::AuthzEngine, + catalog, + compat::wasm_routing::get, + config, db, events, jobs, + middleware::{AdminAuth, AuthzContext, ClientInfo}, + pricing, providers, services, +}; + +/// Browser-based Hadrian gateway. +/// +/// Instantiated once in the service worker and reused for all requests. +#[wasm_bindgen] +pub struct HadrianGateway { + router: Router, +} + +#[wasm_bindgen] +impl HadrianGateway { + /// Initialize the gateway with sql.js database. Called once from the service worker. + #[wasm_bindgen(constructor)] + pub async fn new() -> Result { + tracing_wasm::set_as_global_default(); + tracing::info!("Initializing Hadrian WASM gateway"); + + let config = wasm_default_config(); + let http_client = reqwest::Client::new(); + + // No secret manager in WASM — API keys are stored directly in SQLite + // (which is persisted to IndexedDB). Using MemorySecretManager would lose + // secrets when the service worker restarts. + + let event_bus = Arc::new(events::EventBus::with_capacity( + config.features.websocket.channel_capacity, + )); + + // Initialize sql.js database via JS bridge + let pool = db::wasm_sqlite::WasmSqlitePool::new(); + pool.init() + .await + .map_err(|e| JsError::new(&format!("DB init failed: {e}")))?; + + tracing::info!("Running database migrations"); + pool.run_migrations() + .await + .map_err(|e| JsError::new(&format!("Migrations failed: {e}")))?; + + let db = Arc::new(db::DbPool::from_wasm_sqlite(pool)); + let file_storage: Arc = + Arc::new(services::DatabaseFileStorage::new(db.clone())); + let svc = services::Services::new(db.clone(), file_storage, 1024); + + // Bootstrap default user and org (auth=none) + let default_user_id = match crate::app::AppState::ensure_default_user(&svc).await { + Ok(id) => { + tracing::info!(user_id = %id, "Default anonymous user available"); + Some(id) + } + Err(e) => { + tracing::warn!(error = %e, "Failed to create default user"); + None + } + }; + + let default_org_id = match crate::app::AppState::ensure_default_org(&svc).await { + Ok(id) => { + tracing::info!(org_id = %id, "Default local organization available"); + Some(id) + } + Err(e) => { + tracing::warn!(error = %e, "Failed to create default organization"); + None + } + }; + + if let (Some(uid), Some(oid)) = (default_user_id, default_org_id) { + if let Err(e) = + crate::app::AppState::ensure_default_org_membership(&svc, uid, oid).await + { + tracing::warn!(error = %e, "Failed to add user to default organization"); + } + } + + let state = crate::app::AppState { + http_client, + config: Arc::new(config.clone()), + db: Some(db), + services: Some(svc), + cache: None, + secrets: None, + dlq: None, + pricing: Arc::new(config.pricing.clone()), + circuit_breakers: providers::CircuitBreakerRegistry::new(), + provider_health: jobs::ProviderHealthStateRegistry::new(), + #[cfg(feature = "sso")] + oidc_registry: None, + #[cfg(feature = "saml")] + saml_registry: None, + #[cfg(feature = "jwt")] + gateway_jwt_registry: None, + policy_registry: None, + response_cache: None, + semantic_cache: None, + input_guardrails: None, + output_guardrails: None, + event_bus, + file_search_service: None, + #[cfg(any( + feature = "document-extraction-basic", + feature = "document-extraction-full" + ))] + document_processor: None, + default_user_id, + default_org_id, + provider_metrics: Arc::new(services::ProviderMetricsService::new()), + model_catalog: catalog::ModelCatalogRegistry::new(), + }; + + let router = build_wasm_router(state, default_user_id, default_org_id); + + tracing::info!("Hadrian WASM gateway initialized (with database)"); + Ok(HadrianGateway { router }) + } + + /// Handle a fetch request from the service worker. + /// + /// Converts `web_sys::Request` → Axum Router dispatch → `web_sys::Response`. + pub async fn handle(&self, request: web_sys::Request) -> Result { + let http_request = convert_request(&request).await?; + + let response = tower::ServiceExt::oneshot(self.router.clone(), http_request) + .await + .map_err(|e| JsError::new(&format!("Router error: {e}")))?; + + convert_response(response).await + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Router +// ───────────────────────────────────────────────────────────────────────────── + +fn build_wasm_router( + state: crate::app::AppState, + default_user_id: Option, + default_org_id: Option, +) -> Router { + // Build permissive authz context for WASM (no RBAC in browser) + let engine = Arc::new( + AuthzEngine::new(config::RbacConfig { + enabled: false, + ..Default::default() + }) + .expect("Failed to create disabled RBAC engine"), + ); + let authz = AuthzContext::permissive(engine); + + let admin_auth = AdminAuth { + identity: Identity { + external_id: "anonymous".to_string(), + email: Some("anonymous@localhost".to_string()), + name: Some("Anonymous User".to_string()), + user_id: default_user_id, + roles: vec!["admin".to_string()], + idp_groups: Vec::new(), + org_ids: default_org_id + .map(|id| vec![id.to_string()]) + .unwrap_or_default(), + team_ids: Vec::new(), + project_ids: Vec::new(), + }, + }; + + // Shared route builders from the actual server code. + // Merge public admin routes (ui config) into the admin router so we can nest once. + let admin_routes = crate::routes::admin::admin_v1_routes() + .merge(crate::routes::admin::public_admin_v1_routes()); + let api_routes = crate::routes::api::api_v1_routes(); + + Router::new() + // WASM-specific handlers (genuinely different behavior) + .route("/health", get(health_check)) + .route("/auth/me", get(auth_me)) + // Shared routes from actual server code + .nest("/admin/v1", admin_routes) + .merge(api_routes) + // Inject extensions that middleware would normally provide + .layer(Extension(admin_auth)) + .layer(Extension(authz)) + .layer(Extension(ClientInfo::default())) + .fallback(fallback_handler) + .with_state(state) +} + +// ───────────────────────────────────────────────────────────────────────────── +// WASM-specific handlers (genuinely different behavior) +// ───────────────────────────────────────────────────────────────────────────── + +async fn health_check() -> Response { + Json(serde_json::json!({"status": "ok", "mode": "wasm"})).into_response() +} + +async fn auth_me(State(state): State) -> Response { + Json(serde_json::json!({ + "external_id": "anonymous", + "email": "anonymous@localhost", + "name": "Anonymous User", + "user_id": state.default_user_id, + "roles": ["admin"], + "idp_groups": [], + })) + .into_response() +} + +async fn fallback_handler() -> Response { + error_response(404, "Not found") +} + +// ───────────────────────────────────────────────────────────────────────────── +// Request / Response conversion +// ───────────────────────────────────────────────────────────────────────────── + +/// Convert `web_sys::Request` → `http::Request`. +async fn convert_request( + req: &web_sys::Request, +) -> Result, JsError> { + let method_str = req.method(); + let url = web_sys::Url::new(&req.url()).map_err(|_| JsError::new("Invalid request URL"))?; + + // The frontend uses /api/v1/ but backend routes are /v1/ + let raw_path = url.pathname(); + let path = raw_path + .strip_prefix("/api/v1/") + .map(|rest| format!("/v1/{rest}")) + .unwrap_or(raw_path); + + let search = url.search(); + let uri = if search.is_empty() { + path + } else { + format!("{path}{search}") + }; + + tracing::debug!(method = %method_str, uri = %uri, "WASM gateway handling request"); + + let method: http::Method = method_str + .parse() + .map_err(|_| JsError::new("Invalid HTTP method"))?; + + // Read body for methods that may have one. + // Use array_buffer() instead of text() to correctly handle binary content + // (multipart form-data, file uploads, audio). + let body = if method == http::Method::POST + || method == http::Method::PUT + || method == http::Method::PATCH + { + let buf = wasm_bindgen_futures::JsFuture::from( + req.array_buffer() + .map_err(|_| JsError::new("Failed to read request body"))?, + ) + .await + .map_err(|_| JsError::new("Failed to read request body"))?; + + let uint8 = js_sys::Uint8Array::new(&buf); + let bytes = uint8.to_vec(); + if bytes.is_empty() { + axum::body::Body::empty() + } else { + axum::body::Body::from(bytes) + } + } else { + axum::body::Body::empty() + }; + + let mut builder = http::Request::builder().method(method).uri(&uri); + + // Copy headers + let headers = req.headers(); + let entries = js_sys::try_iter(&headers).ok().flatten(); + if let Some(iter) = entries { + for entry in iter { + if let Ok(entry) = entry { + let pair = js_sys::Array::from(&entry); + if let (Some(key), Some(value)) = (pair.get(0).as_string(), pair.get(1).as_string()) + { + if let (Ok(name), Ok(val)) = ( + http::header::HeaderName::from_bytes(key.as_bytes()), + http::header::HeaderValue::from_str(&value), + ) { + builder = builder.header(name, val); + } + } + } + } + } + + builder + .body(body) + .map_err(|e| JsError::new(&format!("Failed to build request: {e}"))) +} + +/// Convert `axum::Response` → `web_sys::Response`. +async fn convert_response(response: Response) -> Result { + let (parts, body) = response.into_parts(); + + let bytes = http_body_util::BodyExt::collect(body) + .await + .map_err(|e| JsError::new(&format!("Failed to read response body: {e}")))? + .to_bytes(); + + let init = web_sys::ResponseInit::new(); + init.set_status(parts.status.as_u16()); + + let headers = web_sys::Headers::new().unwrap(); + for (key, value) in &parts.headers { + if let Ok(v) = value.to_str() { + let _ = headers.set(key.as_str(), v); + } + } + // Ensure content-type is set for JSON responses + if !parts.headers.contains_key(http::header::CONTENT_TYPE) && !bytes.is_empty() { + let _ = headers.set("content-type", "application/json"); + } + init.set_headers(&headers.into()); + + let body_js = if bytes.is_empty() { + None + } else { + let uint8 = js_sys::Uint8Array::from(bytes.as_ref()); + Some(uint8.into()) + }; + + web_sys::Response::new_with_opt_buffer_source_and_init(body_js.as_ref(), &init) + .map_err(|_| JsError::new("Failed to create response")) +} + +// ───────────────────────────────────────────────────────────────────────────── +// Helpers +// ───────────────────────────────────────────────────────────────────────────── + +fn error_response(status: u16, message: &str) -> Response { + let code = axum::http::StatusCode::from_u16(status) + .unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR); + ( + code, + Json(serde_json::json!({ + "error": { + "message": message, + "type": "error", + "code": status, + } + })), + ) + .into_response() +} + +/// Create a minimal config suitable for WASM browser operation. +fn wasm_default_config() -> config::GatewayConfig { + config::GatewayConfig { + server: config::ServerConfig { + allow_loopback_urls: true, + allow_private_urls: true, + ..Default::default() + }, + database: config::DatabaseConfig::None, + cache: config::CacheConfig::None, + auth: config::AuthConfig { + mode: config::AuthMode::None, + ..Default::default() + }, + providers: config::ProvidersConfig::default(), + limits: config::LimitsConfig::default(), + features: config::FeaturesConfig::default(), + observability: config::ObservabilityConfig::default(), + ui: config::UiConfig::default(), + docs: config::DocsConfig::default(), + pricing: pricing::PricingConfig::default(), + secrets: config::SecretsConfig::None, + retention: config::RetentionConfig::default(), + storage: config::StorageConfig::default(), + } +} diff --git a/ui/package.json b/ui/package.json index 13b9277..cfb4aed 100644 --- a/ui/package.json +++ b/ui/package.json @@ -46,6 +46,7 @@ "recharts": "^3.7.0", "remark-gfm": "^4.0.1", "shiki": "^3.22.0", + "sql.js": "^1.12.0", "streamdown": "^2.3.0", "tailwind-merge": "^3.5.0", "use-debounce": "^10.1.0", @@ -116,7 +117,8 @@ "workbox-build>ajv": ">=8.18.0", "@modelcontextprotocol/sdk>ajv": ">=8.18.0", "@modelcontextprotocol/sdk>hono": ">=4.12.4", - "@modelcontextprotocol/sdk>@hono/node-server": ">=1.19.10" + "@modelcontextprotocol/sdk>@hono/node-server": ">=1.19.10", + "@modelcontextprotocol/sdk>express-rate-limit": ">=8.2.2" } } } \ No newline at end of file diff --git a/ui/pnpm-lock.yaml b/ui/pnpm-lock.yaml index 1074725..514c870 100644 --- a/ui/pnpm-lock.yaml +++ b/ui/pnpm-lock.yaml @@ -18,6 +18,7 @@ overrides: '@modelcontextprotocol/sdk>ajv': '>=8.18.0' '@modelcontextprotocol/sdk>hono': '>=4.12.4' '@modelcontextprotocol/sdk>@hono/node-server': '>=1.19.10' + '@modelcontextprotocol/sdk>express-rate-limit': '>=8.2.2' importers: @@ -98,6 +99,9 @@ importers: shiki: specifier: ^3.22.0 version: 3.22.0 + sql.js: + specifier: ^1.12.0 + version: 1.14.1 streamdown: specifier: ^2.3.0 version: 2.3.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4) @@ -2981,8 +2985,8 @@ packages: resolution: {integrity: sha512-knvyeauYhqjOYvQ66MznSMs83wmHrCycNEN6Ao+2AeYEfxUIkuiVxdEa1qlGEPK+We3n0THiDciYSsCcgW/DoA==} engines: {node: '>=12.0.0'} - express-rate-limit@8.2.1: - resolution: {integrity: sha512-PCZEIEIxqwhzw4KF0n7QF4QqruVTcF73O5kFKUnGOyjbCCgizBBiFaYpd/fnBLUMPw/BWw9OsiN7GgrNYr7j6g==} + express-rate-limit@8.3.1: + resolution: {integrity: sha512-D1dKN+cmyPWuvB+G2SREQDzPY1agpBIcTa9sJxOPMCNeH3gwzhqJRDWCXW3gg0y//+LQ/8j52JbMROWyrKdMdw==} engines: {node: '>= 16'} peerDependencies: express: '>= 4.11' @@ -3307,8 +3311,8 @@ packages: resolution: {integrity: sha512-5Hh7Y1wQbvY5ooGgPbDaL5iYLAPzMTUrjMulskHLH6wnv/A+1q5rgEaiuqEjB+oxGXIVZs1FF+R/KPN3ZSQYYg==} engines: {node: '>=12'} - ip-address@10.0.1: - resolution: {integrity: sha512-NWv9YLW4PoW2B7xtzaS3NCot75m6nK7Icdv0o3lfMceJVRfSoQwqD4wEH5rLwoKJwUiZ/rfpiVBhnaF0FK4HoA==} + ip-address@10.1.0: + resolution: {integrity: sha512-XXADHxXmvT9+CRxhXg56LJovE+bmWnEWB78LB83VZTprKTmaC5QfruXocxzTZ2Kl0DNwKuBdlIhjL8LeY8Sf8Q==} engines: {node: '>= 12'} ipaddr.js@1.9.1: @@ -4517,6 +4521,9 @@ packages: space-separated-tokens@2.0.2: resolution: {integrity: sha512-PEGlAwrG8yXGXRjW32fGbg66JAlOAwbObuqVoJpv/mRgoWDQfgH1wDPvtzWyUSNAXBGSk8h755YDbbcEy3SH2Q==} + sql.js@1.14.1: + resolution: {integrity: sha512-gcj8zBWU5cFsi9WUP+4bFNXAyF1iRpA3LLyS/DP5xlrNzGmPIizUeBggKa8DbDwdqaKwUcTEnChtd2grWo/x/A==} + stackback@0.0.2: resolution: {integrity: sha512-1XMJE5fQo1jGH6Y/7ebnwPOBEkIEnT4QF32d5R1+VXdXveM0IBMJt8zfaxX1P3QhVwrYe+576+jkANtSS2mBbw==} @@ -6335,7 +6342,7 @@ snapshots: eventsource: 3.0.7 eventsource-parser: 3.0.6 express: 5.2.1 - express-rate-limit: 8.2.1(express@5.2.1) + express-rate-limit: 8.3.1(express@5.2.1) hono: 4.12.5 jose: 6.1.3 json-schema-typed: 8.0.2 @@ -8224,10 +8231,10 @@ snapshots: expect-type@1.3.0: {} - express-rate-limit@8.2.1(express@5.2.1): + express-rate-limit@8.3.1(express@5.2.1): dependencies: express: 5.2.1 - ip-address: 10.0.1 + ip-address: 10.1.0 express@5.2.1: dependencies: @@ -8619,7 +8626,7 @@ snapshots: internmap@2.0.3: {} - ip-address@10.0.1: {} + ip-address@10.1.0: {} ipaddr.js@1.9.1: {} @@ -10082,6 +10089,8 @@ snapshots: space-separated-tokens@2.0.2: {} + sql.js@1.14.1: {} + stackback@0.0.2: {} statuses@2.0.2: {} diff --git a/ui/src/App.tsx b/ui/src/App.tsx index 9ffa450..0eb36c7 100644 --- a/ui/src/App.tsx +++ b/ui/src/App.tsx @@ -1,60 +1,16 @@ -import { BrowserRouter, Routes, Route, Navigate } from "react-router-dom"; +import { BrowserRouter } from "react-router-dom"; import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; import { ConfigProvider } from "@/config/ConfigProvider"; import { PreferencesProvider } from "@/preferences/PreferencesProvider"; -import { AuthProvider, RequireAuth, RequireAdmin } from "@/auth"; +import { AuthProvider } from "@/auth"; import { ApiClientProvider } from "@/api/ApiClientProvider"; import { ToastProvider } from "@/components/Toast/Toast"; import { ConfirmDialogProvider } from "@/components/ConfirmDialog/ConfirmDialog"; import { CommandPaletteProvider } from "@/components/CommandPalette/CommandPalette"; import { ConversationsProvider } from "@/components/ConversationsProvider/ConversationsProvider"; import { ErrorBoundary } from "@/components/ErrorBoundary/ErrorBoundary"; -import { AppLayout } from "@/components/AppLayout/AppLayout"; -import { AdminLayout } from "@/components/AdminLayout/AdminLayout"; - -// Pages - lazy loaded for code splitting -import { lazy, Suspense } from "react"; -import { Spinner } from "@/components/Spinner/Spinner"; - -const LoginPage = lazy(() => import("@/pages/LoginPage")); -const AccountPage = lazy(() => import("@/pages/AccountPage")); -const ProjectsPage = lazy(() => import("@/pages/ProjectsPage")); -const TeamsPage = lazy(() => import("@/pages/TeamsPage")); -const KnowledgeBasesPage = lazy(() => import("@/pages/KnowledgeBasesPage")); -const ApiKeysPage = lazy(() => import("@/pages/ApiKeysPage")); -const ApiKeyDetailPage = lazy(() => import("@/pages/ApiKeyDetailPage")); - -const MyUsagePage = lazy(() => import("@/pages/MyUsagePage")); -const MyProvidersPage = lazy(() => import("@/pages/MyProvidersPage")); -const SelfServiceProjectDetailPage = lazy(() => import("@/pages/project/ProjectDetailPage")); -const StudioPage = lazy(() => import("@/pages/studio/StudioPage")); -const ChatPage = lazy(() => import("@/pages/chat/ChatPage")); -const AdminDashboardPage = lazy(() => import("@/pages/admin/DashboardPage")); -const OrganizationsPage = lazy(() => import("@/pages/admin/OrganizationsPage")); -const OrganizationDetailPage = lazy(() => import("@/pages/admin/OrganizationDetailPage")); -const ProjectDetailPage = lazy(() => import("@/pages/admin/ProjectDetailPage")); -const UsersPage = lazy(() => import("@/pages/admin/UsersPage")); -const UserDetailPage = lazy(() => import("@/pages/admin/UserDetailPage")); -const AdminApiKeysPage = lazy(() => import("@/pages/admin/ApiKeysPage")); -const ProvidersPage = lazy(() => import("@/pages/admin/ProvidersPage")); -const ProviderHealthPage = lazy(() => import("@/pages/admin/ProviderHealthPage")); -const ProviderDetailPage = lazy(() => import("@/pages/admin/ProviderDetailPage")); -const PricingPage = lazy(() => import("@/pages/admin/PricingPage")); -const UsagePage = lazy(() => import("@/pages/admin/UsagePage")); -const AdminProjectsPage = lazy(() => import("@/pages/admin/ProjectsPage")); -const AdminTeamsPage = lazy(() => import("@/pages/admin/TeamsPage")); -const ServiceAccountsPage = lazy(() => import("@/pages/admin/ServiceAccountsPage")); -const TeamDetailPage = lazy(() => import("@/pages/admin/TeamDetailPage")); -const SettingsPage = lazy(() => import("@/pages/admin/SettingsPage")); -const AuditLogsPage = lazy(() => import("@/pages/admin/AuditLogsPage")); -const VectorStoresPage = lazy(() => import("@/pages/admin/VectorStoresPage")); -const VectorStoreDetailPage = lazy(() => import("@/pages/admin/VectorStoreDetailPage")); -const SsoConnectionsPage = lazy(() => import("@/pages/admin/SsoConnectionsPage")); -const SsoGroupMappingsPage = lazy(() => import("@/pages/admin/SsoGroupMappingsPage")); -const OrgSsoConfigPage = lazy(() => import("@/pages/admin/OrgSsoConfigPage")); -const ScimConfigPage = lazy(() => import("@/pages/admin/ScimConfigPage")); -const OrgRbacPoliciesPage = lazy(() => import("@/pages/admin/OrgRbacPoliciesPage")); -const SessionInfoPage = lazy(() => import("@/pages/admin/SessionInfoPage")); +import { WasmSetupGuard } from "@/components/WasmSetup/WasmSetupGuard"; +import { AppRoutes } from "@/routes/AppRoutes"; const queryClient = new QueryClient({ defaultOptions: { @@ -65,14 +21,6 @@ const queryClient = new QueryClient({ }, }); -function PageLoader() { - return ( -
- -
- ); -} - export default function App() { return ( @@ -84,382 +32,13 @@ export default function App() { - - - - {/* Root redirect */} - } /> - - {/* Login route */} - }> - - - } - /> - - {/* Auth callback route for OIDC */} - }> - - - } - /> - - {/* Protected routes with main AppLayout (chat sidebar) */} - - - - } - > - {/* Chat routes */} - }> - - - } - /> - }> - - - } - /> - - {/* Projects route */} - }> - - - } - /> - - {/* Project detail route */} - }> - - - } - /> - - {/* Teams route */} - }> - - - } - /> - - {/* Knowledge Bases route */} - }> - - - } - /> - - {/* API Keys routes */} - }> - - - } - /> - }> - - - } - /> - - {/* Providers route (self-service) */} - }> - - - } - /> - - {/* Usage route (self-service) */} - }> - - - } - /> - - {/* Studio route */} - }> - - - } - /> - - {/* Account settings route */} - }> - - - } - /> - - {/* Session info route (debugging) */} - }> - - - } - /> - - - {/* Admin routes with AdminLayout (admin sidebar) */} - - - - } - > - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - - - {/* Catch all - redirect to chat */} - } /> - - - + + + + + + + diff --git a/ui/src/components/Header/Header.tsx b/ui/src/components/Header/Header.tsx index 275f283..84a7c35 100644 --- a/ui/src/components/Header/Header.tsx +++ b/ui/src/components/Header/Header.tsx @@ -9,12 +9,14 @@ import { Palette, Server, Shield, + Sparkles, UsersRound, } from "lucide-react"; import { Button } from "@/components/Button/Button"; import { HadrianIcon } from "@/components/HadrianIcon/HadrianIcon"; import { ThemeToggle } from "@/components/ThemeToggle/ThemeToggle"; import { UserMenu } from "@/components/UserMenu/UserMenu"; +import { useWasmSetup } from "@/components/WasmSetup/WasmSetupGuard"; import { useConfig } from "@/config/ConfigProvider"; import { usePreferences } from "@/preferences/PreferencesProvider"; import { useAuth, hasAdminAccess } from "@/auth"; @@ -55,6 +57,7 @@ export function Header({ onMenuClick, showMenuButton = false, className }: Heade const { config } = useConfig(); const { resolvedTheme } = usePreferences(); const { user } = useAuth(); + const { isWasm, openSetupWizard } = useWasmSetup(); const location = useLocation(); // Determine which logo to use based on theme @@ -136,6 +139,11 @@ export function Header({ onMenuClick, showMenuButton = false, className }: Heade {/* Right: Theme toggle and user menu */}
+ {isWasm && ( + + )}
diff --git a/ui/src/components/ModelSelector/ModelSelector.tsx b/ui/src/components/ModelSelector/ModelSelector.tsx index 660cb1d..9147aba 100644 --- a/ui/src/components/ModelSelector/ModelSelector.tsx +++ b/ui/src/components/ModelSelector/ModelSelector.tsx @@ -361,9 +361,14 @@ export function ModelSelector({
) : availableModels.length > 0 ? ( + setOpen(false)} /> + + ); +} + +export const Default: Story = { + render: () => , +}; + +function OAuthSuccessStory() { + const [open, setOpen] = useState(true); + return ( + <> + + setOpen(false)} oauthProviderName="openrouter" /> + + ); +} + +export const OAuthSuccess: Story = { + render: () => , +}; + +function OllamaDetectedStory() { + const [open, setOpen] = useState(true); + return ( + <> + + setOpen(false)} + ollamaDetected + onOllamaConnect={() => alert("Connecting Ollama...")} + /> + + ); +} + +export const OllamaDetected: Story = { + render: () => , +}; diff --git a/ui/src/components/WasmSetup/WasmSetup.tsx b/ui/src/components/WasmSetup/WasmSetup.tsx new file mode 100644 index 0000000..3a627ab --- /dev/null +++ b/ui/src/components/WasmSetup/WasmSetup.tsx @@ -0,0 +1,783 @@ +import { useState, useCallback, useEffect, useId } from "react"; +import { + ArrowRight, + ArrowLeft, + CheckCircle2, + XCircle, + Loader2, + Sparkles, + Plus, + Trash2, + ExternalLink, + Server, +} from "lucide-react"; +import { useMutation, useQueryClient } from "@tanstack/react-query"; +import { + meProvidersCreateMutation, + meProvidersListQueryKey, + apiV1ModelsQueryKey, +} from "@/api/generated/@tanstack/react-query.gen"; +import { meProvidersTestCredentials } from "@/api/generated/sdk.gen"; +import type { ConnectivityTestResponse, DynamicProviderResponse } from "@/api/generated/types.gen"; +import { + Modal, + ModalHeader, + ModalTitle, + ModalDescription, + ModalContent, + ModalFooter, +} from "@/components/Modal/Modal"; +import { Button } from "@/components/Button/Button"; +import { Input } from "@/components/Input/Input"; +import { FormField } from "@/components/FormField/FormField"; +import { startOpenRouterOAuth } from "./openrouter-oauth"; + +interface ProviderTemplate { + id: string; + label: string; + providerType: string; + defaultBaseUrl: string; + placeholder: string; + description: string; + docsUrl: string; +} + +const PROVIDER_TEMPLATES: ProviderTemplate[] = [ + { + id: "openai", + label: "OpenAI Compatible", + providerType: "open_ai", + defaultBaseUrl: "https://api.openai.com/v1", + placeholder: "sk-...", + description: "OpenAI, OpenRouter, Ollama, and other compatible APIs", + docsUrl: "https://platform.openai.com/api-keys", + }, + { + id: "anthropic", + label: "Anthropic", + providerType: "anthropic", + defaultBaseUrl: "https://api.anthropic.com", + placeholder: "sk-ant-...", + description: "Claude Opus, Sonnet, and Haiku", + docsUrl: "https://console.anthropic.com/settings/keys", + }, +]; + +type Step = "welcome" | "providers" | "done"; + +interface ProviderEntry { + /** Unique key for React rendering and state tracking. */ + key: string; + template: ProviderTemplate; + apiKey: string; + baseUrl: string; + name: string; + testResult: ConnectivityTestResponse | null; + isTesting: boolean; + isSaving: boolean; + saved: boolean; + error: string | null; +} + +let entryCounter = 0; + +function createEntry(template: ProviderTemplate, index: number): ProviderEntry { + entryCounter++; + return { + key: `${template.id}-${entryCounter}`, + template, + apiKey: "", + baseUrl: template.defaultBaseUrl, + name: index === 0 ? template.id : `${template.id}-${index + 1}`, + testResult: null, + isTesting: false, + isSaving: false, + saved: false, + error: null, + }; +} + +function initialEntries(): ProviderEntry[] { + return PROVIDER_TEMPLATES.map((t) => createEntry(t, 0)); +} + +export function WasmSetup({ + open, + onComplete, + oauthProviderName, + oauthError, + existingProviders, + ollamaDetected, + ollamaConnecting, + ollamaConnected, + onOllamaConnect, +}: { + open: boolean; + onComplete: () => void; + oauthProviderName?: string | null; + oauthError?: string | null; + existingProviders?: DynamicProviderResponse[]; + ollamaDetected?: boolean; + ollamaConnecting?: boolean; + ollamaConnected?: boolean; + onOllamaConnect?: () => void; +}) { + const [step, setStep] = useState("welcome"); + const [entries, setEntries] = useState(initialEntries); + + // Reset to welcome when the wizard is re-opened + useEffect(() => { + if (open) setStep("welcome"); + }, [open]); + + // Jump to "done" when OAuth completes (oauthProviderName arrives async) + useEffect(() => { + if (oauthProviderName) setStep("done"); + }, [oauthProviderName]); + const [oauthLoading, setOauthLoading] = useState(false); + const queryClient = useQueryClient(); + + // Detect existing OpenRouter provider from the database + const hasExistingOpenRouter = + !!oauthProviderName || + existingProviders?.some((p) => p.base_url.includes("openrouter.ai")) === true; + + const hasExistingOllama = + !!ollamaConnected || + existingProviders?.some((p) => p.base_url.includes("localhost:11434")) === true; + + const createMutation = useMutation({ + ...meProvidersCreateMutation(), + }); + + const updateEntry = useCallback((key: string, update: Partial) => { + setEntries((prev) => prev.map((e) => (e.key === key ? { ...e, ...update } : e))); + }, []); + + const addEntry = useCallback((template: ProviderTemplate) => { + setEntries((prev) => { + const count = prev.filter((e) => e.template.id === template.id).length; + return [...prev, createEntry(template, count)]; + }); + }, []); + + const removeEntry = useCallback((key: string) => { + setEntries((prev) => prev.filter((e) => e.key !== key)); + }, []); + + const handleTest = useCallback( + async (key: string) => { + const entry = entries.find((e) => e.key === key); + if (!entry) return; + + updateEntry(key, { isTesting: true, testResult: null, error: null }); + + try { + const { data } = await meProvidersTestCredentials({ + body: { + name: entry.name, + provider_type: entry.template.providerType, + base_url: entry.baseUrl, + api_key: entry.apiKey, + }, + }); + updateEntry(key, { isTesting: false, testResult: data ?? null }); + } catch (err) { + updateEntry(key, { + isTesting: false, + testResult: { status: "error", message: String(err) }, + }); + } + }, + [entries, updateEntry] + ); + + const handleSave = useCallback( + async (key: string) => { + const entry = entries.find((e) => e.key === key); + if (!entry) return; + + updateEntry(key, { isSaving: true, error: null }); + + try { + await createMutation.mutateAsync({ + body: { + name: entry.name, + provider_type: entry.template.providerType, + base_url: entry.baseUrl, + api_key: entry.apiKey, + }, + }); + queryClient.invalidateQueries({ queryKey: meProvidersListQueryKey() }); + queryClient.invalidateQueries({ queryKey: apiV1ModelsQueryKey() }); + updateEntry(key, { isSaving: false, saved: true }); + } catch (err) { + updateEntry(key, { isSaving: false, error: String(err) }); + } + }, + [entries, createMutation, queryClient, updateEntry] + ); + + const handleOpenRouterOAuth = useCallback(async () => { + setOauthLoading(true); + try { + await startOpenRouterOAuth(); + } catch { + setOauthLoading(false); + } + }, []); + + const savedCount = + entries.filter((e) => e.saved).length + + (hasExistingOpenRouter ? 1 : 0) + + (hasExistingOllama ? 1 : 0); + const hasAnySaved = savedCount > 0; + + return ( + + {step === "welcome" && ( + setStep("providers")} + onReady={() => setStep("done")} + onSkip={onComplete} + onOpenRouterOAuth={handleOpenRouterOAuth} + oauthLoading={oauthLoading} + hasExistingOpenRouter={hasExistingOpenRouter} + ollamaDetected={ollamaDetected} + ollamaConnecting={ollamaConnecting} + hasExistingOllama={hasExistingOllama} + onOllamaConnect={onOllamaConnect} + /> + )} + {step === "providers" && ( + setStep("welcome")} + onNext={() => setStep("done")} + onSkip={onComplete} + hasAnySaved={hasAnySaved} + hasExistingOpenRouter={hasExistingOpenRouter} + oauthError={oauthError ?? null} + onOpenRouterOAuth={handleOpenRouterOAuth} + oauthLoading={oauthLoading} + ollamaDetected={ollamaDetected} + ollamaConnecting={ollamaConnecting} + hasExistingOllama={hasExistingOllama} + onOllamaConnect={onOllamaConnect} + /> + )} + {step === "done" && } + + ); +} + +function WelcomeStep({ + onNext, + onReady, + onSkip, + onOpenRouterOAuth, + oauthLoading, + hasExistingOpenRouter, + ollamaDetected, + ollamaConnecting, + hasExistingOllama, + onOllamaConnect, +}: { + onNext: () => void; + onReady: () => void; + onSkip: () => void; + onOpenRouterOAuth: () => void; + oauthLoading: boolean; + hasExistingOpenRouter: boolean; + ollamaDetected?: boolean; + ollamaConnecting?: boolean; + hasExistingOllama: boolean; + onOllamaConnect?: () => void; +}) { + const hasProvider = hasExistingOpenRouter || hasExistingOllama; + return ( + <> + + Welcome to Hadrian + Open-source AI gateway + + +

+ Hadrian is a free, open-source AI gateway that lets you chat with multiple models side by + side. This is the browser edition: the gateway runs entirely in your browser. +

+

+ For the server version with teams, SSO, guardrails, and more providers, see{" "} + + hadriangateway.com + + . +

+ + {hasExistingOpenRouter ? ( +
+
+
+

OpenRouter

+

https://openrouter.ai/api/v1

+
+
+ + Connected +
+
+
+ ) : ( +
+

OpenRouter

+

+ Sign in to access 200+ models. No manual API key entry required. +

+ +
+ )} + + {hasExistingOllama ? ( +
+
+
+

Ollama

+

http://localhost:11434/v1

+
+
+ + Connected +
+
+
+ ) : ollamaDetected ? ( +
+

Ollama

+

+ Found Ollama running at localhost:11434. +

+ +
+ ) : ( +
+
+
+

Ollama

+

Not detected at localhost:11434

+
+
+ + Not found +
+
+

+ Install and start{" "} + + Ollama + + {" "} + to use local models for free. Re-open this wizard after starting it. +

+
+ )} + +

+ {hasProvider + ? "You can also add API keys from OpenAI, Anthropic, or other providers." + : "Or add your own API keys from OpenAI, Anthropic, or other providers."} +

+
+ + {hasProvider ? ( + <> + + + + ) : ( + <> + + + + )} + + + ); +} + +function ProvidersStep({ + entries, + onUpdate, + onAdd, + onRemove, + onTest, + onSave, + onBack, + onNext, + onSkip, + hasAnySaved, + hasExistingOpenRouter, + oauthError, + onOpenRouterOAuth, + oauthLoading, + ollamaDetected, + ollamaConnecting, + hasExistingOllama, + onOllamaConnect, +}: { + entries: ProviderEntry[]; + onUpdate: (key: string, update: Partial) => void; + onAdd: (template: ProviderTemplate) => void; + onRemove: (key: string) => void; + onTest: (key: string) => void; + onSave: (key: string) => void; + onBack: () => void; + onNext: () => void; + onSkip: () => void; + hasAnySaved: boolean; + hasExistingOpenRouter: boolean; + oauthError: string | null; + onOpenRouterOAuth: () => void; + oauthLoading: boolean; + ollamaDetected?: boolean; + ollamaConnecting?: boolean; + hasExistingOllama: boolean; + onOllamaConnect?: () => void; +}) { + return ( + <> + + Connect your providers + Add at least one API key to start chatting + + + {/* OpenRouter OAuth section */} + {hasExistingOpenRouter ? ( +
+
+
+

OpenRouter

+

https://openrouter.ai/api/v1

+
+
+ + Connected +
+
+
+ ) : ( +
+
+
+

OpenRouter

+

200+ models, one click

+
+ +
+
+ )} + + {oauthError && ( +
+ + OpenRouter connection failed: {oauthError} +
+ )} + + {/* Ollama section */} + {hasExistingOllama ? ( +
+
+
+

Ollama

+

http://localhost:11434/v1

+
+
+ + Connected +
+
+
+ ) : ollamaDetected ? ( +
+
+
+

Ollama

+

Local models detected

+
+ +
+
+ ) : null} + +
+ {entries.map((entry) => ( + e.template.id === entry.template.id).length > 1} + onUpdate={(update) => onUpdate(entry.key, update)} + onRemove={() => onRemove(entry.key)} + onTest={() => onTest(entry.key)} + onSave={() => onSave(entry.key)} + /> + ))} +
+
+ {PROVIDER_TEMPLATES.map((t) => ( + + ))} +
+
+ + +
+ {!hasAnySaved && ( + + )} + {hasAnySaved && ( + + )} +
+
+ + ); +} + +function ProviderKeyEntry({ + entry, + canRemove, + onUpdate, + onRemove, + onTest, + onSave, +}: { + entry: ProviderEntry; + canRemove: boolean; + onUpdate: (update: Partial) => void; + onRemove: () => void; + onTest: () => void; + onSave: () => void; +}) { + const id = useId(); + const { template } = entry; + + if (entry.saved) { + return ( +
+
+
+

{entry.name}

+

{entry.baseUrl}

+
+
+ + Connected +
+
+
+ ); + } + + return ( +
+
+
+

{template.label}

+
+ + Get API key + + {canRemove && ( + + )} +
+
+

{template.description}

+
+ +
+ + onUpdate({ apiKey: e.target.value })} + disabled={entry.isSaving} + /> + + + + onUpdate({ baseUrl: e.target.value })} + disabled={entry.isSaving} + /> + +
+ +
+ + +
+ + {entry.testResult && ( +
+ {entry.testResult.status === "ok" ? ( +
+ + {entry.testResult.message} + {entry.testResult.latency_ms != null && ( + ({entry.testResult.latency_ms}ms) + )} +
+ ) : ( +
+ + {entry.testResult.message} +
+ )} +
+ )} + + {entry.error &&

{entry.error}

} +
+ ); +} + +function DoneStep({ savedCount, onComplete }: { savedCount: number; onComplete: () => void }) { + return ( + <> + + Setup complete + + {savedCount} provider{savedCount !== 1 ? "s" : ""} connected + + + +
+
+ +
+

+ Manage providers from the Providers page in the sidebar, or re-run this + wizard from the user menu. +

+
+
+ + + + + ); +} diff --git a/ui/src/components/WasmSetup/WasmSetupGuard.stories.tsx b/ui/src/components/WasmSetup/WasmSetupGuard.stories.tsx new file mode 100644 index 0000000..7d69e83 --- /dev/null +++ b/ui/src/components/WasmSetup/WasmSetupGuard.stories.tsx @@ -0,0 +1,45 @@ +import type { Meta, StoryObj } from "@storybook/react"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import { WasmSetupGuard } from "./WasmSetupGuard"; +import { ApiClientProvider } from "@/api/ApiClientProvider"; +import { ConfigProvider } from "@/config/ConfigProvider"; +import { AuthProvider } from "@/auth"; + +const queryClient = new QueryClient({ + defaultOptions: { queries: { retry: false } }, +}); + +const meta: Meta = { + title: "Components/WasmSetupGuard", + component: WasmSetupGuard, + parameters: { + layout: "centered", + }, + decorators: [ + (Story) => ( + + + + + + + + + + ), + ], +}; + +export default meta; +type Story = StoryObj; + +export const Default: Story = { + render: () => ( + +
+

App Content

+

In non-WASM mode, children render directly.

+
+
+ ), +}; diff --git a/ui/src/components/WasmSetup/WasmSetupGuard.tsx b/ui/src/components/WasmSetup/WasmSetupGuard.tsx new file mode 100644 index 0000000..235a8b4 --- /dev/null +++ b/ui/src/components/WasmSetup/WasmSetupGuard.tsx @@ -0,0 +1,186 @@ +import { + createContext, + useContext, + useState, + useCallback, + useEffect, + useRef, + type ReactNode, +} from "react"; +import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query"; +import { + meProvidersListOptions, + meProvidersCreateMutation, + meProvidersListQueryKey, + apiV1ModelsQueryKey, +} from "@/api/generated/@tanstack/react-query.gen"; +import { WasmSetup } from "./WasmSetup"; +import { + getOpenRouterCallbackCode, + clearCallbackCode, + exchangeCodeForKey, +} from "./openrouter-oauth"; + +const IS_WASM = import.meta.env.VITE_WASM_MODE === "true"; +const DISMISSED_KEY = "hadrian-wasm-setup-dismissed"; + +interface WasmSetupContextValue { + /** True when running in WASM mode. */ + isWasm: boolean; + /** Open the setup wizard. */ + openSetupWizard: () => void; + /** Name of provider just connected via OAuth, if any. */ + oauthProviderName: string | null; + /** Clear the OAuth success state. */ + clearOAuthSuccess: () => void; +} + +const WasmSetupContext = createContext({ + isWasm: false, + openSetupWizard: () => {}, + oauthProviderName: null, + clearOAuthSuccess: () => {}, +}); + +/** Access WASM setup state. Only meaningful when `isWasm` is true. */ +export function useWasmSetup() { + return useContext(WasmSetupContext); +} + +/** + * In WASM mode, shows the onboarding wizard if no providers are configured + * and the user hasn't dismissed it before. In server mode, renders children directly. + * + * Also handles OpenRouter OAuth callbacks: if the URL contains a `code` param, + * exchanges it for an API key and saves it as a dynamic provider. + */ +export function WasmSetupGuard({ children }: { children: ReactNode }) { + const [dismissed, setDismissed] = useState(() => localStorage.getItem(DISMISSED_KEY) === "true"); + const [manualOpen, setManualOpen] = useState(false); + const [oauthProviderName, setOAuthProviderName] = useState(null); + const [oauthError, setOAuthError] = useState(null); + const oauthHandled = useRef(false); + const [ollamaDetected, setOllamaDetected] = useState(false); + const [ollamaConnecting, setOllamaConnecting] = useState(false); + const [ollamaConnected, setOllamaConnected] = useState(false); + const queryClient = useQueryClient(); + + const createProvider = useMutation({ ...meProvidersCreateMutation() }); + + const { data, isLoading } = useQuery({ + ...meProvidersListOptions(), + enabled: IS_WASM, + }); + + // Detect local Ollama instance + useEffect(() => { + if (!IS_WASM) return; + const controller = new AbortController(); + fetch("http://localhost:11434/v1/models", { signal: controller.signal }) + .then((res) => { + if (res.ok) setOllamaDetected(true); + }) + .catch(() => {}); + return () => controller.abort(); + }, []); + + const handleOllamaConnect = useCallback(async () => { + setOllamaConnecting(true); + try { + await createProvider.mutateAsync({ + body: { + name: "ollama", + provider_type: "open_ai", + base_url: "http://localhost:11434/v1", + api_key: "ollama", + }, + }); + queryClient.invalidateQueries({ queryKey: meProvidersListQueryKey() }); + queryClient.invalidateQueries({ queryKey: apiV1ModelsQueryKey() }); + setOllamaConnected(true); + setManualOpen(true); + } catch (err) { + console.error("Ollama connect failed:", err); + } finally { + setOllamaConnecting(false); + } + }, [createProvider, queryClient]); + + // Handle OpenRouter OAuth callback + useEffect(() => { + if (!IS_WASM || oauthHandled.current) return; + const code = getOpenRouterCallbackCode(); + if (!code) return; + oauthHandled.current = true; + clearCallbackCode(); + + (async () => { + try { + const apiKey = await exchangeCodeForKey(code); + await createProvider.mutateAsync({ + body: { + name: "openrouter", + provider_type: "open_ai", + base_url: "https://openrouter.ai/api/v1", + api_key: apiKey, + }, + }); + queryClient.invalidateQueries({ queryKey: meProvidersListQueryKey() }); + queryClient.invalidateQueries({ queryKey: apiV1ModelsQueryKey() }); + setOAuthProviderName("openrouter"); + setManualOpen(true); + } catch (err) { + console.error("OpenRouter OAuth failed:", err); + setOAuthError(String(err)); + setManualOpen(true); + } + })(); + }, [createProvider, queryClient]); + + const openSetupWizard = useCallback(() => setManualOpen(true), []); + + const handleComplete = useCallback(() => { + localStorage.setItem(DISMISSED_KEY, "true"); + setDismissed(true); + setManualOpen(false); + setOAuthProviderName(null); + setOAuthError(null); + setOllamaConnected(false); + }, []); + + const clearOAuthSuccess = useCallback(() => { + setOAuthProviderName(null); + setOAuthError(null); + }, []); + + const contextValue: WasmSetupContextValue = { + isWasm: IS_WASM, + openSetupWizard, + oauthProviderName, + clearOAuthSuccess, + }; + + if (!IS_WASM) { + return {children}; + } + + // Auto-show: no providers and not previously dismissed + const needsOnboarding = !dismissed && !isLoading && (data?.data?.length ?? 0) === 0; + + return ( + + {children} + + + ); +} diff --git a/ui/src/components/WasmSetup/openrouter-oauth.ts b/ui/src/components/WasmSetup/openrouter-oauth.ts new file mode 100644 index 0000000..ab785ed --- /dev/null +++ b/ui/src/components/WasmSetup/openrouter-oauth.ts @@ -0,0 +1,91 @@ +/** + * OpenRouter OAuth PKCE flow. + * + * @see https://openrouter.ai/docs/guides/overview/auth/oauth + */ + +const VERIFIER_KEY = "hadrian-openrouter-verifier"; + +/** Generate a cryptographically random code verifier. */ +function generateCodeVerifier(): string { + const bytes = crypto.getRandomValues(new Uint8Array(32)); + return base64url(bytes); +} + +/** SHA-256 hash the verifier and base64url-encode it. */ +async function generateCodeChallenge(verifier: string): Promise { + const encoded = new TextEncoder().encode(verifier); + const digest = await crypto.subtle.digest("SHA-256", encoded); + return base64url(new Uint8Array(digest)); +} + +function base64url(bytes: Uint8Array): string { + let binary = ""; + for (const b of bytes) binary += String.fromCharCode(b); + return btoa(binary).replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, ""); +} + +/** + * Start the OpenRouter OAuth PKCE flow. + * Stores the code verifier in sessionStorage and redirects to OpenRouter. + */ +export async function startOpenRouterOAuth() { + const verifier = generateCodeVerifier(); + const challenge = await generateCodeChallenge(verifier); + sessionStorage.setItem(VERIFIER_KEY, verifier); + + const callbackUrl = window.location.origin + window.location.pathname; + const params = new URLSearchParams({ + callback_url: callbackUrl, + code_challenge: challenge, + code_challenge_method: "S256", + }); + + window.location.href = `https://openrouter.ai/auth?${params}`; +} + +/** + * Check if we're returning from an OpenRouter OAuth callback. + * Returns the authorization code if present, otherwise null. + */ +export function getOpenRouterCallbackCode(): string | null { + const params = new URLSearchParams(window.location.search); + return params.get("code"); +} + +/** Remove the code param from the URL without a page reload. */ +export function clearCallbackCode() { + const url = new URL(window.location.href); + url.searchParams.delete("code"); + window.history.replaceState({}, "", url.toString()); +} + +/** + * Exchange the authorization code for an OpenRouter API key. + * Uses the stored code verifier from sessionStorage. + */ +export async function exchangeCodeForKey(code: string): Promise { + const verifier = sessionStorage.getItem(VERIFIER_KEY); + if (!verifier) { + throw new Error("Missing code verifier — OAuth flow may have been interrupted"); + } + + const res = await fetch("https://openrouter.ai/api/v1/auth/keys", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + code, + code_verifier: verifier, + code_challenge_method: "S256", + }), + }); + + if (!res.ok) { + const body = await res.text(); + throw new Error(`OpenRouter key exchange failed (${res.status}): ${body}`); + } + + const { key } = await res.json(); + sessionStorage.removeItem(VERIFIER_KEY); + return key; +} diff --git a/ui/src/index.css b/ui/src/index.css index a379d9d..f707a41 100644 --- a/ui/src/index.css +++ b/ui/src/index.css @@ -379,6 +379,31 @@ animation: typing-dot 1.4s ease-in-out 0.4s infinite; } + /* Shimmer sweep for attention */ + @keyframes shimmer { + 0%, + 40% { + transform: translateX(-100%); + } + 70%, + 100% { + transform: translateX(100%); + } + } + + .animate-shimmer { + position: relative; + overflow: hidden; + } + + .animate-shimmer::after { + content: ""; + position: absolute; + inset: 0; + background: linear-gradient(110deg, transparent 30%, hsl(0 0% 100% / 0.3) 50%, transparent 70%); + animation: shimmer 3s ease-in-out infinite; + } + @keyframes animate-in { from { opacity: var(--tw-enter-opacity, 1); diff --git a/ui/src/main.tsx b/ui/src/main.tsx index 12fa35b..322c50d 100644 --- a/ui/src/main.tsx +++ b/ui/src/main.tsx @@ -3,8 +3,20 @@ import { createRoot } from "react-dom/client"; import "./index.css"; import App from "./App"; -createRoot(document.getElementById("root")!).render( - - - -); +async function bootstrap() { + // In WASM mode, register the service worker and wait for it to control the + // page before rendering. This prevents API calls from firing before the SW + // is active (race condition on hard refresh). + if (import.meta.env.VITE_WASM_MODE === "true") { + const { registerWasmServiceWorker } = await import("./service-worker/register"); + await registerWasmServiceWorker(); + } + + createRoot(document.getElementById("root")!).render( + + + + ); +} + +bootstrap(); diff --git a/ui/src/routes/AppRoutes.stories.tsx b/ui/src/routes/AppRoutes.stories.tsx new file mode 100644 index 0000000..8217851 --- /dev/null +++ b/ui/src/routes/AppRoutes.stories.tsx @@ -0,0 +1,60 @@ +import type { Meta, StoryObj } from "@storybook/react"; +import { MemoryRouter } from "react-router-dom"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import { ConfigProvider } from "@/config/ConfigProvider"; +import { AuthProvider } from "@/auth"; +import { ApiClientProvider } from "@/api/ApiClientProvider"; +import { ToastProvider } from "@/components/Toast/Toast"; +import { ConfirmDialogProvider } from "@/components/ConfirmDialog/ConfirmDialog"; +import { CommandPaletteProvider } from "@/components/CommandPalette/CommandPalette"; +import { ConversationsProvider } from "@/components/ConversationsProvider/ConversationsProvider"; +import { AppRoutes } from "./AppRoutes"; + +const queryClient = new QueryClient({ + defaultOptions: { queries: { retry: false } }, +}); + +const meta: Meta = { + title: "Routes/AppRoutes", + component: AppRoutes, + decorators: [ + (Story) => ( + + + + + + + + + + + + + + + + + + + + ), + ], + parameters: { + layout: "fullscreen", + a11y: { + config: { + rules: [ + // Route tree renders lazily — landmark/heading checks are irrelevant in isolation + { id: "region", enabled: false }, + { id: "page-has-heading-one", enabled: false }, + ], + }, + }, + }, +}; + +export default meta; +type Story = StoryObj; + +export const LoginRoute: Story = {}; diff --git a/ui/src/routes/AppRoutes.tsx b/ui/src/routes/AppRoutes.tsx new file mode 100644 index 0000000..ed2d444 --- /dev/null +++ b/ui/src/routes/AppRoutes.tsx @@ -0,0 +1,430 @@ +import { Routes, Route, Navigate } from "react-router-dom"; +import { RequireAuth, RequireAdmin } from "@/auth"; +import { AppLayout } from "@/components/AppLayout/AppLayout"; +import { AdminLayout } from "@/components/AdminLayout/AdminLayout"; +import { lazy, Suspense } from "react"; +import { Spinner } from "@/components/Spinner/Spinner"; + +const LoginPage = lazy(() => import("@/pages/LoginPage")); +const AccountPage = lazy(() => import("@/pages/AccountPage")); +const ProjectsPage = lazy(() => import("@/pages/ProjectsPage")); +const TeamsPage = lazy(() => import("@/pages/TeamsPage")); +const KnowledgeBasesPage = lazy(() => import("@/pages/KnowledgeBasesPage")); +const ApiKeysPage = lazy(() => import("@/pages/ApiKeysPage")); +const ApiKeyDetailPage = lazy(() => import("@/pages/ApiKeyDetailPage")); +const MyUsagePage = lazy(() => import("@/pages/MyUsagePage")); +const MyProvidersPage = lazy(() => import("@/pages/MyProvidersPage")); +const SelfServiceProjectDetailPage = lazy(() => import("@/pages/project/ProjectDetailPage")); +const StudioPage = lazy(() => import("@/pages/studio/StudioPage")); +const ChatPage = lazy(() => import("@/pages/chat/ChatPage")); +const AdminDashboardPage = lazy(() => import("@/pages/admin/DashboardPage")); +const OrganizationsPage = lazy(() => import("@/pages/admin/OrganizationsPage")); +const OrganizationDetailPage = lazy(() => import("@/pages/admin/OrganizationDetailPage")); +const ProjectDetailPage = lazy(() => import("@/pages/admin/ProjectDetailPage")); +const UsersPage = lazy(() => import("@/pages/admin/UsersPage")); +const UserDetailPage = lazy(() => import("@/pages/admin/UserDetailPage")); +const AdminApiKeysPage = lazy(() => import("@/pages/admin/ApiKeysPage")); +const ProvidersPage = lazy(() => import("@/pages/admin/ProvidersPage")); +const ProviderHealthPage = lazy(() => import("@/pages/admin/ProviderHealthPage")); +const ProviderDetailPage = lazy(() => import("@/pages/admin/ProviderDetailPage")); +const PricingPage = lazy(() => import("@/pages/admin/PricingPage")); +const UsagePage = lazy(() => import("@/pages/admin/UsagePage")); +const AdminProjectsPage = lazy(() => import("@/pages/admin/ProjectsPage")); +const AdminTeamsPage = lazy(() => import("@/pages/admin/TeamsPage")); +const ServiceAccountsPage = lazy(() => import("@/pages/admin/ServiceAccountsPage")); +const TeamDetailPage = lazy(() => import("@/pages/admin/TeamDetailPage")); +const SettingsPage = lazy(() => import("@/pages/admin/SettingsPage")); +const AuditLogsPage = lazy(() => import("@/pages/admin/AuditLogsPage")); +const VectorStoresPage = lazy(() => import("@/pages/admin/VectorStoresPage")); +const VectorStoreDetailPage = lazy(() => import("@/pages/admin/VectorStoreDetailPage")); +const SsoConnectionsPage = lazy(() => import("@/pages/admin/SsoConnectionsPage")); +const SsoGroupMappingsPage = lazy(() => import("@/pages/admin/SsoGroupMappingsPage")); +const OrgSsoConfigPage = lazy(() => import("@/pages/admin/OrgSsoConfigPage")); +const ScimConfigPage = lazy(() => import("@/pages/admin/ScimConfigPage")); +const OrgRbacPoliciesPage = lazy(() => import("@/pages/admin/OrgRbacPoliciesPage")); +const SessionInfoPage = lazy(() => import("@/pages/admin/SessionInfoPage")); + +function PageLoader() { + return ( +
+ +
+ ); +} + +export function AppRoutes() { + return ( + + {/* Root redirect */} + } /> + + {/* Login route */} + }> + + + } + /> + + {/* Auth callback route for OIDC */} + }> + + + } + /> + + {/* Protected routes with main AppLayout (chat sidebar) */} + + + + } + > + {/* Chat routes */} + }> + + + } + /> + }> + + + } + /> + + {/* Projects route */} + }> + + + } + /> + + {/* Project detail route */} + }> + + + } + /> + + {/* Teams route */} + }> + + + } + /> + + {/* Knowledge Bases route */} + }> + + + } + /> + + {/* API Keys routes */} + }> + + + } + /> + }> + + + } + /> + + {/* Providers route (self-service) */} + }> + + + } + /> + + {/* Usage route (self-service) */} + }> + + + } + /> + + {/* Studio route */} + }> + + + } + /> + + {/* Account settings route */} + }> + + + } + /> + + {/* Session info route (debugging) */} + }> + + + } + /> + + + {/* Admin routes with AdminLayout (admin sidebar) */} + + + + } + > + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + + + {/* Catch all - redirect to chat */} + } /> + + ); +} diff --git a/ui/src/service-worker/register.ts b/ui/src/service-worker/register.ts new file mode 100644 index 0000000..f70d8fe --- /dev/null +++ b/ui/src/service-worker/register.ts @@ -0,0 +1,69 @@ +/** + * Service worker registration for WASM mode. + * + * Only registers the service worker when VITE_WASM_MODE is enabled. + * In server mode, the existing vite-plugin-pwa config handles SW lifecycle. + */ + +export async function registerWasmServiceWorker(): Promise { + if (!("serviceWorker" in navigator)) { + console.warn("Service workers not supported in this browser"); + return; + } + + try { + const registration = await navigator.serviceWorker.register("/sw.js", { + type: "module", + scope: "/", + }); + + console.log("Hadrian WASM service worker registered:", registration.scope); + + // Wait for the SW to be active (handles installing, waiting, or already active) + const sw = registration.active || registration.waiting || registration.installing; + if (sw && sw.state !== "activated") { + await Promise.race([ + new Promise((resolve) => { + sw.addEventListener("statechange", function handler() { + if (sw.state === "activated") { + sw.removeEventListener("statechange", handler); + resolve(); + } + }); + }), + new Promise((_, reject) => + setTimeout(() => reject(new Error("Service worker activation timed out")), 10_000) + ), + ]); + } + + // Ensure this page is controlled by the SW — even after activation, + // the page may not be controlled until clients.claim() fires. + // On hard refresh the activate event doesn't re-fire, so we ask the + // SW to re-claim via postMessage and race against a timeout. + if (!navigator.serviceWorker.controller) { + const controllerReady = new Promise((resolve) => { + navigator.serviceWorker.addEventListener("controllerchange", () => resolve(), { + once: true, + }); + }); + + // Ask the already-active SW to call clients.claim() + registration.active?.postMessage({ type: "CLAIM" }); + + await Promise.race([controllerReady, new Promise((r) => setTimeout(r, 2000))]); + } + } catch (error) { + console.error("Failed to register WASM service worker:", error); + } +} + +/** + * Check if the WASM service worker is active and ready. + */ +export async function isWasmReady(): Promise { + if (!("serviceWorker" in navigator)) return false; + + const registration = await navigator.serviceWorker.getRegistration("/"); + return registration?.active != null; +} diff --git a/ui/src/service-worker/sqlite-bridge.ts b/ui/src/service-worker/sqlite-bridge.ts new file mode 100644 index 0000000..a7f7f94 --- /dev/null +++ b/ui/src/service-worker/sqlite-bridge.ts @@ -0,0 +1,185 @@ +/** + * sql.js bridge for the Hadrian WASM gateway. + * + * Implements `globalThis.__hadrian_sqlite` with three methods that the Rust + * WASM FFI bridge (`src/db/wasm_sqlite/bridge.rs`) calls via `wasm_bindgen`: + * + * - `init_database()` — load sql.js WASM and create an in-memory database + * - `query(sql, params)` — run a SELECT, return `Array>` + * - `execute(sql, params)` — run INSERT/UPDATE/DELETE, return `{ changes, last_insert_rowid }` + * + * The database is persisted to IndexedDB so state survives hard refreshes. + * + * This must be imported *before* the Hadrian WASM module so the bridge exists + * when the Rust constructor calls `init_database()`. + */ + +// sql.js ships a factory function that loads its own WASM binary. +// We import the ESM build and point it at the WASM file we serve from /wasm/. +import initSqlJs, { type Database } from "sql.js"; + +let db: Database | null = null; + +// --------------------------------------------------------------------------- +// IndexedDB persistence +// --------------------------------------------------------------------------- + +const IDB_NAME = "hadrian-wasm"; +const IDB_STORE = "data"; +const IDB_KEY = "db"; + +async function loadFromIndexedDB(): Promise { + return new Promise((resolve) => { + const req = indexedDB.open(IDB_NAME, 1); + req.onupgradeneeded = () => req.result.createObjectStore(IDB_STORE); + req.onsuccess = () => { + const tx = req.result.transaction(IDB_STORE, "readonly"); + const get = tx.objectStore(IDB_STORE).get(IDB_KEY); + get.onsuccess = () => resolve(get.result ?? null); + get.onerror = () => resolve(null); + }; + req.onerror = () => resolve(null); + }); +} + +function saveToIndexedDB(data: Uint8Array): void { + const req = indexedDB.open(IDB_NAME, 1); + req.onupgradeneeded = () => req.result.createObjectStore(IDB_STORE); + req.onsuccess = () => { + const tx = req.result.transaction(IDB_STORE, "readwrite"); + const putReq = tx.objectStore(IDB_STORE).put(data, IDB_KEY); + putReq.onerror = () => console.error("[sqlite-bridge] IndexedDB put failed:", putReq.error); + tx.onerror = () => console.error("[sqlite-bridge] IndexedDB transaction failed:", tx.error); + }; + req.onerror = () => console.error("[sqlite-bridge] IndexedDB open failed:", req.error); +} + +let saveTimer: ReturnType | null = null; + +function flushSave(): void { + if (!db || !saveTimer) return; + clearTimeout(saveTimer); + saveTimer = null; + const data = db.export(); + saveToIndexedDB(new Uint8Array(data)); +} + +function debouncedSave(): void { + if (!db) return; + if (saveTimer) clearTimeout(saveTimer); + saveTimer = setTimeout(() => { + saveTimer = null; + const data = db!.export(); + saveToIndexedDB(new Uint8Array(data)); + }, 500); +} + +// Flush pending saves when the service worker is about to be evicted. +// Service workers don't fire `unload`/`beforeunload`; instead we listen for +// the `activate` event to ensure the DB is persisted on startup, and save +// synchronously on every write when a pending timer exists. +// `beforeunload` is kept as a best-effort fallback for window contexts. +if (typeof globalThis.addEventListener === "function") { + globalThis.addEventListener("beforeunload", flushSave); +} + +// --------------------------------------------------------------------------- +// Parameter binding +// --------------------------------------------------------------------------- + +/** + * Bind values into a prepared statement. + * The Rust bridge sends params via serde_wasm_bindgen with `#[serde(untagged)]`, + * so values arrive as native JS types: string, number, null, or boolean. + * Tagged variants ({ Text: "..." }, etc.) are only possible if serialization + * changes — we handle both forms defensively. + * sql.js accepts `(string | number | Uint8Array | null)[]`. + */ +function bindParams(params: unknown[]): (string | number | Uint8Array | null)[] { + return params.map((p) => { + if (p == null) return null; // == catches both null and undefined + if (typeof p === "string" || typeof p === "number") return p; + if (typeof p === "boolean") return p ? 1 : 0; // SQLite stores booleans as integers + if (typeof p === "object" && p !== null) { + const obj = p as Record; + if ("Text" in obj) return obj.Text as string; + if ("Integer" in obj) return obj.Integer as number; + if ("Real" in obj) return obj.Real as number; + if ("Blob" in obj) return new Uint8Array(obj.Blob as number[]); + if ("Null" in obj) return null; + } + return String(p); + }); +} + +const bridge = { + async init_database(): Promise { + // Fetch the WASM binary ourselves to avoid filename ambiguity. + // esbuild resolves `sql.js` to the browser build which expects + // "sql-wasm-browser.wasm", but we serve "sql-wasm.wasm". + const wasmBinary = await fetch("/wasm/sql-wasm.wasm").then((r) => { + if (!r.ok) throw new Error(`Failed to fetch sql-wasm.wasm: ${r.status}`); + return r.arrayBuffer(); + }); + const SQL = await initSqlJs({ wasmBinary }); + + // Try to restore from IndexedDB, otherwise create fresh + const saved = await loadFromIndexedDB(); + db = saved ? new SQL.Database(saved) : new SQL.Database(); + + // Enable WAL-like pragmas for better performance + db.run("PRAGMA journal_mode = MEMORY"); + db.run("PRAGMA synchronous = OFF"); + db.run("PRAGMA foreign_keys = ON"); + + console.log( + `[sqlite-bridge] Database initialized${saved ? " (restored from IndexedDB)" : " (fresh)"}` + ); + }, + + async query(sql: string, params: unknown[]): Promise[]> { + if (!db) throw new Error("Database not initialized — call init_database() first"); + + const stmt = db.prepare(sql); + try { + stmt.bind(bindParams(params)); + + const rows: Record[] = []; + while (stmt.step()) { + const row = stmt.getAsObject(); + rows.push(row as Record); + } + return rows; + } finally { + stmt.free(); + } + }, + + async execute( + sql: string, + params: unknown[] + ): Promise<{ changes: number; last_insert_rowid: number }> { + if (!db) throw new Error("Database not initialized — call init_database() first"); + + db.run(sql, bindParams(params)); + + // sql.js doesn't return affected rows directly from run(), + // so we query the SQLite functions. + const changes = (db.exec("SELECT changes()")[0]?.values[0]?.[0] as number) ?? 0; + const lastId = (db.exec("SELECT last_insert_rowid()")[0]?.values[0]?.[0] as number) ?? 0; + + debouncedSave(); + return { changes, last_insert_rowid: lastId }; + }, + + async execute_script(sql: string): Promise { + if (!db) throw new Error("Database not initialized — call init_database() first"); + db.exec(sql); + debouncedSave(); + }, +}; + +// Expose on globalThis so the Rust wasm_bindgen FFI can find it. +(globalThis as unknown as Record).__hadrian_sqlite = bridge; + +console.log("[sqlite-bridge] Bridge registered on globalThis.__hadrian_sqlite"); diff --git a/ui/src/service-worker/sw.ts b/ui/src/service-worker/sw.ts new file mode 100644 index 0000000..04c4e74 --- /dev/null +++ b/ui/src/service-worker/sw.ts @@ -0,0 +1,109 @@ +/** + * Hadrian WASM Service Worker + * + * Intercepts API requests and routes them through the WASM-compiled + * Hadrian gateway running entirely in the browser. + * + * Intercepted paths: + * - /v1/* — OpenAI-compatible API endpoints + * - /admin/v1/* — Admin API endpoints + * - /health — Health check + * - /api/* — Other API endpoints + */ + +/// +declare const self: ServiceWorkerGlobalScope; + +// Initialize the sql.js bridge BEFORE loading the WASM module. +// This registers globalThis.__hadrian_sqlite so the Rust FFI can use it. +import "./sqlite-bridge"; + +// Static import — dynamic import() is disallowed in service workers. +// The WASM module is served from public/wasm/ at runtime. +import wasmInit, { HadrianGateway } from "/wasm/hadrian.js"; + +let gateway: HadrianGateway | null = null; +let initPromise: Promise | null = null; + +// Path prefixes handled by the WASM gateway +const GATEWAY_PATHS = ["/v1/", "/admin/v1/", "/health", "/auth/", "/api/"]; + +async function ensureGateway(): Promise { + await wasmInit("/wasm/hadrian_bg.wasm"); + gateway = await new HadrianGateway(); +} + +self.addEventListener("install", (event) => { + // Activate immediately, don't wait for existing clients to close + event.waitUntil(self.skipWaiting()); +}); + +self.addEventListener("activate", (event) => { + // Take control of all clients immediately + event.waitUntil(self.clients.claim()); +}); + +// Allow clients to request re-claim (e.g. after hard refresh where +// the activate event doesn't fire again). +self.addEventListener("message", (event) => { + if (event.data?.type === "CLAIM") { + self.clients.claim(); + } +}); + +self.addEventListener("fetch", (event) => { + const url = new URL(event.request.url); + + // Only intercept gateway API paths on the same origin + if (url.origin !== self.location.origin) return; + if (!GATEWAY_PATHS.some((p) => url.pathname.startsWith(p))) return; + + event.respondWith(handleRequest(event.request)); +}); + +async function handleRequest(request: Request): Promise { + // Lazy-init the WASM gateway on first intercepted request + if (!gateway) { + if (!initPromise) { + initPromise = ensureGateway(); + } + try { + await initPromise; + } catch (error) { + initPromise = null; // Allow retry on next request + console.error("Failed to initialize Hadrian WASM gateway:", error); + return new Response( + JSON.stringify({ + error: { + message: `Gateway initialization failed: ${String(error)}`, + type: "server_error", + code: 503, + }, + }), + { + status: 503, + headers: { "Content-Type": "application/json" }, + } + ); + } + } + + try { + return await gateway!.handle(request); + } catch (error) { + console.error("Hadrian WASM gateway error:", error); + return new Response( + JSON.stringify({ + error: { + message: String(error), + type: "server_error", + code: 500, + }, + }), + { + status: 500, + headers: { "Content-Type": "application/json" }, + } + ); + } +} diff --git a/ui/src/service-worker/wasm.d.ts b/ui/src/service-worker/wasm.d.ts new file mode 100644 index 0000000..9cc19ca --- /dev/null +++ b/ui/src/service-worker/wasm.d.ts @@ -0,0 +1,16 @@ +/** + * Type declarations for the Hadrian WASM module. + * Generated types live in the wasm-pack output, but the service worker + * needs declarations at build time. + */ + +declare module "/wasm/hadrian.js" { + export default function init(wasmUrl?: string | URL): Promise; + + export class HadrianGateway { + /** The constructor is async (returns a Promise) via wasm-bindgen. */ + constructor(); + handle(request: Request): Promise; + free(): void; + } +} diff --git a/ui/tsconfig.app.json b/ui/tsconfig.app.json index 9ea1ffd..1159ed5 100644 --- a/ui/tsconfig.app.json +++ b/ui/tsconfig.app.json @@ -31,5 +31,5 @@ } }, "include": ["src"], - "exclude": ["src/**/*.stories.tsx", "src/**/*.stories.ts"] + "exclude": ["src/**/*.stories.tsx", "src/**/*.stories.ts", "src/service-worker"] } diff --git a/ui/vite.config.ts b/ui/vite.config.ts index e5a367b..0b7a4c5 100644 --- a/ui/vite.config.ts +++ b/ui/vite.config.ts @@ -1,76 +1,152 @@ /// -import { defineConfig } from "vite"; +import { defineConfig, type Plugin } from "vite"; import react from "@vitejs/plugin-react"; import tailwindcss from "@tailwindcss/vite"; import path from "path"; -import { fileURLToPath } from "node:url"; +import { fileURLToPath, pathToFileURL } from "node:url"; +import { createRequire } from "node:module"; import { storybookTest } from "@storybook/addon-vitest/vitest-plugin"; import { playwright } from "@vitest/browser-playwright"; -import {VitePWA} from "vite-plugin-pwa"; +import { VitePWA } from "vite-plugin-pwa"; const dirname = - typeof __dirname !== "undefined" ? __dirname : path.dirname(fileURLToPath(import.meta.url)); + typeof __dirname !== "undefined" + ? __dirname + : path.dirname(fileURLToPath(import.meta.url)); + +const isWasmMode = process.env.VITE_WASM_MODE === "true"; + +/** + * Builds and serves the WASM service worker. + * + * Dev: intercepts /sw.js requests and transforms the TS source on the fly. + * Build: compiles sw.ts with esbuild (separate from rollup) so the output + * is a standalone file without Vite's preload helpers. + */ +function wasmServiceWorkerPlugin(): Plugin { + const swPath = path.resolve(__dirname, "src/service-worker/sw.ts"); + + function getEsbuild() { + const req = createRequire(pathToFileURL(__filename).href); + return req("esbuild") as { + transform: Function; + build: Function; + }; + } + + return { + name: "hadrian-wasm-sw", + configureServer(server) { + // Must run before Vite's SPA fallback, which would serve index.html + server.middlewares.use(async (req, res, next) => { + if (req.url !== "/sw.js") return next(); + const { build } = getEsbuild(); + const os = await import("node:os"); + const outfile = path.join(os.tmpdir(), "hadrian-sw-dev.js"); + try { + await build({ + entryPoints: [swPath], + outfile, + bundle: true, + format: "esm", + target: "es2022", + write: true, + // /wasm/hadrian.js is a runtime import served by the browser + external: ["/wasm/hadrian.js"], + }); + const fs = await import("node:fs/promises"); + const code = await fs.readFile(outfile, "utf-8"); + res.setHeader("Content-Type", "application/javascript; charset=utf-8"); + res.setHeader("Cache-Control", "no-store"); + res.end(code); + } catch (err) { + console.error("Failed to compile service worker:", err); + next(err); + } + }); + }, + async writeBundle() { + const { build } = getEsbuild(); + await build({ + entryPoints: [swPath], + outfile: path.resolve(__dirname, "dist/sw.js"), + bundle: true, + format: "esm", + target: "es2022", + sourcemap: true, + // /wasm/hadrian.js is a runtime import served by the browser + external: ["/wasm/hadrian.js"], + }); + }, + }; +} // More info at: https://storybook.js.org/docs/next/writing-tests/integrations/vitest-addon export default defineConfig({ plugins: [ react(), tailwindcss(), - VitePWA({ - selfDestroying: true, // Unregisters existing service workers - includeAssets: ["favicon.ico", "icons/*.png"], - manifest: { - name: "Hadrian Gateway", - short_name: "Hadrian", - description: "AI Gateway - Chat & Admin Dashboard", - theme_color: "#3b82f6", - background_color: "#0f172a", - display: "standalone", - start_url: "/", - icons: [ - { - src: "/icons/icon-72.png", - sizes: "72x72", - type: "image/png", - }, - { - src: "/icons/icon-96.png", - sizes: "96x96", - type: "image/png", - }, - { - src: "/icons/icon-128.png", - sizes: "128x128", - type: "image/png", - }, - { - src: "/icons/icon-144.png", - sizes: "144x144", - type: "image/png", - }, - { - src: "/icons/icon-152.png", - sizes: "152x152", - type: "image/png", - }, - { - src: "/icons/icon-192.png", - sizes: "192x192", - type: "image/png", - }, - { - src: "/icons/icon-384.png", - sizes: "384x384", - type: "image/png", - }, - { - src: "/icons/icon-512.png", - sizes: "512x512", - type: "image/png", - purpose: "any maskable", - }, - ], - }, - }), + // In WASM mode, compile and serve the service worker; otherwise use vite-plugin-pwa + ...(isWasmMode ? [wasmServiceWorkerPlugin()] : []), + ...(!isWasmMode + ? [ + VitePWA({ + selfDestroying: true, + includeAssets: ["favicon.ico", "icons/*.png"], + manifest: { + name: "Hadrian Gateway", + short_name: "Hadrian", + description: "AI Gateway - Chat & Admin Dashboard", + theme_color: "#3b82f6", + background_color: "#0f172a", + display: "standalone", + start_url: "/", + icons: [ + { + src: "/icons/icon-72.png", + sizes: "72x72", + type: "image/png", + }, + { + src: "/icons/icon-96.png", + sizes: "96x96", + type: "image/png", + }, + { + src: "/icons/icon-128.png", + sizes: "128x128", + type: "image/png", + }, + { + src: "/icons/icon-144.png", + sizes: "144x144", + type: "image/png", + }, + { + src: "/icons/icon-152.png", + sizes: "152x152", + type: "image/png", + }, + { + src: "/icons/icon-192.png", + sizes: "192x192", + type: "image/png", + }, + { + src: "/icons/icon-384.png", + sizes: "384x384", + type: "image/png", + }, + { + src: "/icons/icon-512.png", + sizes: "512x512", + type: "image/png", + purpose: "any maskable", + }, + ], + }, + }), + ] + : []), ], resolve: { alias: { @@ -79,20 +155,23 @@ export default defineConfig({ }, server: { port: 5173, - proxy: { - "/api/": { - target: "http://localhost:8080", - changeOrigin: true, - }, - "/admin/v1": { - target: "http://localhost:8080", - changeOrigin: true, - }, - "/auth": { - target: "http://localhost:8080", - changeOrigin: true, - }, - }, + // In WASM mode, the service worker handles API routing — no proxy needed + proxy: isWasmMode + ? undefined + : { + "/api/": { + target: "http://localhost:8080", + changeOrigin: true, + }, + "/admin/v1": { + target: "http://localhost:8080", + changeOrigin: true, + }, + "/auth": { + target: "http://localhost:8080", + changeOrigin: true, + }, + }, }, worker: { format: "es", @@ -114,8 +193,6 @@ export default defineConfig({ { extends: true, plugins: [ - // The plugin will run tests for the stories defined in your Storybook config - // See options at: https://storybook.js.org/docs/next/writing-tests/integrations/vitest-addon#storybooktest storybookTest({ configDir: path.join(dirname, ".storybook"), }),