diff --git a/.cursor/skills/onyx-cli/SKILL.md b/.cursor/skills/onyx-cli/SKILL.md new file mode 100644 index 00000000000..da7132d3206 --- /dev/null +++ b/.cursor/skills/onyx-cli/SKILL.md @@ -0,0 +1,186 @@ +--- +name: onyx-cli +description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx. +--- + +# Onyx CLI — Agent Tool + +Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents. + +## Prerequisites + +### 1. Check if installed + +```bash +which onyx-cli +``` + +### 2. Install (if needed) + +**Primary — pip:** + +```bash +pip install onyx-cli +``` + +**From source (Go):** + +```bash +cd cli && go build -o onyx-cli . && sudo mv onyx-cli /usr/local/bin/ +``` + +### 3. Check if configured + +```bash +onyx-cli validate-config +``` + +This checks the config file exists, API key is present, and tests the server connection via `/api/me`. Exit code 0 on success, non-zero with a descriptive error on failure. + +If unconfigured, you have two options: + +**Option A — Interactive setup (requires user input):** + +```bash +onyx-cli configure +``` + +This prompts for the Onyx server URL and API key, tests the connection, and saves config. + +**Option B — Environment variables (non-interactive, preferred for agents):** + +```bash +export ONYX_SERVER_URL="https://your-onyx-server.com" # default: https://cloud.onyx.app +export ONYX_API_KEY="your-api-key" +``` + +Environment variables override the config file. If these are set, no config file is needed. + +| Variable | Required | Description | +|----------|----------|-------------| +| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `https://cloud.onyx.app`) | +| `ONYX_API_KEY` | Yes | API key for authentication | +| `ONYX_PERSONA_ID` | No | Default agent/persona ID | + +If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either: +- Run `onyx-cli configure` interactively, or +- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables + +## Commands + +### Validate configuration + +```bash +onyx-cli validate-config +``` + +Checks config file exists, API key is present, and tests the server connection. Use this before `ask` or `agents` to confirm the CLI is properly set up. + +### List available agents + +```bash +onyx-cli agents +``` + +Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output: + +```bash +onyx-cli agents --json +``` + +Use agent IDs with `ask --agent-id` to query a specific agent. + +### Basic query (plain text output) + +```bash +onyx-cli ask "What is our company's PTO policy?" +``` + +Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error. + +### JSON output (structured events) + +```bash +onyx-cli ask --json "What authentication methods do we support?" +``` + +Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads. + +Each line is a JSON object with this envelope: + +```json +{"type": "", "event": { ... }} +``` + +| Event Type | Description | +|------------|-------------| +| `message_delta` | Content token — concatenate all `content` fields for the full answer | +| `stop` | Stream complete | +| `error` | Error with `error` message field | +| `search_tool_start` | Onyx started searching documents | +| `citation_info` | Source citation — see shape below | + +`citation_info` event shape: + +```json +{ + "type": "citation_info", + "event": { + "citation_number": 1, + "document_id": "abc123def456", + "placement": {"turn_index": 0, "tab_index": 0, "sub_turn_index": null} + } +} +``` + +`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases. + +### Specify an agent + +```bash +onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap" +``` + +Uses a specific Onyx agent/persona instead of the default. + +### All flags + +| Flag | Type | Description | +|------|------|-------------| +| `--agent-id` | int | Agent ID to use (overrides default) | +| `--json` | bool | Output raw NDJSON events instead of plain text | + +## Statelessness + +Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead. + +## When to Use + +Use `onyx-cli ask` when: + +- The user asks about company-specific information (policies, docs, processes) +- You need to search internal knowledge bases or connected data sources +- The user references Onyx, asks you to "search Onyx", or wants to query their documents +- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources + +Do NOT use when: + +- The question is about general programming knowledge (use your own knowledge) +- The user is asking about code in the current repository (use grep/read tools) +- The user hasn't mentioned Onyx and the question doesn't require internal company data + +## Examples + +```bash +# Simple question +onyx-cli ask "What are the steps to deploy to production?" + +# Get structured output for parsing +onyx-cli ask --json "List all active API integrations" + +# Use a specialized agent +onyx-cli ask --agent-id 3 "What were the action items from last week's standup?" + +# Pipe the answer into another command +onyx-cli ask "What is the database schema for users?" | head -20 +``` diff --git a/.github/actions/build-integration-image/action.yml b/.github/actions/build-integration-image/action.yml index 254f1c67c85..c0e5d8146b9 100644 --- a/.github/actions/build-integration-image/action.yml +++ b/.github/actions/build-integration-image/action.yml @@ -54,6 +54,7 @@ runs: shell: bash env: RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }} + INTEGRATION_REPOSITORY: ${{ inputs.runs-on-ecr-cache }} TAG: nightly-llm-it-${{ inputs.run-id }} CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }} HEAD_SHA: ${{ inputs.github-sha }} diff --git a/.github/workflows/deployment.yml b/.github/workflows/deployment.yml index 01c135d527f..548d16cafe4 100644 --- a/.github/workflows/deployment.yml +++ b/.github/workflows/deployment.yml @@ -29,20 +29,32 @@ jobs: build-backend-craft: ${{ steps.check.outputs.build-backend-craft }} build-model-server: ${{ steps.check.outputs.build-model-server }} is-cloud-tag: ${{ steps.check.outputs.is-cloud-tag }} - is-stable: ${{ steps.check.outputs.is-stable }} is-beta: ${{ steps.check.outputs.is-beta }} - is-stable-standalone: ${{ steps.check.outputs.is-stable-standalone }} is-beta-standalone: ${{ steps.check.outputs.is-beta-standalone }} - is-craft-latest: ${{ steps.check.outputs.is-craft-latest }} + is-latest: ${{ steps.check.outputs.is-latest }} is-test-run: ${{ steps.check.outputs.is-test-run }} sanitized-tag: ${{ steps.check.outputs.sanitized-tag }} short-sha: ${{ steps.check.outputs.short-sha }} steps: + - name: Checkout (for git tags) + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 + with: + persist-credentials: false + fetch-depth: 0 + fetch-tags: true + + - name: Setup uv + uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7 + with: + version: "0.9.9" + enable-cache: false + - name: Check which components to build and version info id: check env: EVENT_NAME: ${{ github.event_name }} run: | + set -eo pipefail TAG="${GITHUB_REF_NAME}" # Sanitize tag name by replacing slashes with hyphens (for Docker tag compatibility) SANITIZED_TAG=$(echo "$TAG" | tr '/' '-') @@ -54,9 +66,8 @@ jobs: IS_VERSION_TAG=false IS_STABLE=false IS_BETA=false - IS_STABLE_STANDALONE=false IS_BETA_STANDALONE=false - IS_CRAFT_LATEST=false + IS_LATEST=false IS_PROD_TAG=false IS_TEST_RUN=false BUILD_DESKTOP=false @@ -67,9 +78,6 @@ jobs: BUILD_MODEL_SERVER=true # Determine tag type based on pattern matching (do regex checks once) - if [[ "$TAG" == craft-* ]]; then - IS_CRAFT_LATEST=true - fi if [[ "$TAG" == *cloud* ]]; then IS_CLOUD=true fi @@ -97,20 +105,28 @@ jobs: fi fi - # Craft-latest builds backend with Craft enabled - if [[ "$IS_CRAFT_LATEST" == "true" ]]; then - BUILD_BACKEND_CRAFT=true - BUILD_BACKEND=false - fi - # Standalone version checks (for backend/model-server - version excluding cloud tags) - if [[ "$IS_STABLE" == "true" ]] && [[ "$IS_CLOUD" != "true" ]]; then - IS_STABLE_STANDALONE=true - fi if [[ "$IS_BETA" == "true" ]] && [[ "$IS_CLOUD" != "true" ]]; then IS_BETA_STANDALONE=true fi + # Determine if this tag should get the "latest" Docker tag. + # Only the highest semver stable tag (vX.Y.Z exactly) gets "latest". + if [[ "$IS_STABLE" == "true" ]]; then + HIGHEST_STABLE=$(uv run --no-sync --with onyx-devtools ods latest-stable-tag) || { + echo "::error::Failed to determine highest stable tag via 'ods latest-stable-tag'" + exit 1 + } + if [[ "$TAG" == "$HIGHEST_STABLE" ]]; then + IS_LATEST=true + fi + fi + + # Build craft-latest backend alongside the regular latest. + if [[ "$IS_LATEST" == "true" ]]; then + BUILD_BACKEND_CRAFT=true + fi + # Determine if this is a production tag # Production tags are: version tags (v1.2.3*) or nightly tags if [[ "$IS_VERSION_TAG" == "true" ]] || [[ "$IS_NIGHTLY" == "true" ]]; then @@ -129,11 +145,9 @@ jobs: echo "build-backend-craft=$BUILD_BACKEND_CRAFT" echo "build-model-server=$BUILD_MODEL_SERVER" echo "is-cloud-tag=$IS_CLOUD" - echo "is-stable=$IS_STABLE" echo "is-beta=$IS_BETA" - echo "is-stable-standalone=$IS_STABLE_STANDALONE" echo "is-beta-standalone=$IS_BETA_STANDALONE" - echo "is-craft-latest=$IS_CRAFT_LATEST" + echo "is-latest=$IS_LATEST" echo "is-test-run=$IS_TEST_RUN" echo "sanitized-tag=$SANITIZED_TAG" echo "short-sha=$SHORT_SHA" @@ -151,7 +165,7 @@ jobs: fetch-depth: 0 - name: Setup uv - uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7 with: version: "0.9.9" # NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable. @@ -182,9 +196,53 @@ jobs: title: "🚨 Version Tag Check Failed" ref-name: ${{ github.ref_name }} - build-desktop: + # Create GitHub release first, before desktop builds start. + # This ensures all desktop matrix jobs upload to the same release instead of + # racing to create duplicate releases. + create-release: needs: determine-builds if: needs.determine-builds.outputs.build-desktop == 'true' + runs-on: ubuntu-slim + timeout-minutes: 10 + permissions: + contents: write + outputs: + release-id: ${{ steps.create-release.outputs.id }} + steps: + - name: Checkout + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 + with: + persist-credentials: false + + - name: Determine release tag + id: release-tag + env: + IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }} + SHORT_SHA: ${{ needs.determine-builds.outputs.short-sha }} + run: | + if [ "${IS_TEST_RUN}" == "true" ]; then + echo "tag=v0.0.0-dev+${SHORT_SHA}" >> "$GITHUB_OUTPUT" + else + echo "tag=${GITHUB_REF_NAME}" >> "$GITHUB_OUTPUT" + fi + + - name: Create GitHub Release + id: create-release + uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631 # ratchet:softprops/action-gh-release@v2 + with: + tag_name: ${{ steps.release-tag.outputs.tag }} + name: ${{ steps.release-tag.outputs.tag }} + body: "See the assets to download this version and install." + draft: true + prerelease: false + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + build-desktop: + needs: + - determine-builds + - create-release + if: needs.determine-builds.outputs.build-desktop == 'true' permissions: id-token: write contents: write @@ -208,12 +266,12 @@ jobs: steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6.0.2 with: - # NOTE: persist-credentials is needed for tauri-action to create GitHub releases. + # NOTE: persist-credentials is needed for tauri-action to upload assets to GitHub releases. persist-credentials: true # zizmor: ignore[artipacked] - name: Configure AWS credentials if: startsWith(matrix.platform, 'macos-') - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -353,11 +411,9 @@ jobs: APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }} APPLE_TEAM_ID: ${{ env.APPLE_TEAM_ID }} with: - tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }} - releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }} - releaseBody: "See the assets to download this version and install." - releaseDraft: true - prerelease: false + # Use the release created by the create-release job to avoid race conditions + # when multiple matrix jobs try to create/update the same release simultaneously + releaseId: ${{ needs.create-release.outputs.release-id }} assetNamePattern: "[name]_[arch][ext]" args: ${{ matrix.args }} @@ -384,7 +440,7 @@ jobs: persist-credentials: false - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -426,8 +482,9 @@ jobs: ONYX_VERSION=${{ github.ref_name }} NODE_OPTIONS=--max-old-space-size=8192 cache-from: | - type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64 + type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge + type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64,mode=max @@ -457,7 +514,7 @@ jobs: persist-credentials: false - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -499,8 +556,9 @@ jobs: ONYX_VERSION=${{ github.ref_name }} NODE_OPTIONS=--max-old-space-size=8192 cache-from: | - type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64 + type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge + type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64,mode=max @@ -525,7 +583,7 @@ jobs: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -556,7 +614,7 @@ jobs: latest=false tags: | type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('web-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }} - type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable == 'true' && 'latest' || '' }} + type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }} type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }} type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }} @@ -595,7 +653,7 @@ jobs: persist-credentials: false - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -646,8 +704,8 @@ jobs: NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true NODE_OPTIONS=--max-old-space-size=8192 cache-from: | - type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64 + type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64,mode=max @@ -677,7 +735,7 @@ jobs: persist-credentials: false - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -728,8 +786,8 @@ jobs: NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true NODE_OPTIONS=--max-old-space-size=8192 cache-from: | - type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64 + type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64,mode=max @@ -754,7 +812,7 @@ jobs: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -821,7 +879,7 @@ jobs: persist-credentials: false - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -862,8 +920,9 @@ jobs: build-args: | ONYX_VERSION=${{ github.ref_name }} cache-from: | - type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64 + type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge + type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64,mode=max @@ -893,7 +952,7 @@ jobs: persist-credentials: false - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -934,8 +993,9 @@ jobs: build-args: | ONYX_VERSION=${{ github.ref_name }} cache-from: | - type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64 + type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge + type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64,mode=max @@ -960,7 +1020,7 @@ jobs: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -991,7 +1051,7 @@ jobs: latest=false tags: | type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('backend-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }} - type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }} + type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }} type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }} type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }} @@ -1030,7 +1090,7 @@ jobs: persist-credentials: false - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -1072,8 +1132,8 @@ jobs: ONYX_VERSION=${{ github.ref_name }} ENABLE_CRAFT=true cache-from: | - type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64 + type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64,mode=max @@ -1103,7 +1163,7 @@ jobs: persist-credentials: false - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -1145,8 +1205,8 @@ jobs: ONYX_VERSION=${{ github.ref_name }} ENABLE_CRAFT=true cache-from: | - type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64 + type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64,mode=max @@ -1172,7 +1232,7 @@ jobs: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -1242,7 +1302,7 @@ jobs: persist-credentials: false - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -1287,8 +1347,9 @@ jobs: build-args: | ONYX_VERSION=${{ github.ref_name }} cache-from: | - type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64 + type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge + type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64,mode=max @@ -1321,7 +1382,7 @@ jobs: persist-credentials: false - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -1366,8 +1427,9 @@ jobs: build-args: | ONYX_VERSION=${{ github.ref_name }} cache-from: | - type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64 + type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge + type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64,mode=max @@ -1394,7 +1456,7 @@ jobs: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -1425,7 +1487,7 @@ jobs: latest=false tags: | type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('model-server-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }} - type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }} + type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }} type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }} type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }} @@ -1459,7 +1521,7 @@ jobs: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -1514,7 +1576,7 @@ jobs: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -1574,7 +1636,7 @@ jobs: persist-credentials: false - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -1631,7 +1693,7 @@ jobs: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 diff --git a/.github/workflows/nightly-llm-provider-chat.yml b/.github/workflows/nightly-llm-provider-chat.yml index a6fbdea6569..a27c1aab798 100644 --- a/.github/workflows/nightly-llm-provider-chat.yml +++ b/.github/workflows/nightly-llm-provider-chat.yml @@ -15,6 +15,11 @@ permissions: jobs: provider-chat-test: uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml + secrets: + AWS_OIDC_ROLE_ARN: ${{ secrets.AWS_OIDC_ROLE_ARN }} + permissions: + contents: read + id-token: write with: openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }} anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }} @@ -25,16 +30,6 @@ jobs: ollama_models: ${{ vars.NIGHTLY_LLM_OLLAMA_MODELS }} openrouter_models: ${{ vars.NIGHTLY_LLM_OPENROUTER_MODELS }} strict: true - secrets: - openai_api_key: ${{ secrets.OPENAI_API_KEY }} - anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} - bedrock_api_key: ${{ secrets.BEDROCK_API_KEY }} - vertex_ai_custom_config_json: ${{ secrets.NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON }} - azure_api_key: ${{ secrets.AZURE_API_KEY }} - ollama_api_key: ${{ secrets.OLLAMA_API_KEY }} - openrouter_api_key: ${{ secrets.OPENROUTER_API_KEY }} - DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} - DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }} notify-slack-on-failure: needs: [provider-chat-test] diff --git a/.github/workflows/post-merge-beta-cherry-pick.yml b/.github/workflows/post-merge-beta-cherry-pick.yml index 06993fa5a53..5e84e652d3c 100644 --- a/.github/workflows/post-merge-beta-cherry-pick.yml +++ b/.github/workflows/post-merge-beta-cherry-pick.yml @@ -1,65 +1,102 @@ name: Post-Merge Beta Cherry-Pick on: - push: - branches: - - main - + pull_request_target: + types: + - closed + +# SECURITY NOTE: +# This workflow intentionally uses pull_request_target so post-merge automation can +# use base-repo credentials. Do not checkout PR head refs in this workflow +# (e.g. github.event.pull_request.head.sha). Only trusted base refs are allowed. permissions: - contents: write - pull-requests: write + contents: read jobs: - cherry-pick-to-latest-release: + resolve-cherry-pick-request: + if: >- + github.event.pull_request.merged == true + && github.event.pull_request.base.ref == 'main' + && github.event.pull_request.head.repo.full_name == github.repository outputs: should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }} pr_number: ${{ steps.gate.outputs.pr_number }} - cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }} - cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }} + merge_commit_sha: ${{ steps.gate.outputs.merge_commit_sha }} + merged_by: ${{ steps.gate.outputs.merged_by }} + gate_error: ${{ steps.gate.outputs.gate_error }} runs-on: ubuntu-latest - timeout-minutes: 45 + timeout-minutes: 10 steps: - name: Resolve merged PR and checkbox state id: gate env: GH_TOKEN: ${{ github.token }} + PR_NUMBER: ${{ github.event.pull_request.number }} + # SECURITY: keep PR body in env/plain-text handling; avoid directly + # inlining github.event.pull_request.body into shell commands. + PR_BODY: ${{ github.event.pull_request.body }} + MERGE_COMMIT_SHA: ${{ github.event.pull_request.merge_commit_sha }} + MERGED_BY: ${{ github.event.pull_request.merged_by.login }} + # GitHub team slug authorized to trigger cherry-picks (e.g. "core-eng"). + # For private/secret teams the GITHUB_TOKEN may need org:read scope; + # visible teams work with the default token. + ALLOWED_TEAM: "onyx-core-team" run: | - # For the commit that triggered this workflow (HEAD on main), fetch all - # associated PRs and keep only the PR that was actually merged into main - # with this exact merge commit SHA. - pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')" - match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')" - pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)" - - if [ "${match_count}" -gt 1 ]; then - echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}." - fi + echo "pr_number=${PR_NUMBER}" >> "$GITHUB_OUTPUT" + echo "merged_by=${MERGED_BY}" >> "$GITHUB_OUTPUT" - if [ -z "$pr_number" ]; then - echo "No merged PR associated with commit ${GITHUB_SHA}; skipping." + if ! echo "${PR_BODY}" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then echo "should_cherrypick=false" >> "$GITHUB_OUTPUT" + echo "Cherry-pick checkbox not checked for PR #${PR_NUMBER}. Skipping." exit 0 fi - # Read the PR once so we can gate behavior and infer preferred actor. - pr_json="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}")" - pr_body="$(printf '%s' "$pr_json" | jq -r '.body // ""')" - merged_by="$(printf '%s' "$pr_json" | jq -r '.merged_by.login // ""')" + # Keep should_cherrypick output before any possible exit 1 below so + # notify-slack can still gate on this output even if this job fails. + echo "should_cherrypick=true" >> "$GITHUB_OUTPUT" + echo "Cherry-pick checkbox checked for PR #${PR_NUMBER}." - echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT" - echo "merged_by=$merged_by" >> "$GITHUB_OUTPUT" + if [ -z "${MERGE_COMMIT_SHA}" ] || [ "${MERGE_COMMIT_SHA}" = "null" ]; then + echo "gate_error=missing-merge-commit-sha" >> "$GITHUB_OUTPUT" + echo "::error::PR #${PR_NUMBER} requested cherry-pick, but merge_commit_sha is missing." + exit 1 + fi - if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then - echo "should_cherrypick=true" >> "$GITHUB_OUTPUT" - echo "Cherry-pick checkbox checked for PR #${pr_number}." - exit 0 + echo "merge_commit_sha=${MERGE_COMMIT_SHA}" >> "$GITHUB_OUTPUT" + + member_state_file="$(mktemp)" + member_err_file="$(mktemp)" + if ! gh api "orgs/${GITHUB_REPOSITORY_OWNER}/teams/${ALLOWED_TEAM}/memberships/${MERGED_BY}" --jq '.state' >"${member_state_file}" 2>"${member_err_file}"; then + api_err="$(tr '\n' ' ' < "${member_err_file}" | sed 's/[[:space:]]\+/ /g' | cut -c1-300)" + echo "gate_error=team-api-error" >> "$GITHUB_OUTPUT" + echo "::error::Team membership API call failed for ${MERGED_BY} in ${ALLOWED_TEAM}: ${api_err}" + exit 1 fi - echo "should_cherrypick=false" >> "$GITHUB_OUTPUT" - echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping." + member_state="$(cat "${member_state_file}")" + if [ "${member_state}" != "active" ]; then + echo "gate_error=not-team-member" >> "$GITHUB_OUTPUT" + echo "::error::${MERGED_BY} is not an active member of team ${ALLOWED_TEAM} (state: ${member_state}). Failing cherry-pick gate." + exit 1 + fi + exit 0 + + cherry-pick-to-latest-release: + needs: + - resolve-cherry-pick-request + if: needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && needs.resolve-cherry-pick-request.result == 'success' + permissions: + contents: write + pull-requests: write + outputs: + cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }} + cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }} + runs-on: ubuntu-latest + timeout-minutes: 45 + steps: - name: Checkout repository - if: steps.gate.outputs.should_cherrypick == 'true' + # SECURITY: keep checkout pinned to trusted base branch; do not switch to PR head refs. uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: fetch-depth: 0 @@ -67,31 +104,37 @@ jobs: ref: main - name: Install the latest version of uv - if: steps.gate.outputs.should_cherrypick == 'true' - uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7 with: enable-cache: false version: "0.9.9" - name: Configure git identity - if: steps.gate.outputs.should_cherrypick == 'true' run: | git config user.name "github-actions[bot]" git config user.email "github-actions[bot]@users.noreply.github.com" - name: Create cherry-pick PR to latest release id: run_cherry_pick - if: steps.gate.outputs.should_cherrypick == 'true' - continue-on-error: true env: GH_TOKEN: ${{ github.token }} GITHUB_TOKEN: ${{ github.token }} - CHERRY_PICK_ASSIGNEE: ${{ steps.gate.outputs.merged_by }} + CHERRY_PICK_ASSIGNEE: ${{ needs.resolve-cherry-pick-request.outputs.merged_by }} + MERGE_COMMIT_SHA: ${{ needs.resolve-cherry-pick-request.outputs.merge_commit_sha }} run: | - set -o pipefail output_file="$(mktemp)" - uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify 2>&1 | tee "$output_file" - exit_code="${PIPESTATUS[0]}" + set +e + uv run --no-sync --with onyx-devtools ods cherry-pick "${MERGE_COMMIT_SHA}" --yes --no-verify 2>&1 | tee "$output_file" + pipe_statuses=("${PIPESTATUS[@]}") + exit_code="${pipe_statuses[0]}" + tee_exit="${pipe_statuses[1]:-0}" + set -e + if [ "${tee_exit}" -ne 0 ]; then + echo "status=failure" >> "$GITHUB_OUTPUT" + echo "reason=output-capture-failed" >> "$GITHUB_OUTPUT" + echo "::error::tee failed to capture cherry-pick output (exit ${tee_exit}); cannot classify result." + exit 1 + fi if [ "${exit_code}" -eq 0 ]; then echo "status=success" >> "$GITHUB_OUTPUT" @@ -113,7 +156,7 @@ jobs: } >> "$GITHUB_OUTPUT" - name: Mark workflow as failed if cherry-pick failed - if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure' + if: steps.run_cherry_pick.outputs.status == 'failure' env: CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }} run: | @@ -122,8 +165,9 @@ jobs: notify-slack-on-cherry-pick-failure: needs: + - resolve-cherry-pick-request - cherry-pick-to-latest-release - if: always() && needs.cherry-pick-to-latest-release.outputs.should_cherrypick == 'true' && needs.cherry-pick-to-latest-release.result != 'success' + if: always() && needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && (needs.resolve-cherry-pick-request.result == 'failure' || needs.cherry-pick-to-latest-release.result == 'failure') runs-on: ubuntu-slim timeout-minutes: 10 steps: @@ -132,22 +176,49 @@ jobs: with: persist-credentials: false + - name: Fail if Slack webhook secret is missing + env: + CHERRY_PICK_PRS_WEBHOOK: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }} + run: | + if [ -z "${CHERRY_PICK_PRS_WEBHOOK}" ]; then + echo "::error::CHERRY_PICK_PRS_WEBHOOK is not configured." + exit 1 + fi + - name: Build cherry-pick failure summary id: failure-summary env: - SOURCE_PR_NUMBER: ${{ needs.cherry-pick-to-latest-release.outputs.pr_number }} + SOURCE_PR_NUMBER: ${{ needs.resolve-cherry-pick-request.outputs.pr_number }} + MERGE_COMMIT_SHA: ${{ needs.resolve-cherry-pick-request.outputs.merge_commit_sha }} + GATE_ERROR: ${{ needs.resolve-cherry-pick-request.outputs.gate_error }} CHERRY_PICK_REASON: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_reason }} CHERRY_PICK_DETAILS: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_details }} run: | source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}" reason_text="cherry-pick command failed" - if [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then + if [ "${GATE_ERROR}" = "missing-merge-commit-sha" ]; then + reason_text="requested cherry-pick but merge commit SHA was missing" + elif [ "${GATE_ERROR}" = "team-api-error" ]; then + reason_text="team membership lookup failed while validating cherry-pick permissions" + elif [ "${GATE_ERROR}" = "not-team-member" ]; then + reason_text="merger is not an active member of the allowed team" + elif [ "${CHERRY_PICK_REASON}" = "output-capture-failed" ]; then + reason_text="failed to capture cherry-pick output for classification" + elif [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then reason_text="merge conflict during cherry-pick" fi details_excerpt="$(printf '%s' "${CHERRY_PICK_DETAILS}" | tail -n 8 | tr '\n' ' ' | sed "s/[[:space:]]\\+/ /g" | sed "s/\"/'/g" | cut -c1-350)" - failed_jobs="• cherry-pick-to-latest-release\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}" + if [ -n "${GATE_ERROR}" ]; then + failed_job_label="resolve-cherry-pick-request" + else + failed_job_label="cherry-pick-to-latest-release" + fi + failed_jobs="• ${failed_job_label}\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}" + if [ -n "${MERGE_COMMIT_SHA}" ]; then + failed_jobs="${failed_jobs}\\n• merge SHA: ${MERGE_COMMIT_SHA}" + fi if [ -n "${details_excerpt}" ]; then failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}" fi @@ -160,4 +231,4 @@ jobs: webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }} failed-jobs: ${{ steps.failure-summary.outputs.jobs }} title: "🚨 Automated Cherry-Pick Failed" - ref-name: ${{ github.ref_name }} + ref-name: ${{ github.event.pull_request.base.ref }} diff --git a/.github/workflows/pr-desktop-build.yml b/.github/workflows/pr-desktop-build.yml index 9bf0b3f3b96..64918081aef 100644 --- a/.github/workflows/pr-desktop-build.yml +++ b/.github/workflows/pr-desktop-build.yml @@ -57,7 +57,7 @@ jobs: cache-dependency-path: ./desktop/package-lock.json - name: Setup Rust - uses: dtolnay/rust-toolchain@4be9e76fd7c4901c61fb841f559994984270fce7 + uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9 with: toolchain: stable targets: ${{ matrix.target }} diff --git a/.github/workflows/pr-external-dependency-unit-tests.yml b/.github/workflows/pr-external-dependency-unit-tests.yml index 673dbb5e200..f26aa69141d 100644 --- a/.github/workflows/pr-external-dependency-unit-tests.yml +++ b/.github/workflows/pr-external-dependency-unit-tests.yml @@ -160,7 +160,7 @@ jobs: cd deployment/docker_compose # Get list of running containers - containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.opensearch.yml ps -q) + containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q) # Collect logs from each container for container in $containers; do diff --git a/.github/workflows/pr-golang-tests.yml b/.github/workflows/pr-golang-tests.yml new file mode 100644 index 00000000000..ae158482df3 --- /dev/null +++ b/.github/workflows/pr-golang-tests.yml @@ -0,0 +1,56 @@ +name: Golang Tests +concurrency: + group: Golang-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }} + cancel-in-progress: true + +on: + merge_group: + pull_request: + branches: + - main + - "release/**" + push: + tags: + - "v*.*.*" + +permissions: {} + +env: + GO_VERSION: "1.26" + +jobs: + detect-modules: + runs-on: ubuntu-latest + timeout-minutes: 10 + outputs: + modules: ${{ steps.set-modules.outputs.modules }} + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 + with: + persist-credentials: false + - id: set-modules + run: echo "modules=$(find . -name 'go.mod' -exec dirname {} \; | jq -Rc '[.,inputs]')" >> "$GITHUB_OUTPUT" + + golang: + needs: detect-modules + runs-on: ubuntu-latest + timeout-minutes: 10 + strategy: + matrix: + modules: ${{ fromJSON(needs.detect-modules.outputs.modules) }} + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6 + with: + persist-credentials: false + - uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # zizmor: ignore[cache-poisoning] + with: + go-version: ${{ env.GO_VERSION }} + cache-dependency-path: "**/go.sum" + + - run: go mod tidy + working-directory: ${{ matrix.modules }} + - run: git diff --exit-code go.mod go.sum + working-directory: ${{ matrix.modules }} + + - run: go test ./... + working-directory: ${{ matrix.modules }} diff --git a/.github/workflows/pr-helm-chart-testing.yml b/.github/workflows/pr-helm-chart-testing.yml index 3bd8f53d6d1..86f56bbd885 100644 --- a/.github/workflows/pr-helm-chart-testing.yml +++ b/.github/workflows/pr-helm-chart-testing.yml @@ -71,7 +71,7 @@ jobs: - name: Create kind cluster if: steps.list-changed.outputs.changed == 'true' - uses: helm/kind-action@92086f6be054225fa813e0a4b13787fc9088faab # ratchet:helm/kind-action@v1.13.0 + uses: helm/kind-action@ef37e7f390d99f746eb8b610417061a60e82a6cc # ratchet:helm/kind-action@v1.14.0 - name: Pre-install cluster status check if: steps.list-changed.outputs.changed == 'true' @@ -133,7 +133,7 @@ jobs: echo "=== Validating chart dependencies ===" cd deployment/helm/charts/onyx helm dependency update - helm lint . + helm lint . --set auth.userauth.values.user_auth_secret=placeholder - name: Run chart-testing (install) with enhanced monitoring timeout-minutes: 25 @@ -194,6 +194,7 @@ jobs: --set=vespa.enabled=false \ --set=opensearch.enabled=true \ --set=auth.opensearch.enabled=true \ + --set=auth.userauth.values.user_auth_secret=test-secret \ --set=slackbot.enabled=false \ --set=postgresql.enabled=true \ --set=postgresql.cluster.storage.storageClass=standard \ @@ -230,6 +231,10 @@ jobs: if: steps.list-changed.outputs.changed == 'true' run: | echo "=== Post-install verification ===" + if ! kubectl cluster-info >/dev/null 2>&1; then + echo "ERROR: Kubernetes cluster is not reachable after install" + exit 1 + fi kubectl get pods --all-namespaces kubectl get services --all-namespaces # Only show issues if they exist @@ -239,6 +244,10 @@ jobs: if: failure() && steps.list-changed.outputs.changed == 'true' run: | echo "=== Cleanup on failure ===" + if ! kubectl cluster-info >/dev/null 2>&1; then + echo "Skipping failure cleanup: Kubernetes cluster is not reachable" + exit 0 + fi echo "=== Final cluster state ===" kubectl get pods --all-namespaces kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -10 diff --git a/.github/workflows/pr-integration-tests.yml b/.github/workflows/pr-integration-tests.yml index 6797f3d03d8..e9b8006c2c7 100644 --- a/.github/workflows/pr-integration-tests.yml +++ b/.github/workflows/pr-integration-tests.yml @@ -316,6 +316,7 @@ jobs: # Base config shared by both editions cat < deployment/docker_compose/.env COMPOSE_PROFILES=s3-filestore + OPENSEARCH_FOR_ONYX_ENABLED=false AUTH_TYPE=basic POSTGRES_POOL_PRE_PING=true POSTGRES_USE_NULL_POOL=true @@ -335,7 +336,6 @@ jobs: # TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license LICENSE_ENFORCEMENT_ENABLED=false CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 - USE_LIGHTWEIGHT_BACKGROUND_WORKER=false EOF fi @@ -419,6 +419,7 @@ jobs: -e POSTGRES_POOL_PRE_PING=true \ -e POSTGRES_USE_NULL_POOL=true \ -e VESPA_HOST=index \ + -e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \ -e REDIS_HOST=cache \ -e API_SERVER_HOST=api_server \ -e OPENAI_API_KEY=${OPENAI_API_KEY} \ @@ -471,13 +472,13 @@ jobs: path: ${{ github.workspace }}/docker-compose.log # ------------------------------------------------------------ - no-vectordb-tests: + onyx-lite-tests: needs: [build-backend-image, build-integration-image] runs-on: [ runs-on, runner=4cpu-linux-arm64, - "run-id=${{ github.run_id }}-no-vectordb-tests", + "run-id=${{ github.run_id }}-onyx-lite-tests", "extras=ecr-cache", ] timeout-minutes: 45 @@ -495,13 +496,12 @@ jobs: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - - name: Create .env file for no-vectordb Docker Compose + - name: Create .env file for Onyx Lite Docker Compose env: ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }} RUN_ID: ${{ github.run_id }} run: | cat < deployment/docker_compose/.env - COMPOSE_PROFILES=s3-filestore ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true LICENSE_ENFORCEMENT_ENABLED=false AUTH_TYPE=basic @@ -509,28 +509,23 @@ jobs: POSTGRES_USE_NULL_POOL=true REQUIRE_EMAIL_VERIFICATION=false DISABLE_TELEMETRY=true - DISABLE_VECTOR_DB=true ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} INTEGRATION_TESTS_MODE=true - USE_LIGHTWEIGHT_BACKGROUND_WORKER=true EOF - # Start only the services needed for no-vectordb mode (no Vespa, no model servers) - - name: Start Docker containers (no-vectordb) + # Start only the services needed for Onyx Lite (Postgres + API server) + - name: Start Docker containers (onyx-lite) run: | cd deployment/docker_compose - docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml up \ + docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up \ relational_db \ - cache \ - minio \ api_server \ - background \ -d - id: start_docker_no_vectordb + id: start_docker_onyx_lite - name: Wait for services to be ready run: | - echo "Starting wait-for-service script (no-vectordb)..." + echo "Starting wait-for-service script (onyx-lite)..." start_time=$(date +%s) timeout=300 while true; do @@ -552,14 +547,14 @@ jobs: sleep 5 done - - name: Run No-VectorDB Integration Tests + - name: Run Onyx Lite Integration Tests uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3 with: timeout_minutes: 20 max_attempts: 3 retry_wait_seconds: 10 command: | - echo "Running no-vectordb integration tests..." + echo "Running onyx-lite integration tests..." docker run --rm --network onyx_default \ --name test-runner \ -e POSTGRES_HOST=relational_db \ @@ -570,39 +565,38 @@ jobs: -e DB_READONLY_PASSWORD=password \ -e POSTGRES_POOL_PRE_PING=true \ -e POSTGRES_USE_NULL_POOL=true \ - -e REDIS_HOST=cache \ -e API_SERVER_HOST=api_server \ -e OPENAI_API_KEY=${OPENAI_API_KEY} \ -e TEST_WEB_HOSTNAME=test-runner \ ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \ /app/tests/integration/tests/no_vectordb - - name: Dump API server logs (no-vectordb) + - name: Dump API server logs (onyx-lite) if: always() run: | cd deployment/docker_compose - docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \ - logs --no-color api_server > $GITHUB_WORKSPACE/api_server_no_vectordb.log || true + docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml \ + logs --no-color api_server > $GITHUB_WORKSPACE/api_server_onyx_lite.log || true - - name: Dump all-container logs (no-vectordb) + - name: Dump all-container logs (onyx-lite) if: always() run: | cd deployment/docker_compose - docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \ - logs --no-color > $GITHUB_WORKSPACE/docker-compose-no-vectordb.log || true + docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml \ + logs --no-color > $GITHUB_WORKSPACE/docker-compose-onyx-lite.log || true - - name: Upload logs (no-vectordb) + - name: Upload logs (onyx-lite) if: always() uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f with: - name: docker-all-logs-no-vectordb - path: ${{ github.workspace }}/docker-compose-no-vectordb.log + name: docker-all-logs-onyx-lite + path: ${{ github.workspace }}/docker-compose-onyx-lite.log - - name: Stop Docker containers (no-vectordb) + - name: Stop Docker containers (onyx-lite) if: always() run: | cd deployment/docker_compose - docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml down -v + docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml down -v multitenant-tests: needs: @@ -645,6 +639,7 @@ jobs: ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \ ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \ DEV_MODE=true \ + OPENSEARCH_FOR_ONYX_ENABLED=false \ docker compose -f docker-compose.multitenant-dev.yml up \ relational_db \ index \ @@ -699,6 +694,7 @@ jobs: -e POSTGRES_DB=postgres \ -e POSTGRES_USE_NULL_POOL=true \ -e VESPA_HOST=index \ + -e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \ -e REDIS_HOST=cache \ -e API_SERVER_HOST=api_server \ -e OPENAI_API_KEY=${OPENAI_API_KEY} \ @@ -744,7 +740,7 @@ jobs: # NOTE: Github-hosted runners have about 20s faster queue times and are preferred here. runs-on: ubuntu-slim timeout-minutes: 45 - needs: [integration-tests, no-vectordb-tests, multitenant-tests] + needs: [integration-tests, onyx-lite-tests, multitenant-tests] if: ${{ always() }} steps: - name: Check job status diff --git a/.github/workflows/pr-jest-tests.yml b/.github/workflows/pr-jest-tests.yml index e7fa59d117b..d2eb465b988 100644 --- a/.github/workflows/pr-jest-tests.yml +++ b/.github/workflows/pr-jest-tests.yml @@ -31,7 +31,7 @@ jobs: uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4 with: node-version: 22 - cache: "npm" + cache: "npm" # zizmor: ignore[cache-poisoning] test-only workflow; no deploy artifacts cache-dependency-path: ./web/package-lock.json - name: Install node dependencies diff --git a/.github/workflows/pr-playwright-tests.yml b/.github/workflows/pr-playwright-tests.yml index 6abaad2a525..86bc00a510e 100644 --- a/.github/workflows/pr-playwright-tests.yml +++ b/.github/workflows/pr-playwright-tests.yml @@ -12,6 +12,9 @@ on: push: tags: - "v*.*.*" + # TODO: Remove this if we enable merge-queues for release branches. + branches: + - "release/**" permissions: contents: read @@ -268,10 +271,11 @@ jobs: persist-credentials: false - name: Setup node + # zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4 with: node-version: 22 - cache: "npm" + cache: "npm" # zizmor: ignore[cache-poisoning] cache-dependency-path: ./web/package-lock.json - name: Install node dependencies @@ -279,6 +283,7 @@ jobs: run: npm ci - name: Cache playwright cache + # zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4 with: path: ~/.cache/ms-playwright @@ -459,14 +464,14 @@ jobs: # --- Visual Regression Diff --- - name: Configure AWS credentials if: always() - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Install the latest version of uv if: always() - uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7 with: enable-cache: false version: "0.9.9" @@ -590,6 +595,108 @@ jobs: name: docker-logs-${{ matrix.project }}-${{ github.run_id }} path: ${{ github.workspace }}/docker-compose.log + playwright-tests-lite: + needs: [build-web-image, build-backend-image] + name: Playwright Tests (lite) + runs-on: + - runs-on + - runner=4cpu-linux-arm64 + - "run-id=${{ github.run_id }}-playwright-tests-lite" + - "extras=ecr-cache" + timeout-minutes: 30 + steps: + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 + + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 + with: + persist-credentials: false + + - name: Setup node + # zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts + uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4 + with: + node-version: 22 + cache: "npm" # zizmor: ignore[cache-poisoning] + cache-dependency-path: ./web/package-lock.json + + - name: Install node dependencies + working-directory: ./web + run: npm ci + + - name: Cache playwright cache + # zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts + uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4 + with: + path: ~/.cache/ms-playwright + key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }} + restore-keys: | + ${{ runner.os }}-playwright-npm- + + - name: Install playwright browsers + working-directory: ./web + run: npx playwright install --with-deps + + - name: Create .env file for Docker Compose + env: + OPENAI_API_KEY_VALUE: ${{ env.OPENAI_API_KEY }} + ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }} + RUN_ID: ${{ github.run_id }} + run: | + cat < deployment/docker_compose/.env + ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true + LICENSE_ENFORCEMENT_ENABLED=false + AUTH_TYPE=basic + INTEGRATION_TESTS_MODE=true + GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE} + MOCK_LLM_RESPONSE=true + REQUIRE_EMAIL_VERIFICATION=false + DISABLE_TELEMETRY=true + ONYX_BACKEND_IMAGE=${ECR_CACHE}:playwright-test-backend-${RUN_ID} + ONYX_WEB_SERVER_IMAGE=${ECR_CACHE}:playwright-test-web-${RUN_ID} + EOF + + # needed for pulling external images otherwise, we hit the "Unauthenticated users" limit + # https://docs.docker.com/docker-hub/usage/ + - name: Login to Docker Hub + uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_TOKEN }} + + - name: Start Docker containers (lite) + run: | + cd deployment/docker_compose + docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up -d + id: start_docker + + - name: Run Playwright tests (lite) + working-directory: ./web + run: npx playwright test --project lite + + - uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f + if: always() + with: + name: playwright-test-results-lite-${{ github.run_id }} + path: ./web/output/playwright/ + retention-days: 30 + + - name: Save Docker logs + if: success() || failure() + env: + WORKSPACE: ${{ github.workspace }} + run: | + cd deployment/docker_compose + docker compose logs > docker-compose.log + mv docker-compose.log ${WORKSPACE}/docker-compose.log + + - name: Upload logs + if: success() || failure() + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f + with: + name: docker-logs-lite-${{ github.run_id }} + path: ${{ github.workspace }}/docker-compose.log + # Post a single combined visual regression comment after all matrix jobs finish visual-regression-comment: needs: [playwright-tests] @@ -603,7 +710,7 @@ jobs: pull-requests: write steps: - name: Download visual diff summaries - uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 + uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 with: pattern: screenshot-diff-summary-* path: summaries/ @@ -686,7 +793,7 @@ jobs: # NOTE: Github-hosted runners have about 20s faster queue times and are preferred here. runs-on: ubuntu-slim timeout-minutes: 45 - needs: [playwright-tests] + needs: [playwright-tests, playwright-tests-lite] if: ${{ always() }} steps: - name: Check job status diff --git a/.github/workflows/pr-python-checks.yml b/.github/workflows/pr-python-checks.yml index a9f95d985af..b1289f51b2d 100644 --- a/.github/workflows/pr-python-checks.yml +++ b/.github/workflows/pr-python-checks.yml @@ -8,7 +8,7 @@ on: pull_request: branches: - main - - 'release/**' + - "release/**" push: tags: - "v*.*.*" @@ -21,7 +21,13 @@ jobs: # See https://runs-on.com/runners/linux/ # Note: Mypy seems quite optimized for x64 compared to arm64. # Similarly, mypy is single-threaded and incremental, so 2cpu is sufficient. - runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-mypy-check", "extras=s3-cache"] + runs-on: + [ + runs-on, + runner=2cpu-linux-x64, + "run-id=${{ github.run_id }}-mypy-check", + "extras=s3-cache", + ] timeout-minutes: 45 steps: @@ -52,21 +58,14 @@ jobs: if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }} uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4 with: - path: backend/.mypy_cache - key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }} + path: .mypy_cache + key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }} restore-keys: | mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}- mypy-${{ runner.os }}- - name: Run MyPy - working-directory: ./backend env: MYPY_FORCE_COLOR: 1 TERM: xterm-256color run: mypy . - - - name: Run MyPy (tools/) - env: - MYPY_FORCE_COLOR: 1 - TERM: xterm-256color - run: mypy tools/ diff --git a/.github/workflows/pr-quality-checks.yml b/.github/workflows/pr-quality-checks.yml index ac9a9bd36f5..6bdaad1fb5b 100644 --- a/.github/workflows/pr-quality-checks.yml +++ b/.github/workflows/pr-quality-checks.yml @@ -28,7 +28,7 @@ jobs: with: python-version: "3.11" - name: Setup Terraform - uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # ratchet:hashicorp/setup-terraform@v3 + uses: hashicorp/setup-terraform@5e8dbf3c6d9deaf4193ca7a8fb23f2ac83bb6c85 # ratchet:hashicorp/setup-terraform@v4.0.0 - name: Setup node uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v6 with: # zizmor: ignore[cache-poisoning] @@ -38,9 +38,9 @@ jobs: - name: Install node dependencies working-directory: ./web run: npm ci - - uses: j178/prek-action@9d6a3097e0c1865ecce00cfb89fe80f2ee91b547 # ratchet:j178/prek-action@v1 + - uses: j178/prek-action@0bb87d7f00b0c99306c8bcb8b8beba1eb581c037 # ratchet:j178/prek-action@v1 with: - prek-version: '0.2.21' + prek-version: '0.3.4' extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }} - name: Check Actions uses: giner/check-actions@28d366c7cbbe235f9624a88aa31a628167eee28c # ratchet:giner/check-actions@v1.0.1 diff --git a/.github/workflows/release-cli.yml b/.github/workflows/release-cli.yml new file mode 100644 index 00000000000..24d4b649735 --- /dev/null +++ b/.github/workflows/release-cli.yml @@ -0,0 +1,214 @@ +name: Release CLI + +on: + push: + tags: + - "cli/v*.*.*" + +jobs: + pypi: + runs-on: ubuntu-latest + environment: + name: release-cli + permissions: + id-token: write + timeout-minutes: 10 + strategy: + matrix: + os-arch: + - { goos: "linux", goarch: "amd64" } + - { goos: "linux", goarch: "arm64" } + - { goos: "windows", goarch: "amd64" } + - { goos: "windows", goarch: "arm64" } + - { goos: "darwin", goarch: "amd64" } + - { goos: "darwin", goarch: "arm64" } + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 + with: + persist-credentials: false + - uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7 + with: + enable-cache: false + version: "0.9.9" + - run: | + GOOS="${{ matrix.os-arch.goos }}" \ + GOARCH="${{ matrix.os-arch.goarch }}" \ + uv build --wheel + working-directory: cli + - run: uv publish + working-directory: cli + + docker-amd64: + runs-on: + - runs-on + - runner=2cpu-linux-x64 + - run-id=${{ github.run_id }}-cli-amd64 + - extras=ecr-cache + environment: deploy + permissions: + id-token: write + timeout-minutes: 30 + outputs: + digest: ${{ steps.build.outputs.digest }} + env: + REGISTRY_IMAGE: onyxdotapp/onyx-cli + steps: + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 + + - name: Checkout + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 + with: + persist-credentials: false + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0 + with: + role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} + aws-region: us-east-2 + + - name: Get AWS Secrets + uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10 + with: + secret-ids: | + DOCKER_USERNAME, deploy/docker-username + DOCKER_TOKEN, deploy/docker-token + parse-json-secrets: true + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4 + + - name: Login to Docker Hub + uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4 + with: + username: ${{ env.DOCKER_USERNAME }} + password: ${{ env.DOCKER_TOKEN }} + + - name: Build and push AMD64 + id: build + uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7 + with: + context: ./cli + file: ./cli/Dockerfile + platforms: linux/amd64 + cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest + cache-to: type=inline + outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true + + docker-arm64: + runs-on: + - runs-on + - runner=2cpu-linux-arm64 + - run-id=${{ github.run_id }}-cli-arm64 + - extras=ecr-cache + environment: deploy + permissions: + id-token: write + timeout-minutes: 30 + outputs: + digest: ${{ steps.build.outputs.digest }} + env: + REGISTRY_IMAGE: onyxdotapp/onyx-cli + steps: + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 + + - name: Checkout + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 + with: + persist-credentials: false + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0 + with: + role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} + aws-region: us-east-2 + + - name: Get AWS Secrets + uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10 + with: + secret-ids: | + DOCKER_USERNAME, deploy/docker-username + DOCKER_TOKEN, deploy/docker-token + parse-json-secrets: true + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4 + + - name: Login to Docker Hub + uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4 + with: + username: ${{ env.DOCKER_USERNAME }} + password: ${{ env.DOCKER_TOKEN }} + + - name: Build and push ARM64 + id: build + uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7 + with: + context: ./cli + file: ./cli/Dockerfile + platforms: linux/arm64 + cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest + cache-to: type=inline + outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true + + merge-docker: + needs: + - docker-amd64 + - docker-arm64 + runs-on: + - runs-on + - runner=2cpu-linux-x64 + - run-id=${{ github.run_id }}-cli-merge + environment: deploy + permissions: + id-token: write + timeout-minutes: 10 + env: + REGISTRY_IMAGE: onyxdotapp/onyx-cli + steps: + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0 + with: + role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} + aws-region: us-east-2 + + - name: Get AWS Secrets + uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10 + with: + secret-ids: | + DOCKER_USERNAME, deploy/docker-username + DOCKER_TOKEN, deploy/docker-token + parse-json-secrets: true + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4 + + - name: Login to Docker Hub + uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4 + with: + username: ${{ env.DOCKER_USERNAME }} + password: ${{ env.DOCKER_TOKEN }} + + - name: Create and push manifest + env: + AMD64_DIGEST: ${{ needs.docker-amd64.outputs.digest }} + ARM64_DIGEST: ${{ needs.docker-arm64.outputs.digest }} + TAG: ${{ github.ref_name }} + run: | + SANITIZED_TAG="${TAG#cli/}" + IMAGES=( + "${REGISTRY_IMAGE}@${AMD64_DIGEST}" + "${REGISTRY_IMAGE}@${ARM64_DIGEST}" + ) + + if [[ "$TAG" =~ ^cli/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + docker buildx imagetools create \ + -t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \ + -t "${REGISTRY_IMAGE}:latest" \ + "${IMAGES[@]}" + else + docker buildx imagetools create \ + -t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \ + "${IMAGES[@]}" + fi diff --git a/.github/workflows/release-devtools.yml b/.github/workflows/release-devtools.yml index d883d2d206d..d1ebc35e4b5 100644 --- a/.github/workflows/release-devtools.yml +++ b/.github/workflows/release-devtools.yml @@ -22,13 +22,11 @@ jobs: - { goos: "windows", goarch: "arm64" } - { goos: "darwin", goarch: "amd64" } - { goos: "darwin", goarch: "arm64" } - - { goos: "", goarch: "" } steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - fetch-depth: 0 - - uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7 + - uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7 with: enable-cache: false version: "0.9.9" diff --git a/.github/workflows/reusable-nightly-llm-provider-chat.yml b/.github/workflows/reusable-nightly-llm-provider-chat.yml index 20417e2e88c..51213434e3e 100644 --- a/.github/workflows/reusable-nightly-llm-provider-chat.yml +++ b/.github/workflows/reusable-nightly-llm-provider-chat.yml @@ -49,27 +49,13 @@ on: default: true type: boolean secrets: - openai_api_key: - required: false - anthropic_api_key: - required: false - bedrock_api_key: - required: false - vertex_ai_custom_config_json: - required: false - azure_api_key: - required: false - ollama_api_key: - required: false - openrouter_api_key: - required: false - DOCKER_USERNAME: - required: true - DOCKER_TOKEN: + AWS_OIDC_ROLE_ARN: + description: "AWS role ARN for OIDC auth" required: true permissions: contents: read + id-token: write jobs: build-backend-image: @@ -81,6 +67,7 @@ jobs: "extras=ecr-cache", ] timeout-minutes: 45 + environment: ci-protected steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 @@ -89,6 +76,19 @@ jobs: with: persist-credentials: false + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 + with: + role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} + aws-region: us-east-2 + + - name: Get AWS Secrets + uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 + with: + secret-ids: | + DOCKER_USERNAME, test/docker-username + DOCKER_TOKEN, test/docker-token + - name: Build backend image uses: ./.github/actions/build-backend-image with: @@ -97,8 +97,8 @@ jobs: pr-number: ${{ github.event.pull_request.number }} github-sha: ${{ github.sha }} run-id: ${{ github.run_id }} - docker-username: ${{ secrets.DOCKER_USERNAME }} - docker-token: ${{ secrets.DOCKER_TOKEN }} + docker-username: ${{ env.DOCKER_USERNAME }} + docker-token: ${{ env.DOCKER_TOKEN }} docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }} build-model-server-image: @@ -110,6 +110,7 @@ jobs: "extras=ecr-cache", ] timeout-minutes: 45 + environment: ci-protected steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 @@ -118,6 +119,19 @@ jobs: with: persist-credentials: false + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 + with: + role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} + aws-region: us-east-2 + + - name: Get AWS Secrets + uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 + with: + secret-ids: | + DOCKER_USERNAME, test/docker-username + DOCKER_TOKEN, test/docker-token + - name: Build model server image uses: ./.github/actions/build-model-server-image with: @@ -126,8 +140,8 @@ jobs: pr-number: ${{ github.event.pull_request.number }} github-sha: ${{ github.sha }} run-id: ${{ github.run_id }} - docker-username: ${{ secrets.DOCKER_USERNAME }} - docker-token: ${{ secrets.DOCKER_TOKEN }} + docker-username: ${{ env.DOCKER_USERNAME }} + docker-token: ${{ env.DOCKER_TOKEN }} build-integration-image: runs-on: @@ -138,6 +152,7 @@ jobs: "extras=ecr-cache", ] timeout-minutes: 45 + environment: ci-protected steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 @@ -146,6 +161,19 @@ jobs: with: persist-credentials: false + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 + with: + role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} + aws-region: us-east-2 + + - name: Get AWS Secrets + uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 + with: + secret-ids: | + DOCKER_USERNAME, test/docker-username + DOCKER_TOKEN, test/docker-token + - name: Build integration image uses: ./.github/actions/build-integration-image with: @@ -154,8 +182,8 @@ jobs: pr-number: ${{ github.event.pull_request.number }} github-sha: ${{ github.sha }} run-id: ${{ github.run_id }} - docker-username: ${{ secrets.DOCKER_USERNAME }} - docker-token: ${{ secrets.DOCKER_TOKEN }} + docker-username: ${{ env.DOCKER_USERNAME }} + docker-token: ${{ env.DOCKER_TOKEN }} provider-chat-test: needs: @@ -170,56 +198,56 @@ jobs: include: - provider: openai models: ${{ inputs.openai_models }} - api_key_secret: openai_api_key - custom_config_secret: "" + api_key_env: OPENAI_API_KEY + custom_config_env: "" api_base: "" api_version: "" deployment_name: "" required: true - provider: anthropic models: ${{ inputs.anthropic_models }} - api_key_secret: anthropic_api_key - custom_config_secret: "" + api_key_env: ANTHROPIC_API_KEY + custom_config_env: "" api_base: "" api_version: "" deployment_name: "" required: true - provider: bedrock models: ${{ inputs.bedrock_models }} - api_key_secret: bedrock_api_key - custom_config_secret: "" + api_key_env: BEDROCK_API_KEY + custom_config_env: "" api_base: "" api_version: "" deployment_name: "" required: false - provider: vertex_ai models: ${{ inputs.vertex_ai_models }} - api_key_secret: "" - custom_config_secret: vertex_ai_custom_config_json + api_key_env: "" + custom_config_env: NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON api_base: "" api_version: "" deployment_name: "" required: false - provider: azure models: ${{ inputs.azure_models }} - api_key_secret: azure_api_key - custom_config_secret: "" + api_key_env: AZURE_API_KEY + custom_config_env: "" api_base: ${{ inputs.azure_api_base }} api_version: "2025-04-01-preview" deployment_name: "" required: false - provider: ollama_chat models: ${{ inputs.ollama_models }} - api_key_secret: ollama_api_key - custom_config_secret: "" + api_key_env: OLLAMA_API_KEY + custom_config_env: "" api_base: "https://ollama.com" api_version: "" deployment_name: "" required: false - provider: openrouter models: ${{ inputs.openrouter_models }} - api_key_secret: openrouter_api_key - custom_config_secret: "" + api_key_env: OPENROUTER_API_KEY + custom_config_env: "" api_base: "https://openrouter.ai/api/v1" api_version: "" deployment_name: "" @@ -230,6 +258,7 @@ jobs: - "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test" - extras=ecr-cache timeout-minutes: 45 + environment: ci-protected steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 @@ -238,21 +267,43 @@ jobs: with: persist-credentials: false + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 + with: + role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} + aws-region: us-east-2 + + - name: Get AWS Secrets + uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 + with: + # Keep JSON values unparsed so vertex custom config is passed as raw JSON. + parse-json-secrets: false + secret-ids: | + DOCKER_USERNAME, test/docker-username + DOCKER_TOKEN, test/docker-token + OPENAI_API_KEY, test/openai-api-key + ANTHROPIC_API_KEY, test/anthropic-api-key + BEDROCK_API_KEY, test/bedrock-api-key + NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON, test/nightly-llm-vertex-ai-custom-config-json + AZURE_API_KEY, test/azure-api-key + OLLAMA_API_KEY, test/ollama-api-key + OPENROUTER_API_KEY, test/openrouter-api-key + - name: Run nightly provider chat test uses: ./.github/actions/run-nightly-provider-chat-test with: provider: ${{ matrix.provider }} models: ${{ matrix.models }} - provider-api-key: ${{ matrix.api_key_secret && secrets[matrix.api_key_secret] || '' }} + provider-api-key: ${{ matrix.api_key_env && env[matrix.api_key_env] || '' }} strict: ${{ inputs.strict && 'true' || 'false' }} api-base: ${{ matrix.api_base }} api-version: ${{ matrix.api_version }} deployment-name: ${{ matrix.deployment_name }} - custom-config-json: ${{ matrix.custom_config_secret && secrets[matrix.custom_config_secret] || '' }} + custom-config-json: ${{ matrix.custom_config_env && env[matrix.custom_config_env] || '' }} runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }} run-id: ${{ github.run_id }} - docker-username: ${{ secrets.DOCKER_USERNAME }} - docker-token: ${{ secrets.DOCKER_TOKEN }} + docker-username: ${{ env.DOCKER_USERNAME }} + docker-token: ${{ env.DOCKER_TOKEN }} - name: Dump API server logs if: always() diff --git a/.github/workflows/sandbox-deployment.yml b/.github/workflows/sandbox-deployment.yml index 151addc2380..6add52b6894 100644 --- a/.github/workflows/sandbox-deployment.yml +++ b/.github/workflows/sandbox-deployment.yml @@ -110,7 +110,7 @@ jobs: persist-credentials: false - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -180,7 +180,7 @@ jobs: persist-credentials: false - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 @@ -244,7 +244,7 @@ jobs: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 diff --git a/.github/workflows/storybook-deploy.yml b/.github/workflows/storybook-deploy.yml new file mode 100644 index 00000000000..7195c031faf --- /dev/null +++ b/.github/workflows/storybook-deploy.yml @@ -0,0 +1,69 @@ +name: Storybook Deploy +env: + VERCEL_ORG_ID: ${{ secrets.VERCEL_ORG_ID }} + VERCEL_PROJECT_ID: prj_sG49mVsA25UsxIPhN2pmBJlikJZM + VERCEL_CLI: vercel@50.14.1 + VERCEL_TOKEN: ${{ secrets.VERCEL_TOKEN }} + +concurrency: + group: storybook-deploy-production + cancel-in-progress: true + +on: + workflow_dispatch: + push: + branches: + - main + paths: + - "web/lib/opal/**" + - "web/src/refresh-components/**" + - "web/.storybook/**" + - "web/package.json" + - "web/package-lock.json" +permissions: + contents: read +jobs: + Deploy-Storybook: + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4 + with: + persist-credentials: false + + - name: Setup node + uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4 + with: + node-version: 22 + cache: "npm" + cache-dependency-path: ./web/package-lock.json + + - name: Install dependencies + working-directory: web + run: npm ci + + - name: Build Storybook + working-directory: web + run: npm run storybook:build + + - name: Deploy to Vercel (Production) + working-directory: web + run: npx --yes "$VERCEL_CLI" deploy storybook-static/ --prod --yes --token="$VERCEL_TOKEN" + + notify-slack-on-failure: + needs: Deploy-Storybook + if: always() && needs.Deploy-Storybook.result == 'failure' + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4 + with: + persist-credentials: false + sparse-checkout: .github/actions/slack-notify + + - name: Send Slack notification + uses: ./.github/actions/slack-notify + with: + webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }} + failed-jobs: "• Deploy-Storybook" + title: "🚨 Storybook Deploy Failed" diff --git a/.github/workflows/zizmor.yml b/.github/workflows/zizmor.yml index 32f550fbb87..86dbb4494c7 100644 --- a/.github/workflows/zizmor.yml +++ b/.github/workflows/zizmor.yml @@ -24,7 +24,7 @@ jobs: persist-credentials: false - name: Install the latest version of uv - uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7 with: enable-cache: false version: "0.9.9" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c996ce2dc07..3d37c042cfa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -119,10 +119,11 @@ repos: ] - repo: https://github.com/golangci/golangci-lint - rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2 + rev: 5d1e709b7be35cb2025444e19de266b056b7b7ee # frozen: v2.10.1 hooks: - id: golangci-lint - entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'" + language_version: "1.26.0" + entry: bash -c "find . -name go.mod -not -path './.venv/*' -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'" - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. diff --git a/.vscode/env_template.txt b/.vscode/env_template.txt index cd398ab3ef5..3b19a3de58e 100644 --- a/.vscode/env_template.txt +++ b/.vscode/env_template.txt @@ -7,6 +7,9 @@ AUTH_TYPE=basic +# Recommended for basic auth - used for signing password reset and verification tokens +# Generate a secure value with: openssl rand -hex 32 +USER_AUTH_SECRET="" DEV_MODE=true diff --git a/.vscode/launch.json b/.vscode/launch.json index 41a8164c45f..6c17e103e31 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -40,19 +40,7 @@ } }, { - "name": "Celery (lightweight mode)", - "configurations": [ - "Celery primary", - "Celery background", - "Celery beat" - ], - "presentation": { - "group": "1" - }, - "stopAll": true - }, - { - "name": "Celery (standard mode)", + "name": "Celery", "configurations": [ "Celery primary", "Celery light", @@ -253,35 +241,6 @@ }, "consoleTitle": "Celery light Console" }, - { - "name": "Celery background", - "type": "debugpy", - "request": "launch", - "module": "celery", - "cwd": "${workspaceFolder}/backend", - "envFile": "${workspaceFolder}/.vscode/.env", - "env": { - "LOG_LEVEL": "INFO", - "PYTHONUNBUFFERED": "1", - "PYTHONPATH": "." - }, - "args": [ - "-A", - "onyx.background.celery.versioned_apps.background", - "worker", - "--pool=threads", - "--concurrency=20", - "--prefetch-multiplier=4", - "--loglevel=INFO", - "--hostname=background@%n", - "-Q", - "vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration" - ], - "presentation": { - "group": "2" - }, - "consoleTitle": "Celery background Console" - }, { "name": "Celery heavy", "type": "debugpy", @@ -526,21 +485,6 @@ "group": "3" } }, - { - "name": "Clear and Restart OpenSearch Container", - // Generic debugger type, required arg but has no bearing on bash. - "type": "node", - "request": "launch", - "runtimeExecutable": "bash", - "runtimeArgs": [ - "${workspaceFolder}/backend/scripts/restart_opensearch_container.sh" - ], - "cwd": "${workspaceFolder}", - "console": "integratedTerminal", - "presentation": { - "group": "3" - } - }, { "name": "Eval CLI", "type": "debugpy", diff --git a/AGENTS.md b/AGENTS.md index 5acea7ecfe5..54b939eb08b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -86,37 +86,6 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work - Monitoring tasks (every 5 minutes) - Cleanup tasks (hourly) -#### Worker Deployment Modes - -Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable: - -**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`): - -- Runs a single consolidated `background` worker that handles all background tasks: - - Light worker tasks (Vespa operations, permissions sync, deletion) - - Document processing (indexing pipeline) - - Document fetching (connector data retrieval) - - Pruning operations (from `heavy` worker) - - Knowledge graph processing (from `kg_processing` worker) - - Monitoring tasks (from `monitoring` worker) - - User file processing (from `user_file_processing` worker) -- Lower resource footprint (fewer worker processes) -- Suitable for smaller deployments or development environments -- Default concurrency: 20 threads (increased to handle combined workload) - -**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`): - -- Runs separate specialized workers as documented above (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing) -- Better isolation and scalability -- Can scale individual workers independently based on workload -- Suitable for production deployments with higher load - -The deployment mode affects: - -- **Backend**: Worker processes spawned by supervisord or dev scripts -- **Helm**: Which Kubernetes deployments are created -- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns - #### Key Features - **Thread-based Workers**: All workers use thread pools (not processes) for stability @@ -135,6 +104,10 @@ The deployment mode affects: - Always use `@shared_task` rather than `@celery_app` - Put tasks under `background/celery/tasks/` or `ee/background/celery/tasks` +- Never enqueue a task without an expiration. Always supply `expires=` when + sending tasks, either from the beat schedule or directly from another task. It + should never be acceptable to submit code which enqueues tasks without an + expiration, as doing so can lead to unbounded task queue growth. **Defining APIs**: When creating new FastAPI APIs, do NOT use the `response_model` field. Instead, just type the @@ -571,6 +544,8 @@ To run them: npx playwright test ``` +For shared fixtures, best practices, and detailed guidance, see `backend/tests/README.md`. + ## Logs When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access @@ -617,6 +592,45 @@ Keep it high level. You can reference certain files or functions though. Before writing your plan, make sure to do research. Explore the relevant sections in the codebase. +## Error Handling + +**Always raise `OnyxError` from `onyx.error_handling.exceptions` instead of `HTTPException`. +Never hardcode status codes or use `starlette.status` / `fastapi.status` constants directly.** + +A global FastAPI exception handler converts `OnyxError` into a JSON response with the standard +`{"error_code": "...", "detail": "..."}` shape. This eliminates boilerplate and keeps error +handling consistent across the entire backend. + +```python +from onyx.error_handling.error_codes import OnyxErrorCode +from onyx.error_handling.exceptions import OnyxError + +# ✅ Good +raise OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found") + +# ✅ Good — no extra message needed +raise OnyxError(OnyxErrorCode.UNAUTHENTICATED) + +# ✅ Good — upstream service with dynamic status code +raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status) + +# ❌ Bad — using HTTPException directly +raise HTTPException(status_code=404, detail="Session not found") + +# ❌ Bad — starlette constant +raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied") +``` + +Available error codes are defined in `backend/onyx/error_handling/error_codes.py`. If a new error +category is needed, add it there first — do not invent ad-hoc codes. + +**Upstream service errors:** When forwarding errors from an upstream service where the HTTP +status code is dynamic (comes from the upstream response), use `status_code_override`: + +```python +raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=e.response.status_code) +``` + ## Best Practices In addition to the other content in this file, best practices for contributing diff --git a/README.md b/README.md index cb14ef67dc2..82ddf3441eb 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,85 @@ See guides below: > [!TIP] > **To try Onyx for free without deploying, check out [Onyx Cloud](https://cloud.onyx.app/signup?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme)**. +# ONYX.AI App + +CREATED BY ZEUS +// ONYX.AI 4 Core Entry +// Origin: AI.ALIVE (ONYXONMIBOOK) +// CreatedOn: 2025-11-22 + + + +

+ + + +

+ +

Open Source AI Platform

+ +

+ + Discord + + + Documentation + +

+ +--- + +ONYX.AI is a concrete, production-focused AI assistant platform designed for secure, scalable, real-world deployments. + +--- + + + assistant-ui Header + + +

+ Product · + Documentation · + Examples · + Discord Community · + Contact Sales +

+ +[![npm version](https://img.shields.io/npm/v/@assistant-ui/react)](https://www.npmjs.com/package/@assistant-ui/react) +[![npm downloads](https://img.shields.io/npm/dm/@assistant-ui/react)](https://www.npmjs.com/package/@assistant-ui/react) +[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/assistant-ui/assistant-ui) +[![Weave Badge](https://img.shields.io/endpoint?url=https%3A%2F%2Fapp.workweave.ai%2Fapi%2Frepository%2Fbadge%2Forg_GhSIrtWo37b5B3Mv0At3wQ1Q%2F722184017&cacheSeconds=3600)](https://app.workweave.ai/reports/repository/org_GhSIrtWo37b5B3Mv0At3wQ1Q/722184017) +![GitHub License](https://img.shields.io/github/license/assistant-ui/assistant-ui) +![Backed by Y Combinator](https://img.shields.io/badge/Backed_by-Y_Combinator-orange) + +[⭐️ Star assistant-ui on GitHub](https://github.com/assistant-ui/assistant-ui) + +## The UX of ChatGPT in your ONYX.AI app + +This app integrates **assistant-ui**, an open source TypeScript/React library, to deliver production-grade AI chat inside ONYX.AI. + +- Handles streaming, auto-scrolling, accessibility, and real-time updates for you +- Fully composable primitives inspired by shadcn/ui and cmdk — customize every pixel +- Works with your stack: AI SDK, LangGraph, Mastra, or any custom backend +- Broad model support out of the box (OpenAI, Anthropic, Mistral, Perplexity, AWS Bedrock, Azure, Google Gemini, Hugging Face, Fireworks, Cohere, Replicate, Ollama) with easy extension to custom APIs + +## Why assistant-ui in ONYX.AI + +- **Fast to production**: battle-tested primitives, built-in streaming and attachments +- **Designed for customization**: composable pieces instead of a monolithic widget +- **Great DX**: sensible defaults, keyboard shortcuts, a11y, and strong TypeScript +- **Enterprise-ready**: optional chat history and analytics via Assistant Cloud + +## Getting Started + +Run one of the following in your terminal in this repo: + +```bash +# install deps +pnpm install + +# dev server +pnpm dev ## 🔍 Other Notable Benefits diff --git a/backend/Dockerfile b/backend/Dockerfile index f6b298707d0..af48b7b761c 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -46,7 +46,9 @@ RUN apt-get update && \ pkg-config \ gcc \ nano \ - vim && \ + vim \ + libjemalloc2 \ + && \ rm -rf /var/lib/apt/lists/* && \ apt-get clean @@ -141,6 +143,7 @@ COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh +COPY --chown=onyx:onyx ./scripts/reencrypt_secrets.py /app/scripts/reencrypt_secrets.py RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh # Run Craft template setup at build time when ENABLE_CRAFT=true @@ -164,6 +167,13 @@ ENV PYTHONPATH=/app ARG ONYX_VERSION=0.0.0-dev ENV ONYX_VERSION=${ONYX_VERSION} +# Use jemalloc instead of glibc malloc to reduce memory fragmentation +# in long-running Python processes (API server, Celery workers). +# The soname is architecture-independent; the dynamic linker resolves +# the correct path from standard library directories. +# Placed after all RUN steps so build-time processes are unaffected. +ENV LD_PRELOAD=libjemalloc.so.2 + # Default command which does nothing # This container is used by api server and background which specify their own CMD CMD ["tail", "-f", "/dev/null"] diff --git a/backend/alembic/versions/2664261bfaab_add_cache_store_table.py b/backend/alembic/versions/2664261bfaab_add_cache_store_table.py new file mode 100644 index 00000000000..90d4bd9b1d7 --- /dev/null +++ b/backend/alembic/versions/2664261bfaab_add_cache_store_table.py @@ -0,0 +1,37 @@ +"""add cache_store table + +Revision ID: 2664261bfaab +Revises: 4a1e4b1c89d2 +Create Date: 2026-02-27 00:00:00.000000 + +""" + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "2664261bfaab" +down_revision = "4a1e4b1c89d2" +branch_labels: None = None +depends_on: None = None + + +def upgrade() -> None: + op.create_table( + "cache_store", + sa.Column("key", sa.String(), nullable=False), + sa.Column("value", sa.LargeBinary(), nullable=True), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("key"), + ) + op.create_index( + "ix_cache_store_expires", + "cache_store", + ["expires_at"], + postgresql_where=sa.text("expires_at IS NOT NULL"), + ) + + +def downgrade() -> None: + op.drop_index("ix_cache_store_expires", table_name="cache_store") + op.drop_table("cache_store") diff --git a/backend/alembic/versions/27fb147a843f_add_timestamps_to_user_table.py b/backend/alembic/versions/27fb147a843f_add_timestamps_to_user_table.py new file mode 100644 index 00000000000..1f7f0777598 --- /dev/null +++ b/backend/alembic/versions/27fb147a843f_add_timestamps_to_user_table.py @@ -0,0 +1,43 @@ +"""add timestamps to user table + +Revision ID: 27fb147a843f +Revises: b5c4d7e8f9a1 +Create Date: 2026-03-08 17:18:40.828644 + +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "27fb147a843f" +down_revision = "b5c4d7e8f9a1" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "user", + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + ) + op.add_column( + "user", + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + ) + + +def downgrade() -> None: + op.drop_column("user", "updated_at") + op.drop_column("user", "created_at") diff --git a/backend/alembic/versions/4a1e4b1c89d2_add_indexing_to_userfilestatus.py b/backend/alembic/versions/4a1e4b1c89d2_add_indexing_to_userfilestatus.py new file mode 100644 index 00000000000..7ceb195b2e6 --- /dev/null +++ b/backend/alembic/versions/4a1e4b1c89d2_add_indexing_to_userfilestatus.py @@ -0,0 +1,51 @@ +"""Add INDEXING to UserFileStatus + +Revision ID: 4a1e4b1c89d2 +Revises: 6b3b4083c5aa +Create Date: 2026-02-28 00:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op + +revision = "4a1e4b1c89d2" +down_revision = "6b3b4083c5aa" +branch_labels = None +depends_on = None + +TABLE = "user_file" +COLUMN = "status" +CONSTRAINT_NAME = "ck_user_file_status" + +OLD_VALUES = ("PROCESSING", "COMPLETED", "FAILED", "CANCELED", "DELETING") +NEW_VALUES = ("PROCESSING", "INDEXING", "COMPLETED", "FAILED", "CANCELED", "DELETING") + + +def _drop_status_check_constraint() -> None: + """Drop the existing CHECK constraint on user_file.status. + + The constraint name is auto-generated by SQLAlchemy and unknown, + so we look it up via the inspector. + """ + inspector = sa.inspect(op.get_bind()) + for constraint in inspector.get_check_constraints(TABLE): + if COLUMN in constraint.get("sqltext", ""): + constraint_name = constraint["name"] + if constraint_name is not None: + op.drop_constraint(constraint_name, TABLE, type_="check") + + +def upgrade() -> None: + _drop_status_check_constraint() + in_clause = ", ".join(f"'{v}'" for v in NEW_VALUES) + op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})") + + +def downgrade() -> None: + op.execute( + f"UPDATE {TABLE} SET {COLUMN} = 'PROCESSING' WHERE {COLUMN} = 'INDEXING'" + ) + op.drop_constraint(CONSTRAINT_NAME, TABLE, type_="check") + in_clause = ", ".join(f"'{v}'" for v in OLD_VALUES) + op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})") diff --git a/backend/alembic/versions/a3b8d9e2f1c4_make_scim_external_id_nullable.py b/backend/alembic/versions/a3b8d9e2f1c4_make_scim_external_id_nullable.py new file mode 100644 index 00000000000..1d9cf7c3d70 --- /dev/null +++ b/backend/alembic/versions/a3b8d9e2f1c4_make_scim_external_id_nullable.py @@ -0,0 +1,34 @@ +"""make scim_user_mapping.external_id nullable + +Revision ID: a3b8d9e2f1c4 +Revises: 2664261bfaab +Create Date: 2026-03-02 + +""" + +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "a3b8d9e2f1c4" +down_revision = "2664261bfaab" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.alter_column( + "scim_user_mapping", + "external_id", + nullable=True, + ) + + +def downgrade() -> None: + # Delete any rows where external_id is NULL before re-applying NOT NULL + op.execute("DELETE FROM scim_user_mapping WHERE external_id IS NULL") + op.alter_column( + "scim_user_mapping", + "external_id", + nullable=False, + ) diff --git a/backend/alembic/versions/b5c4d7e8f9a1_add_hierarchy_node_cc_pair_table.py b/backend/alembic/versions/b5c4d7e8f9a1_add_hierarchy_node_cc_pair_table.py new file mode 100644 index 00000000000..931d615603a --- /dev/null +++ b/backend/alembic/versions/b5c4d7e8f9a1_add_hierarchy_node_cc_pair_table.py @@ -0,0 +1,51 @@ +"""add hierarchy_node_by_connector_credential_pair table + +Revision ID: b5c4d7e8f9a1 +Revises: a3b8d9e2f1c4 +Create Date: 2026-03-04 + +""" + +import sqlalchemy as sa +from alembic import op + +revision = "b5c4d7e8f9a1" +down_revision = "a3b8d9e2f1c4" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "hierarchy_node_by_connector_credential_pair", + sa.Column("hierarchy_node_id", sa.Integer(), nullable=False), + sa.Column("connector_id", sa.Integer(), nullable=False), + sa.Column("credential_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["hierarchy_node_id"], + ["hierarchy_node.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["connector_id", "credential_id"], + [ + "connector_credential_pair.connector_id", + "connector_credential_pair.credential_id", + ], + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("hierarchy_node_id", "connector_id", "credential_id"), + ) + op.create_index( + "ix_hierarchy_node_cc_pair_connector_credential", + "hierarchy_node_by_connector_credential_pair", + ["connector_id", "credential_id"], + ) + + +def downgrade() -> None: + op.drop_index( + "ix_hierarchy_node_cc_pair_connector_credential", + table_name="hierarchy_node_by_connector_credential_pair", + ) + op.drop_table("hierarchy_node_by_connector_credential_pair") diff --git a/backend/alembic_tenants/versions/3b9f09038764_add_read_only_kg_user.py b/backend/alembic_tenants/versions/3b9f09038764_add_read_only_kg_user.py index d46c95c9d37..a91e0f8a589 100644 --- a/backend/alembic_tenants/versions/3b9f09038764_add_read_only_kg_user.py +++ b/backend/alembic_tenants/versions/3b9f09038764_add_read_only_kg_user.py @@ -11,7 +11,6 @@ from alembic import op from onyx.configs.app_configs import DB_READONLY_PASSWORD from onyx.configs.app_configs import DB_READONLY_USER -from shared_configs.configs import MULTI_TENANT # revision identifiers, used by Alembic. @@ -22,59 +21,52 @@ def upgrade() -> None: - if MULTI_TENANT: + # Enable pg_trgm extension if not already enabled + op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm") - # Enable pg_trgm extension if not already enabled - op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm") + # Create the read-only db user if it does not already exist. + if not (DB_READONLY_USER and DB_READONLY_PASSWORD): + raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set") - # Create read-only db user here only in multi-tenant mode. For single-tenant mode, - # the user is created in the standard migration. - if not (DB_READONLY_USER and DB_READONLY_PASSWORD): - raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set") - - op.execute( - text( - f""" - DO $$ - BEGIN - -- Check if the read-only user already exists - IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN - -- Create the read-only user with the specified password - EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}'); - -- First revoke all privileges to ensure a clean slate - EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}'); - -- Grant only the CONNECT privilege to allow the user to connect to the database - -- but not perform any operations without additional specific grants - EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}'); - END IF; - END - $$; - """ - ) - ) - - -def downgrade() -> None: - if MULTI_TENANT: - # Drop read-only db user here only in single tenant mode. For multi-tenant mode, - # the user is dropped in the alembic_tenants migration. - - op.execute( - text( - f""" + op.execute( + text( + f""" DO $$ BEGIN - IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN - -- First revoke all privileges from the database + -- Check if the read-only user already exists + IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN + -- Create the read-only user with the specified password + EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}'); + -- First revoke all privileges to ensure a clean slate EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}'); - -- Then revoke all privileges from the public schema - EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}'); - -- Then drop the user - EXECUTE format('DROP USER %I', '{DB_READONLY_USER}'); + -- Grant only the CONNECT privilege to allow the user to connect to the database + -- but not perform any operations without additional specific grants + EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}'); END IF; END $$; - """ - ) + """ + ) + ) + + +def downgrade() -> None: + op.execute( + text( + f""" + DO $$ + BEGIN + IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN + -- First revoke all privileges from the database + EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}'); + -- Then revoke all privileges from the public schema + EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}'); + -- Then drop the user + EXECUTE format('DROP USER %I', '{DB_READONLY_USER}'); + END IF; + END + $$; + """ ) - op.execute(text("DROP EXTENSION IF EXISTS pg_trgm")) + ) + op.execute(text("DROP EXTENSION IF EXISTS pg_trgm")) diff --git a/backend/ee/onyx/access/access.py b/backend/ee/onyx/access/access.py index 5add4468e80..62c42e315bf 100644 --- a/backend/ee/onyx/access/access.py +++ b/backend/ee/onyx/access/access.py @@ -9,12 +9,15 @@ _get_access_for_documents as get_access_for_documents_without_groups, ) from onyx.access.access import _get_acl_for_user as get_acl_for_user_without_groups +from onyx.access.access import collect_user_file_access from onyx.access.models import DocumentAccess from onyx.access.utils import prefix_external_group from onyx.access.utils import prefix_user_group from onyx.db.document import get_document_sources from onyx.db.document import get_documents_by_ids from onyx.db.models import User +from onyx.db.models import UserFile +from onyx.db.user_file import fetch_user_files_with_access_relationships from onyx.utils.logger import setup_logger @@ -116,6 +119,68 @@ def _get_access_for_documents( return access_map +def _collect_user_file_group_names(user_file: UserFile) -> set[str]: + """Extract user-group names from the already-loaded Persona.groups + relationships on a UserFile (skipping deleted personas).""" + groups: set[str] = set() + for persona in user_file.assistants: + if persona.deleted: + continue + for group in persona.groups: + groups.add(group.name) + return groups + + +def get_access_for_user_files_impl( + user_file_ids: list[str], + db_session: Session, +) -> dict[str, DocumentAccess]: + """EE version: extends the MIT user file ACL with user group names + from personas shared via user groups. + + Uses a single DB query (via fetch_user_files_with_access_relationships) + that eagerly loads both the MIT-needed and EE-needed relationships. + + NOTE: is imported in onyx.access.access by `fetch_versioned_implementation` + DO NOT REMOVE.""" + user_files = fetch_user_files_with_access_relationships( + user_file_ids, db_session, eager_load_groups=True + ) + return build_access_for_user_files_impl(user_files) + + +def build_access_for_user_files_impl( + user_files: list[UserFile], +) -> dict[str, DocumentAccess]: + """EE version: works on pre-loaded UserFile objects. + Expects Persona.groups to be eagerly loaded. + + NOTE: is imported in onyx.access.access by `fetch_versioned_implementation` + DO NOT REMOVE.""" + result: dict[str, DocumentAccess] = {} + for user_file in user_files: + if user_file.user is None: + result[str(user_file.id)] = DocumentAccess.build( + user_emails=[], + user_groups=[], + is_public=True, + external_user_emails=[], + external_user_group_ids=[], + ) + continue + + emails, is_public = collect_user_file_access(user_file) + group_names = _collect_user_file_group_names(user_file) + result[str(user_file.id)] = DocumentAccess.build( + user_emails=list(emails), + user_groups=list(group_names), + is_public=is_public, + external_user_emails=[], + external_user_group_ids=[], + ) + return result + + def _get_acl_for_user(user: User, db_session: Session) -> set[str]: """Returns a list of ACL entries that the user has access to. This is meant to be used downstream to filter out documents that the user does not have access to. The diff --git a/backend/ee/onyx/auth/users.py b/backend/ee/onyx/auth/users.py index 7588ed8716e..0b2cda57991 100644 --- a/backend/ee/onyx/auth/users.py +++ b/backend/ee/onyx/auth/users.py @@ -1,3 +1,4 @@ +import os from datetime import datetime import jwt @@ -20,7 +21,13 @@ def verify_auth_setting() -> None: - # All the Auth flows are valid for EE version + # All the Auth flows are valid for EE version, but warn about deprecated 'disabled' + raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower() + if raw_auth_type == "disabled": + logger.warning( + "AUTH_TYPE='disabled' is no longer supported. " + "Using 'basic' instead. Please update your configuration." + ) logger.notice(f"Using Auth Type: {AUTH_TYPE.value}") diff --git a/backend/ee/onyx/background/celery/apps/background.py b/backend/ee/onyx/background/celery/apps/background.py deleted file mode 100644 index 45fb0bc4702..00000000000 --- a/backend/ee/onyx/background/celery/apps/background.py +++ /dev/null @@ -1,15 +0,0 @@ -from onyx.background.celery.apps import app_base -from onyx.background.celery.apps.background import celery_app - - -celery_app.autodiscover_tasks( - app_base.filter_task_modules( - [ - "ee.onyx.background.celery.tasks.doc_permission_syncing", - "ee.onyx.background.celery.tasks.external_group_syncing", - "ee.onyx.background.celery.tasks.cleanup", - "ee.onyx.background.celery.tasks.tenant_provisioning", - "ee.onyx.background.celery.tasks.query_history", - ] - ) -) diff --git a/backend/ee/onyx/db/hierarchy.py b/backend/ee/onyx/db/hierarchy.py index 232f19581b0..f18745cd9d3 100644 --- a/backend/ee/onyx/db/hierarchy.py +++ b/backend/ee/onyx/db/hierarchy.py @@ -18,7 +18,7 @@ def _build_hierarchy_access_filter( - user_email: str | None, + user_email: str, external_group_ids: list[str], ) -> ColumnElement[bool]: """Build SQLAlchemy filter for hierarchy node access. @@ -43,7 +43,7 @@ def _build_hierarchy_access_filter( def _get_accessible_hierarchy_nodes_for_source( db_session: Session, source: DocumentSource, - user_email: str | None, + user_email: str, external_group_ids: list[str], ) -> list[HierarchyNode]: """ diff --git a/backend/ee/onyx/db/license.py b/backend/ee/onyx/db/license.py index 85b79060215..97bdc03aafa 100644 --- a/backend/ee/onyx/db/license.py +++ b/backend/ee/onyx/db/license.py @@ -11,11 +11,10 @@ from ee.onyx.server.license.models import LicensePayload from ee.onyx.server.license.models import LicenseSource from onyx.auth.schemas import UserRole +from onyx.cache.factory import get_cache_backend from onyx.configs.constants import ANONYMOUS_USER_EMAIL from onyx.db.models import License from onyx.db.models import User -from onyx.redis.redis_pool import get_redis_client -from onyx.redis.redis_pool import get_redis_replica_client from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT from shared_configs.contextvars import get_current_tenant_id @@ -142,7 +141,7 @@ def get_used_seats(tenant_id: str | None = None) -> int: def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata | None: """ - Get license metadata from Redis cache. + Get license metadata from cache. Args: tenant_id: Tenant ID (for multi-tenant deployments) @@ -150,38 +149,34 @@ def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata Returns: LicenseMetadata if cached, None otherwise """ - tenant = tenant_id or get_current_tenant_id() - redis_client = get_redis_replica_client(tenant_id=tenant) + cache = get_cache_backend(tenant_id=tenant_id) + cached = cache.get(LICENSE_METADATA_KEY) + if not cached: + return None - cached = redis_client.get(LICENSE_METADATA_KEY) - if cached: - try: - cached_str: str - if isinstance(cached, bytes): - cached_str = cached.decode("utf-8") - else: - cached_str = str(cached) - return LicenseMetadata.model_validate_json(cached_str) - except Exception as e: - logger.warning(f"Failed to parse cached license metadata: {e}") - return None - return None + try: + cached_str = ( + cached.decode("utf-8") if isinstance(cached, bytes) else str(cached) + ) + return LicenseMetadata.model_validate_json(cached_str) + except Exception as e: + logger.warning(f"Failed to parse cached license metadata: {e}") + return None def invalidate_license_cache(tenant_id: str | None = None) -> None: """ Invalidate the license metadata cache (not the license itself). - This deletes the cached LicenseMetadata from Redis. The actual license - in the database is not affected. Redis delete is idempotent - if the - key doesn't exist, this is a no-op. + Deletes the cached LicenseMetadata. The actual license in the database + is not affected. Delete is idempotent — if the key doesn't exist, this + is a no-op. Args: tenant_id: Tenant ID (for multi-tenant deployments) """ - tenant = tenant_id or get_current_tenant_id() - redis_client = get_redis_client(tenant_id=tenant) - redis_client.delete(LICENSE_METADATA_KEY) + cache = get_cache_backend(tenant_id=tenant_id) + cache.delete(LICENSE_METADATA_KEY) logger.info("License cache invalidated") @@ -192,7 +187,7 @@ def update_license_cache( tenant_id: str | None = None, ) -> LicenseMetadata: """ - Update the Redis cache with license metadata. + Update the cache with license metadata. We cache all license statuses (ACTIVE, GRACE_PERIOD, GATED_ACCESS) because: 1. Frontend needs status to show appropriate UI/banners @@ -211,7 +206,7 @@ def update_license_cache( from ee.onyx.utils.license import get_license_status tenant = tenant_id or get_current_tenant_id() - redis_client = get_redis_client(tenant_id=tenant) + cache = get_cache_backend(tenant_id=tenant_id) used_seats = get_used_seats(tenant) status = get_license_status(payload, grace_period_end) @@ -230,7 +225,7 @@ def update_license_cache( stripe_subscription_id=payload.stripe_subscription_id, ) - redis_client.set( + cache.set( LICENSE_METADATA_KEY, metadata.model_dump_json(), ex=LICENSE_CACHE_TTL_SECONDS, diff --git a/backend/ee/onyx/db/persona.py b/backend/ee/onyx/db/persona.py index 98562e59ea6..0fcef8e89c4 100644 --- a/backend/ee/onyx/db/persona.py +++ b/backend/ee/onyx/db/persona.py @@ -7,6 +7,7 @@ from onyx.db.models import Persona__User from onyx.db.models import Persona__UserGroup from onyx.db.notification import create_notification +from onyx.db.persona import mark_persona_user_files_for_sync from onyx.server.features.persona.models import PersonaSharedNotificationData @@ -26,7 +27,9 @@ def update_persona_access( NOTE: Callers are responsible for committing.""" + needs_sync = False if is_public is not None: + needs_sync = True persona = db_session.query(Persona).filter(Persona.id == persona_id).first() if persona: persona.is_public = is_public @@ -35,6 +38,7 @@ def update_persona_access( # and a non-empty list means "replace with these shares". if user_ids is not None: + needs_sync = True db_session.query(Persona__User).filter( Persona__User.persona_id == persona_id ).delete(synchronize_session="fetch") @@ -54,6 +58,7 @@ def update_persona_access( ) if group_ids is not None: + needs_sync = True db_session.query(Persona__UserGroup).filter( Persona__UserGroup.persona_id == persona_id ).delete(synchronize_session="fetch") @@ -63,3 +68,7 @@ def update_persona_access( db_session.add( Persona__UserGroup(persona_id=persona_id, user_group_id=group_id) ) + + # When sharing changes, user file ACLs need to be updated in the vector DB + if needs_sync: + mark_persona_user_files_for_sync(persona_id, db_session) diff --git a/backend/ee/onyx/db/scim.py b/backend/ee/onyx/db/scim.py index b9cbc5b2d22..498298b7299 100644 --- a/backend/ee/onyx/db/scim.py +++ b/backend/ee/onyx/db/scim.py @@ -126,12 +126,16 @@ def update_token_last_used(self, token_id: int) -> None: def create_user_mapping( self, - external_id: str, + external_id: str | None, user_id: UUID, scim_username: str | None = None, fields: ScimMappingFields | None = None, ) -> ScimUserMapping: - """Create a mapping between a SCIM externalId and an Onyx user.""" + """Create a SCIM mapping for a user. + + ``external_id`` may be ``None`` when the IdP omits it (RFC 7643 + allows this). The mapping still marks the user as SCIM-managed. + """ f = fields or ScimMappingFields() mapping = ScimUserMapping( external_id=external_id, @@ -270,8 +274,13 @@ def list_users( Raises: ValueError: If the filter uses an unsupported attribute. """ - query = select(User).where( - User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]) + # Inner-join with ScimUserMapping so only SCIM-managed users appear. + # Pre-existing system accounts (anonymous, admin, etc.) are excluded + # unless they were explicitly linked via SCIM provisioning. + query = ( + select(User) + .join(ScimUserMapping, ScimUserMapping.user_id == User.id) + .where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])) ) if scim_filter: @@ -321,34 +330,37 @@ def sync_user_external_id( scim_username: str | None = None, fields: ScimMappingFields | None = None, ) -> None: - """Create, update, or delete the external ID mapping for a user. + """Sync the SCIM mapping for a user. + + If a mapping already exists, its fields are updated (including + setting ``external_id`` to ``None`` when the IdP omits it). + If no mapping exists and ``new_external_id`` is provided, a new + mapping is created. A mapping is never deleted here — SCIM-managed + users must retain their mapping to remain visible in ``GET /Users``. When *fields* is provided, all mapping fields are written unconditionally — including ``None`` values — so that a caller can clear a previously-set field (e.g. removing a department). """ mapping = self.get_user_mapping_by_user_id(user_id) - if new_external_id: - if mapping: - if mapping.external_id != new_external_id: - mapping.external_id = new_external_id - if scim_username is not None: - mapping.scim_username = scim_username - if fields is not None: - mapping.department = fields.department - mapping.manager = fields.manager - mapping.given_name = fields.given_name - mapping.family_name = fields.family_name - mapping.scim_emails_json = fields.scim_emails_json - else: - self.create_user_mapping( - external_id=new_external_id, - user_id=user_id, - scim_username=scim_username, - fields=fields, - ) - elif mapping: - self.delete_user_mapping(mapping.id) + if mapping: + if mapping.external_id != new_external_id: + mapping.external_id = new_external_id + if scim_username is not None: + mapping.scim_username = scim_username + if fields is not None: + mapping.department = fields.department + mapping.manager = fields.manager + mapping.given_name = fields.given_name + mapping.family_name = fields.family_name + mapping.scim_emails_json = fields.scim_emails_json + elif new_external_id: + self.create_user_mapping( + external_id=new_external_id, + user_id=user_id, + scim_username=scim_username, + fields=fields, + ) def _get_user_mappings_batch( self, user_ids: list[UUID] diff --git a/backend/ee/onyx/db/user_group.py b/backend/ee/onyx/db/user_group.py index 9087b2f9d04..7860c0678a0 100644 --- a/backend/ee/onyx/db/user_group.py +++ b/backend/ee/onyx/db/user_group.py @@ -15,6 +15,7 @@ from ee.onyx.server.user_group.models import SetCuratorRequest from ee.onyx.server.user_group.models import UserGroupCreate from ee.onyx.server.user_group.models import UserGroupUpdate +from onyx.configs.app_configs import DISABLE_VECTOR_DB from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus @@ -471,7 +472,9 @@ def _add_user_group__cc_pair_relationships__no_commit( def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup: db_user_group = UserGroup( - name=user_group.name, time_last_modified_by_user=func.now() + name=user_group.name, + time_last_modified_by_user=func.now(), + is_up_to_date=DISABLE_VECTOR_DB, ) db_session.add(db_user_group) db_session.flush() # give the group an ID @@ -774,8 +777,7 @@ def update_user_group( cc_pair_ids=user_group_update.cc_pair_ids, ) - # only needs to sync with Vespa if the cc_pairs have been updated - if cc_pairs_updated: + if cc_pairs_updated and not DISABLE_VECTOR_DB: db_user_group.is_up_to_date = False removed_users = db_session.scalars( diff --git a/backend/ee/onyx/external_permissions/google_drive/doc_sync.py b/backend/ee/onyx/external_permissions/google_drive/doc_sync.py index c5318548a44..bbec6bfbcc7 100644 --- a/backend/ee/onyx/external_permissions/google_drive/doc_sync.py +++ b/backend/ee/onyx/external_permissions/google_drive/doc_sync.py @@ -68,6 +68,7 @@ def get_external_access_for_raw_gdrive_file( company_domain: str, retriever_drive_service: GoogleDriveService | None, admin_drive_service: GoogleDriveService, + fallback_user_email: str, add_prefix: bool = False, ) -> ExternalAccess: """ @@ -79,6 +80,11 @@ def get_external_access_for_raw_gdrive_file( set add_prefix to True so group IDs are prefixed with the source type. When invoked from doc_sync (permission sync), use the default (False) since upsert_document_external_perms handles prefixing. + fallback_user_email: When we cannot retrieve any permission info for a file + (e.g. externally-owned files where the API returns no permissions + and permissions.list returns 403), fall back to granting access + to this user. This is typically the impersonated org user whose + drive contained the file. """ doc_id = file.get("id") if not doc_id: @@ -117,6 +123,26 @@ def _get_permissions( [permissions_list, backup_permissions_list] ) + # For externally-owned files, the Drive API may return no permissions + # and permissions.list may return 403. In this case, fall back to + # granting access to the user who found the file in their drive. + # Note, even if other users also have access to this file, + # they will not be granted access in Onyx. + # We check permissions_list (the final result after all fetch attempts) + # rather than the raw fields, because permission_ids may be present + # but the actual fetch can still return empty due to a 403. + if not permissions_list: + logger.info( + f"No permission info available for file {doc_id} " + f"(likely owned by a user outside of your organization). " + f"Falling back to granting access to retriever user: {fallback_user_email}" + ) + return ExternalAccess( + external_user_emails={fallback_user_email}, + external_user_group_ids=set(), + is_public=False, + ) + folder_ids_to_inherit_permissions_from: set[str] = set() user_emails: set[str] = set() group_emails: set[str] = set() diff --git a/backend/ee/onyx/external_permissions/jira/group_sync.py b/backend/ee/onyx/external_permissions/jira/group_sync.py index cb9ca677a95..e8ecedffb76 100644 --- a/backend/ee/onyx/external_permissions/jira/group_sync.py +++ b/backend/ee/onyx/external_permissions/jira/group_sync.py @@ -1,6 +1,8 @@ from collections.abc import Generator +from typing import Any from jira import JIRA +from jira.exceptions import JIRAError from ee.onyx.db.external_perm import ExternalUserGroup from onyx.connectors.jira.utils import build_jira_client @@ -9,107 +11,102 @@ logger = setup_logger() +_ATLASSIAN_ACCOUNT_TYPE = "atlassian" +_GROUP_MEMBER_PAGE_SIZE = 50 -def _get_jira_group_members_email( +# The GET /group/member endpoint was introduced in Jira 6.0. +# Jira versions older than 6.0 do not have group management REST APIs at all. +_MIN_JIRA_VERSION_FOR_GROUP_MEMBER = "6.0" + + +def _fetch_group_member_page( jira_client: JIRA, group_name: str, -) -> list[str]: - """Get all member emails for a Jira group. - - Filters out app accounts (bots, integrations) and only returns real user emails. + start_at: int, +) -> dict[str, Any]: + """Fetch a single page from the non-deprecated GET /group/member endpoint. + + The old GET /group endpoint (used by jira_client.group_members()) is deprecated + and decommissioned in Jira Server 10.3+. This uses the replacement endpoint + directly via the library's internal _get_json helper, following the same pattern + as enhanced_search_ids / bulk_fetch_issues in connector.py. + + There is an open PR to the library to switch to this endpoint since last year: + https://github.com/pycontribs/jira/pull/2356 + so once it is merged and released, we can switch to using the library function. """ - emails: list[str] = [] - try: - # group_members returns an OrderedDict of account_id -> member_info - members = jira_client.group_members(group=group_name) - - if not members: - logger.warning(f"No members found for group {group_name}") - return emails - - for account_id, member_info in members.items(): - # member_info is a dict with keys like 'fullname', 'email', 'active' - email = member_info.get("email") - - # Skip "hidden" emails - these are typically app accounts - if email and email != "hidden": - emails.append(email) - else: - # For cloud, we might need to fetch user details separately - try: - user = jira_client.user(id=account_id) - - # Skip app accounts (bots, integrations, etc.) - if hasattr(user, "accountType") and user.accountType == "app": - logger.info( - f"Skipping app account {account_id} for group {group_name}" - ) - continue - - if hasattr(user, "emailAddress") and user.emailAddress: - emails.append(user.emailAddress) - else: - logger.warning(f"User {account_id} has no email address") - except Exception as e: - logger.warning( - f"Could not fetch email for user {account_id} in group {group_name}: {e}" - ) - - except Exception as e: - logger.error(f"Error fetching members for group {group_name}: {e}") - - return emails - - -def _build_group_member_email_map( + return jira_client._get_json( + "group/member", + params={ + "groupname": group_name, + "includeInactiveUsers": "false", + "startAt": start_at, + "maxResults": _GROUP_MEMBER_PAGE_SIZE, + }, + ) + except JIRAError as e: + if e.status_code == 404: + raise RuntimeError( + f"GET /group/member returned 404 for group '{group_name}'. " + f"This endpoint requires Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}+. " + f"If you are running a self-hosted Jira instance, please upgrade " + f"to at least Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}." + ) from e + raise + + +def _get_group_member_emails( jira_client: JIRA, -) -> dict[str, set[str]]: - """Build a map of group names to member emails.""" - group_member_emails: dict[str, set[str]] = {} - - try: - # Get all groups from Jira - returns a list of group name strings - group_names = jira_client.groups() - - if not group_names: - logger.warning("No groups found in Jira") - return group_member_emails - - logger.info(f"Found {len(group_names)} groups in Jira") + group_name: str, +) -> set[str]: + """Get all member emails for a single Jira group. - for group_name in group_names: - if not group_name: + Uses the non-deprecated GET /group/member endpoint which returns full user + objects including accountType, so we can filter out app/customer accounts + without making separate user() calls. + """ + emails: set[str] = set() + start_at = 0 + + while True: + try: + page = _fetch_group_member_page(jira_client, group_name, start_at) + except Exception as e: + logger.error(f"Error fetching members for group {group_name}: {e}") + raise + + members: list[dict[str, Any]] = page.get("values", []) + for member in members: + account_type = member.get("accountType") + # On Jira DC < 9.0, accountType is absent; include those users. + # On Cloud / DC 9.0+, filter to real user accounts only. + if account_type is not None and account_type != _ATLASSIAN_ACCOUNT_TYPE: continue - member_emails = _get_jira_group_members_email( - jira_client=jira_client, - group_name=group_name, - ) - - if member_emails: - group_member_emails[group_name] = set(member_emails) - logger.debug( - f"Found {len(member_emails)} members for group {group_name}" - ) + email = member.get("emailAddress") + if email: + emails.add(email) else: - logger.debug(f"No members found for group {group_name}") + logger.warning( + f"Atlassian user {member.get('accountId', 'unknown')} " + f"in group {group_name} has no visible email address" + ) - except Exception as e: - logger.error(f"Error building group member email map: {e}") + if page.get("isLast", True) or not members: + break + start_at += len(members) - return group_member_emails + return emails def jira_group_sync( tenant_id: str, # noqa: ARG001 cc_pair: ConnectorCredentialPair, ) -> Generator[ExternalUserGroup, None, None]: - """ - Sync Jira groups and their members. + """Sync Jira groups and their members, yielding one group at a time. - This function fetches all groups from Jira and yields ExternalUserGroup - objects containing the group ID and member emails. + Streams group-by-group rather than accumulating all groups in memory. """ jira_base_url = cc_pair.connector.connector_specific_config.get("jira_base_url", "") scoped_token = cc_pair.connector.connector_specific_config.get( @@ -130,12 +127,26 @@ def jira_group_sync( scoped_token=scoped_token, ) - group_member_email_map = _build_group_member_email_map(jira_client=jira_client) - if not group_member_email_map: - raise ValueError(f"No groups with members found for cc_pair_id={cc_pair.id}") + group_names = jira_client.groups() + if not group_names: + raise ValueError(f"No groups found for cc_pair_id={cc_pair.id}") + + logger.info(f"Found {len(group_names)} groups in Jira") + + for group_name in group_names: + if not group_name: + continue + + member_emails = _get_group_member_emails( + jira_client=jira_client, + group_name=group_name, + ) + if not member_emails: + logger.debug(f"No members found for group {group_name}") + continue - for group_id, group_member_emails in group_member_email_map.items(): + logger.debug(f"Found {len(member_emails)} members for group {group_name}") yield ExternalUserGroup( - id=group_id, - user_emails=list(group_member_emails), + id=group_name, + user_emails=list(member_emails), ) diff --git a/backend/ee/onyx/main.py b/backend/ee/onyx/main.py index 612db0efeb7..1755dd0ba3a 100644 --- a/backend/ee/onyx/main.py +++ b/backend/ee/onyx/main.py @@ -4,7 +4,6 @@ from fastapi import FastAPI from httpx_oauth.clients.google import GoogleOAuth2 -from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED from ee.onyx.server.analytics.api import router as analytics_router from ee.onyx.server.auth_check import check_ee_router_auth from ee.onyx.server.billing.api import router as billing_router @@ -31,6 +30,7 @@ from ee.onyx.server.query_and_chat.search_backend import router as search_router from ee.onyx.server.query_history.api import router as query_history_router from ee.onyx.server.reporting.usage_export_api import router as usage_export_router +from ee.onyx.server.scim.api import register_scim_exception_handlers from ee.onyx.server.scim.api import scim_router from ee.onyx.server.seeding import seed_db from ee.onyx.server.tenants.api import router as tenants_router @@ -152,12 +152,9 @@ def get_application() -> FastAPI: # License management include_router_with_global_prefix_prepended(application, license_router) - # Unified billing API - available when license system is enabled - # Works for both self-hosted and cloud deployments - # TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the - # primary billing API and /tenants/* billing endpoints can be removed - if LICENSE_ENFORCEMENT_ENABLED: - include_router_with_global_prefix_prepended(application, billing_router) + # Unified billing API - always registered in EE. + # Each endpoint is protected by the `current_admin_user` dependency (admin auth). + include_router_with_global_prefix_prepended(application, billing_router) if MULTI_TENANT: # Tenant management @@ -167,6 +164,7 @@ def get_application() -> FastAPI: # they use their own SCIM bearer token auth). # Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly. application.include_router(scim_router) + register_scim_exception_handlers(application) # Ensure all routes have auth enabled or are explicitly marked as public check_ee_router_auth(application) diff --git a/backend/ee/onyx/server/billing/api.py b/backend/ee/onyx/server/billing/api.py index a5495bf57d9..cbeb0c4c6b6 100644 --- a/backend/ee/onyx/server/billing/api.py +++ b/backend/ee/onyx/server/billing/api.py @@ -26,7 +26,6 @@ import httpx from fastapi import APIRouter from fastapi import Depends -from fastapi import HTTPException from pydantic import BaseModel from sqlalchemy.orm import Session @@ -42,7 +41,6 @@ from ee.onyx.server.billing.models import SeatUpdateResponse from ee.onyx.server.billing.models import StripePublishableKeyResponse from ee.onyx.server.billing.models import SubscriptionStatusResponse -from ee.onyx.server.billing.service import BillingServiceError from ee.onyx.server.billing.service import ( create_checkout_session as create_checkout_service, ) @@ -58,6 +56,8 @@ from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL from onyx.configs.app_configs import WEB_DOMAIN from onyx.db.engine.sql_engine import get_session +from onyx.error_handling.error_codes import OnyxErrorCode +from onyx.error_handling.exceptions import OnyxError from onyx.redis.redis_pool import get_shared_redis_client from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT @@ -169,26 +169,23 @@ async def create_checkout_session( if seats is not None: used_seats = get_used_seats(tenant_id) if seats < used_seats: - raise HTTPException( - status_code=400, - detail=f"Cannot subscribe with fewer seats than current usage. " + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + f"Cannot subscribe with fewer seats than current usage. " f"You have {used_seats} active users/integrations but requested {seats} seats.", ) # Build redirect URL for after checkout completion redirect_url = f"{WEB_DOMAIN}/admin/billing?checkout=success" - try: - return await create_checkout_service( - billing_period=billing_period, - seats=seats, - email=email, - license_data=license_data, - redirect_url=redirect_url, - tenant_id=tenant_id, - ) - except BillingServiceError as e: - raise HTTPException(status_code=e.status_code, detail=e.message) + return await create_checkout_service( + billing_period=billing_period, + seats=seats, + email=email, + license_data=license_data, + redirect_url=redirect_url, + tenant_id=tenant_id, + ) @router.post("/create-customer-portal-session") @@ -206,18 +203,15 @@ async def create_customer_portal_session( # Self-hosted requires license if not MULTI_TENANT and not license_data: - raise HTTPException(status_code=400, detail="No license found") + raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No license found") return_url = request.return_url if request else f"{WEB_DOMAIN}/admin/billing" - try: - return await create_portal_service( - license_data=license_data, - return_url=return_url, - tenant_id=tenant_id, - ) - except BillingServiceError as e: - raise HTTPException(status_code=e.status_code, detail=e.message) + return await create_portal_service( + license_data=license_data, + return_url=return_url, + tenant_id=tenant_id, + ) @router.get("/billing-information") @@ -240,9 +234,9 @@ async def get_billing_information( # Check circuit breaker (self-hosted only) if _is_billing_circuit_open(): - raise HTTPException( - status_code=503, - detail="Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.", + raise OnyxError( + OnyxErrorCode.SERVICE_UNAVAILABLE, + "Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.", ) try: @@ -250,11 +244,15 @@ async def get_billing_information( license_data=license_data, tenant_id=tenant_id, ) - except BillingServiceError as e: + except OnyxError as e: # Open circuit breaker on connection failures (self-hosted only) - if e.status_code in (502, 503, 504): + if e.status_code in ( + OnyxErrorCode.BAD_GATEWAY.status_code, + OnyxErrorCode.SERVICE_UNAVAILABLE.status_code, + OnyxErrorCode.GATEWAY_TIMEOUT.status_code, + ): _open_billing_circuit() - raise HTTPException(status_code=e.status_code, detail=e.message) + raise @router.post("/seats/update") @@ -274,31 +272,25 @@ async def update_seats( # Self-hosted requires license if not MULTI_TENANT and not license_data: - raise HTTPException(status_code=400, detail="No license found") + raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No license found") # Validate that new seat count is not less than current used seats used_seats = get_used_seats(tenant_id) if request.new_seat_count < used_seats: - raise HTTPException( - status_code=400, - detail=f"Cannot reduce seats below current usage. " + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + f"Cannot reduce seats below current usage. " f"You have {used_seats} active users/integrations but requested {request.new_seat_count} seats.", ) - try: - result = await update_seat_service( - new_seat_count=request.new_seat_count, - license_data=license_data, - tenant_id=tenant_id, - ) - - # Note: Don't store license here - the control plane may still be processing - # the subscription update. The frontend should call /license/claim after a - # short delay to get the freshly generated license. - - return result - except BillingServiceError as e: - raise HTTPException(status_code=e.status_code, detail=e.message) + # Note: Don't store license here - the control plane may still be processing + # the subscription update. The frontend should call /license/claim after a + # short delay to get the freshly generated license. + return await update_seat_service( + new_seat_count=request.new_seat_count, + license_data=license_data, + tenant_id=tenant_id, + ) @router.get("/stripe-publishable-key") @@ -329,18 +321,18 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse: if STRIPE_PUBLISHABLE_KEY_OVERRIDE: key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip() if not key.startswith("pk_"): - raise HTTPException( - status_code=500, - detail="Invalid Stripe publishable key format", + raise OnyxError( + OnyxErrorCode.INTERNAL_ERROR, + "Invalid Stripe publishable key format", ) _stripe_publishable_key_cache = key return StripePublishableKeyResponse(publishable_key=key) # Fall back to S3 bucket if not STRIPE_PUBLISHABLE_KEY_URL: - raise HTTPException( - status_code=500, - detail="Stripe publishable key is not configured", + raise OnyxError( + OnyxErrorCode.INTERNAL_ERROR, + "Stripe publishable key is not configured", ) try: @@ -351,17 +343,17 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse: # Validate key format if not key.startswith("pk_"): - raise HTTPException( - status_code=500, - detail="Invalid Stripe publishable key format", + raise OnyxError( + OnyxErrorCode.INTERNAL_ERROR, + "Invalid Stripe publishable key format", ) _stripe_publishable_key_cache = key return StripePublishableKeyResponse(publishable_key=key) except httpx.HTTPError: - raise HTTPException( - status_code=500, - detail="Failed to fetch Stripe publishable key", + raise OnyxError( + OnyxErrorCode.INTERNAL_ERROR, + "Failed to fetch Stripe publishable key", ) diff --git a/backend/ee/onyx/server/billing/service.py b/backend/ee/onyx/server/billing/service.py index 041a975d0af..941e96cc7e6 100644 --- a/backend/ee/onyx/server/billing/service.py +++ b/backend/ee/onyx/server/billing/service.py @@ -22,6 +22,8 @@ from ee.onyx.server.billing.models import SubscriptionStatusResponse from ee.onyx.server.tenants.access import generate_data_plane_token from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL +from onyx.error_handling.error_codes import OnyxErrorCode +from onyx.error_handling.exceptions import OnyxError from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT @@ -31,15 +33,6 @@ _REQUEST_TIMEOUT = 30.0 -class BillingServiceError(Exception): - """Exception raised for billing service errors.""" - - def __init__(self, message: str, status_code: int = 500): - self.message = message - self.status_code = status_code - super().__init__(self.message) - - def _get_proxy_headers(license_data: str | None) -> dict[str, str]: """Build headers for proxy requests (self-hosted). @@ -101,7 +94,7 @@ async def _make_billing_request( Response JSON as dict Raises: - BillingServiceError: If request fails + OnyxError: If request fails """ base_url = _get_base_url() @@ -128,11 +121,17 @@ async def _make_billing_request( except Exception: pass logger.error(f"{error_message}: {e.response.status_code} - {detail}") - raise BillingServiceError(detail, e.response.status_code) + raise OnyxError( + OnyxErrorCode.BAD_GATEWAY, + detail, + status_code_override=e.response.status_code, + ) except httpx.RequestError: logger.exception("Failed to connect to billing service") - raise BillingServiceError("Failed to connect to billing service", 502) + raise OnyxError( + OnyxErrorCode.BAD_GATEWAY, "Failed to connect to billing service" + ) async def create_checkout_session( diff --git a/backend/ee/onyx/server/enterprise_settings/api.py b/backend/ee/onyx/server/enterprise_settings/api.py index b938e07bd85..73b367bea71 100644 --- a/backend/ee/onyx/server/enterprise_settings/api.py +++ b/backend/ee/onyx/server/enterprise_settings/api.py @@ -223,6 +223,15 @@ def get_active_scim_token( token = dal.get_active_token() if not token: raise HTTPException(status_code=404, detail="No active SCIM token") + + # Derive the IdP domain from the first synced user as a heuristic. + idp_domain: str | None = None + mappings, _total = dal.list_user_mappings(start_index=1, count=1) + if mappings: + user = dal.get_user(mappings[0].user_id) + if user and "@" in user.email: + idp_domain = user.email.rsplit("@", 1)[1] + return ScimTokenResponse( id=token.id, name=token.name, @@ -230,6 +239,7 @@ def get_active_scim_token( is_active=token.is_active, created_at=token.created_at, last_used_at=token.last_used_at, + idp_domain=idp_domain, ) diff --git a/backend/ee/onyx/server/license/api.py b/backend/ee/onyx/server/license/api.py index 416b13cb5a3..5c16f4e3192 100644 --- a/backend/ee/onyx/server/license/api.py +++ b/backend/ee/onyx/server/license/api.py @@ -14,7 +14,6 @@ from fastapi import APIRouter from fastapi import Depends from fastapi import File -from fastapi import HTTPException from fastapi import UploadFile from sqlalchemy.orm import Session @@ -35,6 +34,8 @@ from ee.onyx.utils.license import verify_license_signature from onyx.auth.users import User from onyx.db.engine.sql_engine import get_session +from onyx.error_handling.error_codes import OnyxErrorCode +from onyx.error_handling.exceptions import OnyxError from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT @@ -127,9 +128,9 @@ async def claim_license( 2. Without session_id: Re-claim using existing license for auth """ if MULTI_TENANT: - raise HTTPException( - status_code=400, - detail="License claiming is only available for self-hosted deployments", + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "License claiming is only available for self-hosted deployments", ) try: @@ -146,15 +147,16 @@ async def claim_license( # Re-claim using existing license for auth metadata = get_license_metadata(db_session) if not metadata or not metadata.tenant_id: - raise HTTPException( - status_code=400, - detail="No license found. Provide session_id after checkout.", + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "No license found. Provide session_id after checkout.", ) license_row = get_license(db_session) if not license_row or not license_row.license_data: - raise HTTPException( - status_code=400, detail="No license found in database" + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "No license found in database", ) url = f"{CLOUD_DATA_PLANE_URL}/proxy/license/{metadata.tenant_id}" @@ -173,7 +175,7 @@ async def claim_license( license_data = data.get("license") if not license_data: - raise HTTPException(status_code=404, detail="No license in response") + raise OnyxError(OnyxErrorCode.NOT_FOUND, "No license in response") # Verify signature before persisting payload = verify_license_signature(license_data) @@ -199,12 +201,14 @@ async def claim_license( detail = error_data.get("detail", detail) except Exception: pass - raise HTTPException(status_code=status_code, detail=detail) + raise OnyxError( + OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=status_code + ) except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) + raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e)) except requests.RequestException: - raise HTTPException( - status_code=502, detail="Failed to connect to license server" + raise OnyxError( + OnyxErrorCode.BAD_GATEWAY, "Failed to connect to license server" ) @@ -221,9 +225,9 @@ async def upload_license( The license file must be cryptographically signed by Onyx. """ if MULTI_TENANT: - raise HTTPException( - status_code=400, - detail="License upload is only available for self-hosted deployments", + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "License upload is only available for self-hosted deployments", ) try: @@ -234,14 +238,14 @@ async def upload_license( # Remove any stray whitespace/newlines from user input license_data = license_data.strip() except UnicodeDecodeError: - raise HTTPException(status_code=400, detail="Invalid license file format") + raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Invalid license file format") # Verify cryptographic signature - this is the only validation needed # The license's tenant_id identifies the customer in control plane, not locally try: payload = verify_license_signature(license_data) except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) + raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e)) # Persist to DB and update cache upsert_license(db_session, license_data) @@ -297,9 +301,9 @@ async def delete_license( Admin only - removes license from database and invalidates cache. """ if MULTI_TENANT: - raise HTTPException( - status_code=400, - detail="License deletion is only available for self-hosted deployments", + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "License deletion is only available for self-hosted deployments", ) try: diff --git a/backend/ee/onyx/server/middleware/license_enforcement.py b/backend/ee/onyx/server/middleware/license_enforcement.py index 133f598233b..03a73c4367e 100644 --- a/backend/ee/onyx/server/middleware/license_enforcement.py +++ b/backend/ee/onyx/server/middleware/license_enforcement.py @@ -46,7 +46,6 @@ from fastapi import Request from fastapi import Response from fastapi.responses import JSONResponse -from redis.exceptions import RedisError from sqlalchemy.exc import SQLAlchemyError from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED @@ -56,6 +55,7 @@ ) from ee.onyx.db.license import get_cached_license_metadata from ee.onyx.db.license import refresh_license_cache +from onyx.cache.interface import CACHE_TRANSIENT_ERRORS from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.server.settings.models import ApplicationStatus from shared_configs.contextvars import get_current_tenant_id @@ -164,9 +164,9 @@ async def enforce_license( "[license_enforcement] No license, allowing community features" ) is_gated = False - except RedisError as e: + except CACHE_TRANSIENT_ERRORS as e: logger.warning(f"Failed to check license metadata: {e}") - # Fail open - don't block users due to Redis connectivity issues + # Fail open - don't block users due to cache connectivity issues is_gated = False if is_gated: diff --git a/backend/ee/onyx/server/scim/api.py b/backend/ee/onyx/server/scim/api.py index 3f9fc55f050..ef8cce40fbf 100644 --- a/backend/ee/onyx/server/scim/api.py +++ b/backend/ee/onyx/server/scim/api.py @@ -15,7 +15,9 @@ from fastapi import APIRouter from fastapi import Depends +from fastapi import FastAPI from fastapi import Query +from fastapi import Request from fastapi import Response from fastapi.responses import JSONResponse from fastapi_users.password import PasswordHelper @@ -24,6 +26,7 @@ from sqlalchemy.orm import Session from ee.onyx.db.scim import ScimDAL +from ee.onyx.server.scim.auth import ScimAuthError from ee.onyx.server.scim.auth import verify_scim_token from ee.onyx.server.scim.filtering import parse_scim_filter from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA @@ -77,6 +80,22 @@ class ScimJSONResponse(JSONResponse): _pw_helper = PasswordHelper() +def register_scim_exception_handlers(app: FastAPI) -> None: + """Register SCIM-specific exception handlers on the FastAPI app. + + Call this after ``app.include_router(scim_router)`` so that auth + failures from ``verify_scim_token`` return RFC 7644 §3.12 error + envelopes (with ``schemas`` and ``status`` fields) instead of + FastAPI's default ``{"detail": "..."}`` format. + """ + + @app.exception_handler(ScimAuthError) + async def _handle_scim_auth_error( + _request: Request, exc: ScimAuthError + ) -> ScimJSONResponse: + return _scim_error_response(exc.status_code, exc.detail) + + def _get_provider( _token: ScimToken = Depends(verify_scim_token), ) -> ScimProvider: @@ -404,21 +423,63 @@ def create_user( email = user_resource.userName.strip() - # externalId is how the IdP correlates this user on subsequent requests. - # Without it, the IdP can't find the user and will try to re-create, - # hitting a 409 conflict — so we require it up front. - if not user_resource.externalId: - return _scim_error_response(400, "externalId is required") + # Check for existing user — if they exist but aren't SCIM-managed yet, + # link them to the IdP rather than rejecting with 409. + external_id: str | None = user_resource.externalId + scim_username: str = user_resource.userName.strip() + fields: ScimMappingFields = _fields_from_resource(user_resource) + + existing_user = dal.get_user_by_email(email) + if existing_user: + existing_mapping = dal.get_user_mapping_by_user_id(existing_user.id) + if existing_mapping: + return _scim_error_response(409, f"User with email {email} already exists") + + # Adopt pre-existing user into SCIM management. + # Reactivating a deactivated user consumes a seat, so enforce the + # seat limit the same way replace_user does. + if user_resource.active and not existing_user.is_active: + seat_error = _check_seat_availability(dal) + if seat_error: + return _scim_error_response(403, seat_error) - # Enforce seat limit + personal_name = _scim_name_to_str(user_resource.name) + dal.update_user( + existing_user, + is_active=user_resource.active, + **({"personal_name": personal_name} if personal_name else {}), + ) + + try: + dal.create_user_mapping( + external_id=external_id, + user_id=existing_user.id, + scim_username=scim_username, + fields=fields, + ) + dal.commit() + except IntegrityError: + dal.rollback() + return _scim_error_response( + 409, f"User with email {email} already has a SCIM mapping" + ) + + return _scim_resource_response( + provider.build_user_resource( + existing_user, + external_id, + scim_username=scim_username, + fields=fields, + ), + status_code=201, + ) + + # Only enforce seat limit for net-new users — adopting a pre-existing + # user doesn't consume a new seat. seat_error = _check_seat_availability(dal) if seat_error: return _scim_error_response(403, seat_error) - # Check for existing user - if dal.get_user_by_email(email): - return _scim_error_response(409, f"User with email {email} already exists") - # Create user with a random password (SCIM users authenticate via IdP) personal_name = _scim_name_to_str(user_resource.name) user = User( @@ -436,18 +497,21 @@ def create_user( dal.rollback() return _scim_error_response(409, f"User with email {email} already exists") - # Create SCIM mapping (externalId is validated above, always present) - external_id = user_resource.externalId - scim_username = user_resource.userName.strip() - fields = _fields_from_resource(user_resource) - dal.create_user_mapping( - external_id=external_id, - user_id=user.id, - scim_username=scim_username, - fields=fields, - ) - - dal.commit() + # Always create a SCIM mapping so that the user is marked as + # SCIM-managed. externalId may be None (RFC 7643 says it's optional). + try: + dal.create_user_mapping( + external_id=external_id, + user_id=user.id, + scim_username=scim_username, + fields=fields, + ) + dal.commit() + except IntegrityError: + dal.rollback() + return _scim_error_response( + 409, f"User with email {email} already has a SCIM mapping" + ) return _scim_resource_response( provider.build_user_resource( diff --git a/backend/ee/onyx/server/scim/auth.py b/backend/ee/onyx/server/scim/auth.py index d05a1bd140b..e8965815053 100644 --- a/backend/ee/onyx/server/scim/auth.py +++ b/backend/ee/onyx/server/scim/auth.py @@ -19,7 +19,6 @@ import secrets from fastapi import Depends -from fastapi import HTTPException from fastapi import Request from sqlalchemy.orm import Session @@ -28,6 +27,21 @@ from onyx.db.engine.sql_engine import get_session from onyx.db.models import ScimToken + +class ScimAuthError(Exception): + """Raised when SCIM bearer token authentication fails. + + Unlike HTTPException, this carries the status and detail so the SCIM + exception handler can wrap them in an RFC 7644 §3.12 error envelope + with ``schemas`` and ``status`` fields. + """ + + def __init__(self, status_code: int, detail: str) -> None: + self.status_code = status_code + self.detail = detail + super().__init__(detail) + + SCIM_TOKEN_PREFIX = "onyx_scim_" SCIM_TOKEN_LENGTH = 48 @@ -82,23 +96,14 @@ def verify_scim_token( """ hashed = _get_hashed_scim_token_from_request(request) if not hashed: - raise HTTPException( - status_code=401, - detail="Missing or invalid SCIM bearer token", - ) + raise ScimAuthError(401, "Missing or invalid SCIM bearer token") token = dal.get_token_by_hash(hashed) if not token: - raise HTTPException( - status_code=401, - detail="Invalid SCIM bearer token", - ) + raise ScimAuthError(401, "Invalid SCIM bearer token") if not token.is_active: - raise HTTPException( - status_code=401, - detail="SCIM token has been revoked", - ) + raise ScimAuthError(401, "SCIM token has been revoked") return token diff --git a/backend/ee/onyx/server/scim/models.py b/backend/ee/onyx/server/scim/models.py index 64221e417f0..441d60fbe05 100644 --- a/backend/ee/onyx/server/scim/models.py +++ b/backend/ee/onyx/server/scim/models.py @@ -365,6 +365,7 @@ class ScimTokenResponse(BaseModel): is_active: bool created_at: datetime last_used_at: datetime | None = None + idp_domain: str | None = None class ScimTokenCreatedResponse(ScimTokenResponse): diff --git a/backend/ee/onyx/server/scim/providers/base.py b/backend/ee/onyx/server/scim/providers/base.py index 5dc5fac3049..af660b32e53 100644 --- a/backend/ee/onyx/server/scim/providers/base.py +++ b/backend/ee/onyx/server/scim/providers/base.py @@ -153,26 +153,31 @@ def build_scim_name( self, user: User, fields: ScimMappingFields, - ) -> ScimName | None: + ) -> ScimName: """Build SCIM name components for the response. Round-trips stored ``given_name``/``family_name`` when available (so the IdP gets back what it sent). Falls back to splitting ``personal_name`` for users provisioned before we stored components. + Always returns a ScimName — Okta's spec tests expect ``name`` + (with ``givenName``/``familyName``) on every user resource. Providers may override for custom behavior. """ if fields.given_name is not None or fields.family_name is not None: return ScimName( - givenName=fields.given_name, - familyName=fields.family_name, - formatted=user.personal_name, + givenName=fields.given_name or "", + familyName=fields.family_name or "", + formatted=user.personal_name or "", ) if not user.personal_name: - return None + # Derive a reasonable name from the email so that SCIM spec tests + # see non-empty givenName / familyName for every user resource. + local = user.email.split("@")[0] if user.email else "" + return ScimName(givenName=local, familyName="", formatted=local) parts = user.personal_name.split(" ", 1) return ScimName( givenName=parts[0], - familyName=parts[1] if len(parts) > 1 else None, + familyName=parts[1] if len(parts) > 1 else "", formatted=user.personal_name, ) diff --git a/backend/ee/onyx/server/seeding.py b/backend/ee/onyx/server/seeding.py index 1539db9c63f..04652f74a68 100644 --- a/backend/ee/onyx/server/seeding.py +++ b/backend/ee/onyx/server/seeding.py @@ -26,6 +26,7 @@ from onyx.db.persona import upsert_persona from onyx.server.features.persona.models import PersonaUpsertRequest from onyx.server.manage.llm.models import LLMProviderUpsertRequest +from onyx.server.manage.llm.models import LLMProviderView from onyx.server.settings.models import Settings from onyx.server.settings.store import store_settings as store_base_settings from onyx.utils.logger import setup_logger @@ -125,10 +126,16 @@ def _seed_llms( existing = fetch_existing_llm_provider(name=request.name, db_session=db_session) if existing: request.id = existing.id - seeded_providers = [ - upsert_llm_provider(llm_upsert_request, db_session) - for llm_upsert_request in llm_upsert_requests - ] + seeded_providers: list[LLMProviderView] = [] + for llm_upsert_request in llm_upsert_requests: + try: + seeded_providers.append(upsert_llm_provider(llm_upsert_request, db_session)) + except ValueError as e: + logger.warning( + "Failed to upsert LLM provider '%s' during seeding: %s", + llm_upsert_request.name, + e, + ) default_provider = next( (p for p in seeded_providers if p.model_configurations), None diff --git a/backend/ee/onyx/server/settings/api.py b/backend/ee/onyx/server/settings/api.py index 35ce9cb1802..a79eb546291 100644 --- a/backend/ee/onyx/server/settings/api.py +++ b/backend/ee/onyx/server/settings/api.py @@ -6,6 +6,7 @@ from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED from ee.onyx.db.license import get_cached_license_metadata from ee.onyx.db.license import refresh_license_cache +from onyx.cache.interface import CACHE_TRANSIENT_ERRORS from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.server.settings.models import ApplicationStatus @@ -125,7 +126,7 @@ def apply_license_status_to_settings(settings: Settings) -> Settings: # syncing) means indexed data may need protection. settings.application_status = _BLOCKING_STATUS settings.ee_features_enabled = False - except RedisError as e: + except CACHE_TRANSIENT_ERRORS as e: logger.warning(f"Failed to check license metadata for settings: {e}") # Fail closed - disable EE features if we can't verify license settings.ee_features_enabled = False diff --git a/backend/ee/onyx/server/tenants/billing_api.py b/backend/ee/onyx/server/tenants/billing_api.py index c357215681c..63018462261 100644 --- a/backend/ee/onyx/server/tenants/billing_api.py +++ b/backend/ee/onyx/server/tenants/billing_api.py @@ -21,7 +21,6 @@ import httpx from fastapi import APIRouter from fastapi import Depends -from fastapi import HTTPException from ee.onyx.auth.users import current_admin_user from ee.onyx.server.tenants.access import control_plane_dep @@ -43,6 +42,8 @@ from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL from onyx.configs.app_configs import WEB_DOMAIN +from onyx.error_handling.error_codes import OnyxErrorCode +from onyx.error_handling.exceptions import OnyxError from onyx.utils.logger import setup_logger from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.contextvars import get_current_tenant_id @@ -116,9 +117,14 @@ async def create_customer_portal_session( try: portal_url = fetch_customer_portal_session(tenant_id, return_url) return {"stripe_customer_portal_url": portal_url} - except Exception as e: + except OnyxError: + raise + except Exception: logger.exception("Failed to create customer portal session") - raise HTTPException(status_code=500, detail=str(e)) + raise OnyxError( + OnyxErrorCode.INTERNAL_ERROR, + "Failed to create customer portal session", + ) @router.post("/create-checkout-session") @@ -134,9 +140,14 @@ async def create_checkout_session( try: checkout_url = fetch_stripe_checkout_session(tenant_id, billing_period, seats) return {"stripe_checkout_url": checkout_url} - except Exception as e: + except OnyxError: + raise + except Exception: logger.exception("Failed to create checkout session") - raise HTTPException(status_code=500, detail=str(e)) + raise OnyxError( + OnyxErrorCode.INTERNAL_ERROR, + "Failed to create checkout session", + ) @router.post("/create-subscription-session") @@ -147,15 +158,20 @@ async def create_subscription_session( try: tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() if not tenant_id: - raise HTTPException(status_code=400, detail="Tenant ID not found") + raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Tenant ID not found") billing_period = request.billing_period if request else "monthly" session_id = fetch_stripe_checkout_session(tenant_id, billing_period) return SubscriptionSessionResponse(sessionId=session_id) - except Exception as e: + except OnyxError: + raise + except Exception: logger.exception("Failed to create subscription session") - raise HTTPException(status_code=500, detail=str(e)) + raise OnyxError( + OnyxErrorCode.INTERNAL_ERROR, + "Failed to create subscription session", + ) @router.get("/stripe-publishable-key") @@ -186,18 +202,18 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse: if STRIPE_PUBLISHABLE_KEY_OVERRIDE: key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip() if not key.startswith("pk_"): - raise HTTPException( - status_code=500, - detail="Invalid Stripe publishable key format", + raise OnyxError( + OnyxErrorCode.INTERNAL_ERROR, + "Invalid Stripe publishable key format", ) _stripe_publishable_key_cache = key return StripePublishableKeyResponse(publishable_key=key) # Fall back to S3 bucket if not STRIPE_PUBLISHABLE_KEY_URL: - raise HTTPException( - status_code=500, - detail="Stripe publishable key is not configured", + raise OnyxError( + OnyxErrorCode.INTERNAL_ERROR, + "Stripe publishable key is not configured", ) try: @@ -208,15 +224,15 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse: # Validate key format if not key.startswith("pk_"): - raise HTTPException( - status_code=500, - detail="Invalid Stripe publishable key format", + raise OnyxError( + OnyxErrorCode.INTERNAL_ERROR, + "Invalid Stripe publishable key format", ) _stripe_publishable_key_cache = key return StripePublishableKeyResponse(publishable_key=key) except httpx.HTTPError: - raise HTTPException( - status_code=500, - detail="Failed to fetch Stripe publishable key", + raise OnyxError( + OnyxErrorCode.INTERNAL_ERROR, + "Failed to fetch Stripe publishable key", ) diff --git a/backend/ee/onyx/server/user_group/api.py b/backend/ee/onyx/server/user_group/api.py index b56a65c303c..9c2709955c5 100644 --- a/backend/ee/onyx/server/user_group/api.py +++ b/backend/ee/onyx/server/user_group/api.py @@ -5,6 +5,8 @@ from sqlalchemy.orm import Session from ee.onyx.db.user_group import add_users_to_user_group +from ee.onyx.db.user_group import delete_user_group as db_delete_user_group +from ee.onyx.db.user_group import fetch_user_group from ee.onyx.db.user_group import fetch_user_groups from ee.onyx.db.user_group import fetch_user_groups_for_user from ee.onyx.db.user_group import insert_user_group @@ -20,6 +22,7 @@ from onyx.auth.users import current_admin_user from onyx.auth.users import current_curator_or_admin_user from onyx.auth.users import current_user +from onyx.configs.app_configs import DISABLE_VECTOR_DB from onyx.configs.constants import PUBLIC_API_TAGS from onyx.db.engine.sql_engine import get_session from onyx.db.models import User @@ -153,3 +156,8 @@ def delete_user_group( prepare_user_group_for_deletion(db_session, user_group_id) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) + + if DISABLE_VECTOR_DB: + user_group = fetch_user_group(db_session, user_group_id) + if user_group: + db_delete_user_group(db_session, user_group) diff --git a/backend/ee/onyx/utils/encryption.py b/backend/ee/onyx/utils/encryption.py index 85be05df31b..79b75f7441b 100644 --- a/backend/ee/onyx/utils/encryption.py +++ b/backend/ee/onyx/utils/encryption.py @@ -14,67 +14,91 @@ logger = setup_logger() -@lru_cache(maxsize=1) +@lru_cache(maxsize=2) def _get_trimmed_key(key: str) -> bytes: encoded_key = key.encode() key_length = len(encoded_key) if key_length < 16: raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short") - elif key_length > 32: - key = key[:32] - elif key_length not in (16, 24, 32): - valid_lengths = [16, 24, 32] - key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))] - return encoded_key + # Trim to the largest valid AES key size that fits + valid_lengths = [32, 24, 16] + for size in valid_lengths: + if key_length >= size: + return encoded_key[:size] + raise AssertionError("unreachable") -def _encrypt_string(input_str: str) -> bytes: - if not ENCRYPTION_KEY_SECRET: + +def _encrypt_string(input_str: str, key: str | None = None) -> bytes: + effective_key = key if key is not None else ENCRYPTION_KEY_SECRET + if not effective_key: return input_str.encode() - key = _get_trimmed_key(ENCRYPTION_KEY_SECRET) + trimmed = _get_trimmed_key(effective_key) iv = urandom(16) padder = padding.PKCS7(algorithms.AES.block_size).padder() padded_data = padder.update(input_str.encode()) + padder.finalize() - cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) + cipher = Cipher(algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend()) encryptor = cipher.encryptor() encrypted_data = encryptor.update(padded_data) + encryptor.finalize() return iv + encrypted_data -def _decrypt_bytes(input_bytes: bytes) -> str: - if not ENCRYPTION_KEY_SECRET: +def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str: + effective_key = key if key is not None else ENCRYPTION_KEY_SECRET + if not effective_key: return input_bytes.decode() - key = _get_trimmed_key(ENCRYPTION_KEY_SECRET) - iv = input_bytes[:16] - encrypted_data = input_bytes[16:] - - cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) - decryptor = cipher.decryptor() - decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize() - - unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder() - decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize() - - return decrypted_data.decode() - - -def encrypt_string_to_bytes(input_str: str) -> bytes: + trimmed = _get_trimmed_key(effective_key) + try: + iv = input_bytes[:16] + encrypted_data = input_bytes[16:] + + cipher = Cipher( + algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend() + ) + decryptor = cipher.decryptor() + decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize() + + unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder() + decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize() + + return decrypted_data.decode() + except (ValueError, UnicodeDecodeError): + if key is not None: + # Explicit key was provided — don't fall back silently + raise + # Read path: attempt raw UTF-8 decode as a fallback for legacy data. + # Does NOT handle data encrypted with a different key — that + # ciphertext is not valid UTF-8 and will raise below. + logger.warning( + "AES decryption failed — falling back to raw decode. " + "Run the re-encrypt secrets script to rotate to the current key." + ) + try: + return input_bytes.decode() + except UnicodeDecodeError: + raise ValueError( + "Data is not valid UTF-8 — likely encrypted with a different key. " + "Run the re-encrypt secrets script to rotate to the current key." + ) from None + + +def encrypt_string_to_bytes(input_str: str, key: str | None = None) -> bytes: versioned_encryption_fn = fetch_versioned_implementation( "onyx.utils.encryption", "_encrypt_string" ) - return versioned_encryption_fn(input_str) + return versioned_encryption_fn(input_str, key=key) -def decrypt_bytes_to_string(input_bytes: bytes) -> str: +def decrypt_bytes_to_string(input_bytes: bytes, key: str | None = None) -> str: versioned_decryption_fn = fetch_versioned_implementation( "onyx.utils.encryption", "_decrypt_bytes" ) - return versioned_decryption_fn(input_bytes) + return versioned_decryption_fn(input_bytes, key=key) def test_encryption() -> None: diff --git a/backend/onyx/access/access.py b/backend/onyx/access/access.py index 6871bfe8ae1..49d93ac6d19 100644 --- a/backend/onyx/access/access.py +++ b/backend/onyx/access/access.py @@ -1,7 +1,6 @@ from collections.abc import Callable from typing import cast -from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session from onyx.access.models import DocumentAccess @@ -12,6 +11,7 @@ from onyx.db.document import get_access_info_for_documents from onyx.db.models import User from onyx.db.models import UserFile +from onyx.db.user_file import fetch_user_files_with_access_relationships from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop from onyx.utils.variable_functionality import fetch_versioned_implementation @@ -132,19 +132,61 @@ def get_access_for_user_files( user_file_ids: list[str], db_session: Session, ) -> dict[str, DocumentAccess]: - user_files = ( - db_session.query(UserFile) - .options(joinedload(UserFile.user)) # Eager load the user relationship - .filter(UserFile.id.in_(user_file_ids)) - .all() + versioned_fn = fetch_versioned_implementation( + "onyx.access.access", "get_access_for_user_files_impl" ) - return { - str(user_file.id): DocumentAccess.build( - user_emails=[user_file.user.email] if user_file.user else [], + return versioned_fn(user_file_ids, db_session) + + +def get_access_for_user_files_impl( + user_file_ids: list[str], + db_session: Session, +) -> dict[str, DocumentAccess]: + user_files = fetch_user_files_with_access_relationships(user_file_ids, db_session) + return build_access_for_user_files_impl(user_files) + + +def build_access_for_user_files( + user_files: list[UserFile], +) -> dict[str, DocumentAccess]: + """Compute access from pre-loaded UserFile objects (with relationships). + Callers must ensure UserFile.user, Persona.users, and Persona.user are + eagerly loaded (and Persona.groups for the EE path).""" + versioned_fn = fetch_versioned_implementation( + "onyx.access.access", "build_access_for_user_files_impl" + ) + return versioned_fn(user_files) + + +def build_access_for_user_files_impl( + user_files: list[UserFile], +) -> dict[str, DocumentAccess]: + result: dict[str, DocumentAccess] = {} + for user_file in user_files: + emails, is_public = collect_user_file_access(user_file) + result[str(user_file.id)] = DocumentAccess.build( + user_emails=list(emails), user_groups=[], - is_public=True if user_file.user is None else False, + is_public=is_public, external_user_emails=[], external_user_group_ids=[], ) - for user_file in user_files - } + return result + + +def collect_user_file_access(user_file: UserFile) -> tuple[set[str], bool]: + """Collect all user emails that should have access to this user file. + Includes the owner plus any users who have access via shared personas. + Returns (emails, is_public).""" + emails: set[str] = {user_file.user.email} + is_public = False + for persona in user_file.assistants: + if persona.deleted: + continue + if persona.is_public: + is_public = True + if persona.user_id is not None and persona.user: + emails.add(persona.user.email) + for shared_user in persona.users: + emails.add(shared_user.email) + return emails, is_public diff --git a/backend/onyx/auth/users.py b/backend/onyx/auth/users.py index 28a48487652..638b74f2822 100644 --- a/backend/onyx/auth/users.py +++ b/backend/onyx/auth/users.py @@ -1,4 +1,7 @@ +import base64 +import hashlib import json +import os import random import secrets import string @@ -28,6 +31,7 @@ from fastapi import Request from fastapi import Response from fastapi import status +from fastapi.responses import JSONResponse from fastapi.responses import RedirectResponse from fastapi.security import OAuth2PasswordRequestForm from fastapi_users import BaseUserManager @@ -54,6 +58,7 @@ from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback from httpx_oauth.oauth2 import BaseOAuth2 +from httpx_oauth.oauth2 import GetAccessTokenError from httpx_oauth.oauth2 import OAuth2Token from pydantic import BaseModel from sqlalchemy import nulls_last @@ -119,8 +124,11 @@ from onyx.db.models import User from onyx.db.pat import fetch_user_for_pat from onyx.db.users import get_user_by_email +from onyx.error_handling.error_codes import OnyxErrorCode +from onyx.error_handling.exceptions import log_onyx_error +from onyx.error_handling.exceptions import onyx_error_to_json_response +from onyx.error_handling.exceptions import OnyxError from onyx.redis.redis_pool import get_async_redis_connection -from onyx.redis.redis_pool import get_redis_client from onyx.server.settings.store import load_settings from onyx.server.utils import BasicAuthenticationError from onyx.utils.logger import setup_logger @@ -146,10 +154,22 @@ def is_user_admin(user: User) -> bool: def verify_auth_setting() -> None: - if AUTH_TYPE == AuthType.CLOUD: + """Log warnings for AUTH_TYPE issues. + + This only runs on app startup not during migrations/scripts. + """ + raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower() + + if raw_auth_type == "cloud": raise ValueError( - f"{AUTH_TYPE.value} is not a valid auth type for self-hosted deployments." + "'cloud' is not a valid auth type for self-hosted deployments." ) + if raw_auth_type == "disabled": + logger.warning( + "AUTH_TYPE='disabled' is no longer supported. " + "Using 'basic' instead. Please update your configuration." + ) + logger.notice(f"Using Auth Type: {AUTH_TYPE.value}") @@ -201,13 +221,14 @@ def user_needs_to_be_verified() -> bool: def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool: - redis_client = get_redis_client(tenant_id=tenant_id) - value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED) + from onyx.cache.factory import get_cache_backend + + cache = get_cache_backend(tenant_id=tenant_id) + value = cache.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED) if value is None: return False - assert isinstance(value, bytes) return int(value.decode("utf-8")) == 1 @@ -725,11 +746,19 @@ async def oauth_callback( if user_by_session: user = user_by_session + # If the user is inactive, check seat availability before + # upgrading role — otherwise they'd become an inactive BASIC + # user who still can't log in. + if not user.is_active: + with get_session_with_current_tenant() as sync_db: + enforce_seat_limit(sync_db) + await self.user_db.update( user, { "is_verified": is_verified_by_default, "role": UserRole.BASIC, + **({"is_active": True} if not user.is_active else {}), }, ) @@ -1600,6 +1629,7 @@ def get_default_admin_user_emails_() -> list[str]: STATE_TOKEN_LIFETIME_SECONDS = 3600 CSRF_TOKEN_KEY = "csrftoken" CSRF_TOKEN_COOKIE_NAME = "fastapiusersoauthcsrf" +PKCE_COOKIE_NAME_PREFIX = "fastapiusersoauthpkce" class OAuth2AuthorizeResponse(BaseModel): @@ -1620,6 +1650,21 @@ def generate_csrf_token() -> str: return secrets.token_urlsafe(32) +def _base64url_encode(data: bytes) -> str: + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + +def generate_pkce_pair() -> tuple[str, str]: + verifier = secrets.token_urlsafe(64) + challenge = _base64url_encode(hashlib.sha256(verifier.encode("ascii")).digest()) + return verifier, challenge + + +def get_pkce_cookie_name(state: str) -> str: + state_hash = hashlib.sha256(state.encode("utf-8")).hexdigest() + return f"{PKCE_COOKIE_NAME_PREFIX}_{state_hash}" + + # refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91 def create_onyx_oauth_router( oauth_client: BaseOAuth2, @@ -1628,6 +1673,7 @@ def create_onyx_oauth_router( redirect_url: Optional[str] = None, associate_by_email: bool = False, is_verified_by_default: bool = False, + enable_pkce: bool = False, ) -> APIRouter: return get_oauth_router( oauth_client, @@ -1637,6 +1683,7 @@ def create_onyx_oauth_router( redirect_url, associate_by_email, is_verified_by_default, + enable_pkce=enable_pkce, ) @@ -1655,6 +1702,7 @@ def get_oauth_router( csrf_token_cookie_secure: Optional[bool] = None, csrf_token_cookie_httponly: bool = True, csrf_token_cookie_samesite: Optional[Literal["lax", "strict", "none"]] = "lax", + enable_pkce: bool = False, ) -> APIRouter: """Generate a router with the OAuth routes.""" router = APIRouter() @@ -1671,6 +1719,13 @@ def get_oauth_router( route_name=callback_route_name, ) + async def null_access_token_state() -> tuple[OAuth2Token, Optional[str]] | None: + return None + + access_token_state_dependency = ( + oauth2_authorize_callback if not enable_pkce else null_access_token_state + ) + if csrf_token_cookie_secure is None: csrf_token_cookie_secure = WEB_DOMAIN.startswith("https") @@ -1704,13 +1759,26 @@ async def authorize( CSRF_TOKEN_KEY: csrf_token, } state = generate_state_token(state_data, state_secret) - - # Get the basic authorization URL - authorization_url = await oauth_client.get_authorization_url( - authorize_redirect_url, - state, - scopes, - ) + pkce_cookie: tuple[str, str] | None = None + + if enable_pkce: + code_verifier, code_challenge = generate_pkce_pair() + pkce_cookie_name = get_pkce_cookie_name(state) + pkce_cookie = (pkce_cookie_name, code_verifier) + authorization_url = await oauth_client.get_authorization_url( + authorize_redirect_url, + state, + scopes, + code_challenge=code_challenge, + code_challenge_method="S256", + ) + else: + # Get the basic authorization URL + authorization_url = await oauth_client.get_authorization_url( + authorize_redirect_url, + state, + scopes, + ) # For Google OAuth, add parameters to request refresh tokens if oauth_client.name == "google": @@ -1718,11 +1786,15 @@ async def authorize( authorization_url, {"access_type": "offline", "prompt": "consent"} ) - if redirect: - redirect_response = RedirectResponse(authorization_url, status_code=302) - redirect_response.set_cookie( - key=csrf_token_cookie_name, - value=csrf_token, + def set_oauth_cookie( + target_response: Response, + *, + key: str, + value: str, + ) -> None: + target_response.set_cookie( + key=key, + value=value, max_age=STATE_TOKEN_LIFETIME_SECONDS, path=csrf_token_cookie_path, domain=csrf_token_cookie_domain, @@ -1730,18 +1802,28 @@ async def authorize( httponly=csrf_token_cookie_httponly, samesite=csrf_token_cookie_samesite, ) - return redirect_response - response.set_cookie( + response_with_cookies: Response + if redirect: + response_with_cookies = RedirectResponse(authorization_url, status_code=302) + else: + response_with_cookies = response + + set_oauth_cookie( + response_with_cookies, key=csrf_token_cookie_name, value=csrf_token, - max_age=STATE_TOKEN_LIFETIME_SECONDS, - path=csrf_token_cookie_path, - domain=csrf_token_cookie_domain, - secure=csrf_token_cookie_secure, - httponly=csrf_token_cookie_httponly, - samesite=csrf_token_cookie_samesite, ) + if pkce_cookie is not None: + pkce_cookie_name, code_verifier = pkce_cookie + set_oauth_cookie( + response_with_cookies, + key=pkce_cookie_name, + value=code_verifier, + ) + + if redirect: + return response_with_cookies return OAuth2AuthorizeResponse(authorization_url=authorization_url) @@ -1772,119 +1854,242 @@ async def authorize( ) async def callback( request: Request, - access_token_state: Tuple[OAuth2Token, str] = Depends( - oauth2_authorize_callback + access_token_state: Tuple[OAuth2Token, Optional[str]] | None = Depends( + access_token_state_dependency ), + code: Optional[str] = None, + state: Optional[str] = None, + error: Optional[str] = None, user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), - ) -> RedirectResponse: - token, state = access_token_state - account_id, account_email = await oauth_client.get_id_email( - token["access_token"] - ) + ) -> Response: + pkce_cookie_name: str | None = None - if account_email is None: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL, - ) + def delete_pkce_cookie(response: Response) -> None: + if enable_pkce and pkce_cookie_name: + response.delete_cookie( + key=pkce_cookie_name, + path=csrf_token_cookie_path, + domain=csrf_token_cookie_domain, + secure=csrf_token_cookie_secure, + httponly=csrf_token_cookie_httponly, + samesite=csrf_token_cookie_samesite, + ) - try: - state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE]) - except jwt.DecodeError: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=getattr( - ErrorCode, "ACCESS_TOKEN_DECODE_ERROR", "ACCESS_TOKEN_DECODE_ERROR" - ), - ) - except jwt.ExpiredSignatureError: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=getattr( - ErrorCode, - "ACCESS_TOKEN_ALREADY_EXPIRED", - "ACCESS_TOKEN_ALREADY_EXPIRED", - ), - ) + def build_error_response(exc: OnyxError) -> JSONResponse: + log_onyx_error(exc) + error_response = onyx_error_to_json_response(exc) + delete_pkce_cookie(error_response) + return error_response - cookie_csrf_token = request.cookies.get(csrf_token_cookie_name) - state_csrf_token = state_data.get(CSRF_TOKEN_KEY) - if ( - not cookie_csrf_token - or not state_csrf_token - or not secrets.compare_digest(cookie_csrf_token, state_csrf_token) - ): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=getattr(ErrorCode, "OAUTH_INVALID_STATE", "OAUTH_INVALID_STATE"), - ) + def decode_and_validate_state(state_value: str) -> Dict[str, str]: + try: + state_data = decode_jwt( + state_value, state_secret, [STATE_TOKEN_AUDIENCE] + ) + except jwt.DecodeError: + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + getattr( + ErrorCode, + "ACCESS_TOKEN_DECODE_ERROR", + "ACCESS_TOKEN_DECODE_ERROR", + ), + ) + except jwt.ExpiredSignatureError: + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + getattr( + ErrorCode, + "ACCESS_TOKEN_ALREADY_EXPIRED", + "ACCESS_TOKEN_ALREADY_EXPIRED", + ), + ) + except jwt.PyJWTError: + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + getattr( + ErrorCode, + "ACCESS_TOKEN_DECODE_ERROR", + "ACCESS_TOKEN_DECODE_ERROR", + ), + ) - next_url = state_data.get("next_url", "/") - referral_source = state_data.get("referral_source", None) - try: - tenant_id = fetch_ee_implementation_or_noop( - "onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None - )(account_email) - except exceptions.UserNotExists: - tenant_id = None + cookie_csrf_token = request.cookies.get(csrf_token_cookie_name) + state_csrf_token = state_data.get(CSRF_TOKEN_KEY) + if ( + not cookie_csrf_token + or not state_csrf_token + or not secrets.compare_digest(cookie_csrf_token, state_csrf_token) + ): + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + getattr(ErrorCode, "OAUTH_INVALID_STATE", "OAUTH_INVALID_STATE"), + ) - request.state.referral_source = referral_source + return state_data - # Proceed to authenticate or create the user - try: - user = await user_manager.oauth_callback( - oauth_client.name, - token["access_token"], - account_id, - account_email, - token.get("expires_at"), - token.get("refresh_token"), - request, - associate_by_email=associate_by_email, - is_verified_by_default=is_verified_by_default, - ) - except UserAlreadyExists: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ErrorCode.OAUTH_USER_ALREADY_EXISTS, - ) + token: OAuth2Token + state_data: Dict[str, str] - if not user.is_active: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ErrorCode.LOGIN_BAD_CREDENTIALS, - ) + # `code`, `state`, and `error` are read directly only in the PKCE path. + # In the non-PKCE path, `oauth2_authorize_callback` consumes them. + if enable_pkce: + if state is not None: + pkce_cookie_name = get_pkce_cookie_name(state) + + if error is not None: + return build_error_response( + OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "Authorization request failed or was denied", + ) + ) + if code is None: + return build_error_response( + OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "Missing authorization code in OAuth callback", + ) + ) + if state is None: + return build_error_response( + OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "Missing state parameter in OAuth callback", + ) + ) + + state_value = state - # Login user - response = await backend.login(strategy, user) - await user_manager.on_after_login(user, request, response) + if redirect_url is not None: + callback_redirect_url = redirect_url + else: + callback_path = request.app.url_path_for(callback_route_name) + callback_redirect_url = f"{WEB_DOMAIN}{callback_path}" + + code_verifier = request.cookies.get(cast(str, pkce_cookie_name)) + if not code_verifier: + return build_error_response( + OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "Missing PKCE verifier cookie in OAuth callback", + ) + ) + + try: + state_data = decode_and_validate_state(state_value) + except OnyxError as e: + return build_error_response(e) - # Prepare redirect response - if tenant_id is None: - # Use URL utility to add parameters - redirect_url = add_url_params(next_url, {"new_team": "true"}) - redirect_response = RedirectResponse(redirect_url, status_code=302) + try: + token = await oauth_client.get_access_token( + code, callback_redirect_url, code_verifier + ) + except GetAccessTokenError: + return build_error_response( + OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "Authorization code exchange failed", + ) + ) else: - # No parameters to add - redirect_response = RedirectResponse(next_url, status_code=302) - - # Copy headers from auth response to redirect response, with special handling for Set-Cookie - for header_name, header_value in response.headers.items(): - # FastAPI can have multiple Set-Cookie headers as a list - if header_name.lower() == "set-cookie" and isinstance(header_value, list): - for cookie_value in header_value: - redirect_response.headers.append(header_name, cookie_value) + if access_token_state is None: + raise OnyxError( + OnyxErrorCode.INTERNAL_ERROR, "Missing OAuth callback state" + ) + token, callback_state = access_token_state + if callback_state is None: + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "Missing state parameter in OAuth callback", + ) + state_data = decode_and_validate_state(callback_state) + + async def complete_login_flow( + token: OAuth2Token, state_data: Dict[str, str] + ) -> RedirectResponse: + account_id, account_email = await oauth_client.get_id_email( + token["access_token"] + ) + + if account_email is None: + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL, + ) + + next_url = state_data.get("next_url", "/") + referral_source = state_data.get("referral_source", None) + try: + tenant_id = fetch_ee_implementation_or_noop( + "onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None + )(account_email) + except exceptions.UserNotExists: + tenant_id = None + + request.state.referral_source = referral_source + + # Proceed to authenticate or create the user + try: + user = await user_manager.oauth_callback( + oauth_client.name, + token["access_token"], + account_id, + account_email, + token.get("expires_at"), + token.get("refresh_token"), + request, + associate_by_email=associate_by_email, + is_verified_by_default=is_verified_by_default, + ) + except UserAlreadyExists: + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + ErrorCode.OAUTH_USER_ALREADY_EXISTS, + ) + + if not user.is_active: + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + ErrorCode.LOGIN_BAD_CREDENTIALS, + ) + + # Login user + response = await backend.login(strategy, user) + await user_manager.on_after_login(user, request, response) + + # Prepare redirect response + if tenant_id is None: + # Use URL utility to add parameters + redirect_destination = add_url_params(next_url, {"new_team": "true"}) + redirect_response = RedirectResponse( + redirect_destination, status_code=302 + ) else: + # No parameters to add + redirect_response = RedirectResponse(next_url, status_code=302) + + # Copy headers from auth response to redirect response, with special handling for Set-Cookie + for header_name, header_value in response.headers.items(): + header_name_lower = header_name.lower() + if header_name_lower == "set-cookie": + redirect_response.headers.append(header_name, header_value) + continue + if header_name_lower in {"location", "content-length"}: + continue redirect_response.headers[header_name] = header_value - if hasattr(response, "body"): - redirect_response.body = response.body - if hasattr(response, "status_code"): - redirect_response.status_code = response.status_code - if hasattr(response, "media_type"): - redirect_response.media_type = response.media_type + return redirect_response + + if enable_pkce: + try: + redirect_response = await complete_login_flow(token, state_data) + except OnyxError as e: + return build_error_response(e) + delete_pkce_cookie(redirect_response) + return redirect_response - return redirect_response + return await complete_login_flow(token, state_data) return router diff --git a/backend/onyx/background/celery/apps/background.py b/backend/onyx/background/celery/apps/background.py deleted file mode 100644 index 137d14bdb3e..00000000000 --- a/backend/onyx/background/celery/apps/background.py +++ /dev/null @@ -1,142 +0,0 @@ -from typing import Any -from typing import cast - -from celery import Celery -from celery import signals -from celery import Task -from celery.apps.worker import Worker -from celery.signals import celeryd_init -from celery.signals import worker_init -from celery.signals import worker_process_init -from celery.signals import worker_ready -from celery.signals import worker_shutdown - -import onyx.background.celery.apps.app_base as app_base -from onyx.background.celery.celery_utils import httpx_init_vespa_pool -from onyx.configs.app_configs import MANAGED_VESPA -from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH -from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH -from onyx.configs.constants import POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME -from onyx.db.engine.sql_engine import SqlEngine -from onyx.utils.logger import setup_logger -from shared_configs.configs import MULTI_TENANT - - -logger = setup_logger() - -celery_app = Celery(__name__) -celery_app.config_from_object("onyx.background.celery.configs.background") -celery_app.Task = app_base.TenantAwareTask # type: ignore [misc] - - -@signals.task_prerun.connect -def on_task_prerun( - sender: Any | None = None, - task_id: str | None = None, - task: Task | None = None, - args: tuple | None = None, - kwargs: dict | None = None, - **kwds: Any, -) -> None: - app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds) - - -@signals.task_postrun.connect -def on_task_postrun( - sender: Any | None = None, - task_id: str | None = None, - task: Task | None = None, - args: tuple | None = None, - kwargs: dict | None = None, - retval: Any | None = None, - state: str | None = None, - **kwds: Any, -) -> None: - app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds) - - -@celeryd_init.connect -def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None: - app_base.on_celeryd_init(sender, conf, **kwargs) - - -@worker_init.connect -def on_worker_init(sender: Worker, **kwargs: Any) -> None: - EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits - - logger.info("worker_init signal received for consolidated background worker.") - - SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME) - pool_size = cast(int, sender.concurrency) # type: ignore - SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY) - - # Initialize Vespa httpx pool (needed for light worker tasks) - if MANAGED_VESPA: - httpx_init_vespa_pool( - sender.concurrency + EXTRA_CONCURRENCY, # type: ignore - ssl_cert=VESPA_CLOUD_CERT_PATH, - ssl_key=VESPA_CLOUD_KEY_PATH, - ) - else: - httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore - - app_base.wait_for_redis(sender, **kwargs) - app_base.wait_for_db(sender, **kwargs) - app_base.wait_for_vespa_or_shutdown(sender, **kwargs) - - # Less startup checks in multi-tenant case - if MULTI_TENANT: - return - - app_base.on_secondary_worker_init(sender, **kwargs) - - -@worker_ready.connect -def on_worker_ready(sender: Any, **kwargs: Any) -> None: - app_base.on_worker_ready(sender, **kwargs) - - -@worker_shutdown.connect -def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: - app_base.on_worker_shutdown(sender, **kwargs) - - -@worker_process_init.connect -def init_worker(**kwargs: Any) -> None: # noqa: ARG001 - SqlEngine.reset_engine() - - -@signals.setup_logging.connect -def on_setup_logging( - loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any -) -> None: - app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs) - - -base_bootsteps = app_base.get_bootsteps() -for bootstep in base_bootsteps: - celery_app.steps["worker"].add(bootstep) - -celery_app.autodiscover_tasks( - app_base.filter_task_modules( - [ - # Original background worker tasks - "onyx.background.celery.tasks.pruning", - "onyx.background.celery.tasks.monitoring", - "onyx.background.celery.tasks.user_file_processing", - "onyx.background.celery.tasks.llm_model_update", - # Light worker tasks - "onyx.background.celery.tasks.shared", - "onyx.background.celery.tasks.vespa", - "onyx.background.celery.tasks.connector_deletion", - "onyx.background.celery.tasks.doc_permission_syncing", - "onyx.background.celery.tasks.opensearch_migration", - # Docprocessing worker tasks - "onyx.background.celery.tasks.docprocessing", - # Docfetching worker tasks - "onyx.background.celery.tasks.docfetching", - # Sandbox cleanup tasks (isolated in build feature) - "onyx.server.features.build.sandbox.tasks", - ] - ) -) diff --git a/backend/onyx/background/celery/celery_utils.py b/backend/onyx/background/celery/celery_utils.py index 4e2e9fd2f03..977a9f0daa0 100644 --- a/backend/onyx/background/celery/celery_utils.py +++ b/backend/onyx/background/celery/celery_utils.py @@ -39,9 +39,13 @@ class SlimConnectorExtractionResult(BaseModel): - """Result of extracting document IDs and hierarchy nodes from a connector.""" + """Result of extracting document IDs and hierarchy nodes from a connector. - doc_ids: set[str] + raw_id_to_parent maps document ID → parent_hierarchy_raw_node_id (or None). + Use raw_id_to_parent.keys() wherever the old set of IDs was needed. + """ + + raw_id_to_parent: dict[str, str | None] hierarchy_nodes: list[HierarchyNode] @@ -93,30 +97,34 @@ def _get_failure_id(failure: ConnectorFailure) -> str | None: return None +class BatchResult(BaseModel): + raw_id_to_parent: dict[str, str | None] + hierarchy_nodes: list[HierarchyNode] + + def _extract_from_batch( doc_list: Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure], -) -> tuple[set[str], list[HierarchyNode]]: - """Separate a batch into document IDs and hierarchy nodes. +) -> BatchResult: + """Separate a batch into document IDs (with parent mapping) and hierarchy nodes. ConnectorFailure items have their failed document/entity IDs added to the - ID set so that failed-to-retrieve documents are not accidentally pruned. + ID dict so that failed-to-retrieve documents are not accidentally pruned. """ - ids: set[str] = set() + ids: dict[str, str | None] = {} hierarchy_nodes: list[HierarchyNode] = [] for item in doc_list: if isinstance(item, HierarchyNode): hierarchy_nodes.append(item) - ids.add(item.raw_node_id) elif isinstance(item, ConnectorFailure): failed_id = _get_failure_id(item) if failed_id: - ids.add(failed_id) + ids[failed_id] = None logger.warning( f"Failed to retrieve document {failed_id}: " f"{item.failure_message}" ) else: - ids.add(item.id) - return ids, hierarchy_nodes + ids[item.id] = item.parent_hierarchy_raw_node_id + return BatchResult(raw_id_to_parent=ids, hierarchy_nodes=hierarchy_nodes) def extract_ids_from_runnable_connector( @@ -132,7 +140,7 @@ def extract_ids_from_runnable_connector( Optionally, a callback can be passed to handle the length of each document batch. """ - all_connector_doc_ids: set[str] = set() + all_raw_id_to_parent: dict[str, str | None] = {} all_hierarchy_nodes: list[HierarchyNode] = [] # Sequence (covariant) lets all the specific list[...] iterator types unify here @@ -177,15 +185,18 @@ def extract_ids_from_runnable_connector( "extract_ids_from_runnable_connector: Stop signal detected" ) - batch_ids, batch_nodes = _extract_from_batch(doc_list) - all_connector_doc_ids.update(doc_batch_processing_func(batch_ids)) + batch_result = _extract_from_batch(doc_list) + batch_ids = batch_result.raw_id_to_parent + batch_nodes = batch_result.hierarchy_nodes + doc_batch_processing_func(batch_ids) + all_raw_id_to_parent.update(batch_ids) all_hierarchy_nodes.extend(batch_nodes) if callback: callback.progress("extract_ids_from_runnable_connector", len(batch_ids)) return SlimConnectorExtractionResult( - doc_ids=all_connector_doc_ids, + raw_id_to_parent=all_raw_id_to_parent, hierarchy_nodes=all_hierarchy_nodes, ) diff --git a/backend/onyx/background/celery/configs/background.py b/backend/onyx/background/celery/configs/background.py deleted file mode 100644 index 64350c8f2b7..00000000000 --- a/backend/onyx/background/celery/configs/background.py +++ /dev/null @@ -1,23 +0,0 @@ -import onyx.background.celery.configs.base as shared_config -from onyx.configs.app_configs import CELERY_WORKER_BACKGROUND_CONCURRENCY - -broker_url = shared_config.broker_url -broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup -broker_pool_limit = shared_config.broker_pool_limit -broker_transport_options = shared_config.broker_transport_options - -redis_socket_keepalive = shared_config.redis_socket_keepalive -redis_retry_on_timeout = shared_config.redis_retry_on_timeout -redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval - -result_backend = shared_config.result_backend -result_expires = shared_config.result_expires # 86400 seconds is the default - -task_default_priority = shared_config.task_default_priority -task_acks_late = shared_config.task_acks_late - -worker_concurrency = CELERY_WORKER_BACKGROUND_CONCURRENCY -worker_pool = "threads" -# Increased from 1 to 4 to handle fast light worker tasks more efficiently -# This allows the worker to prefetch multiple tasks per thread -worker_prefetch_multiplier = 4 diff --git a/backend/onyx/background/celery/tasks/hierarchyfetching/tasks.py b/backend/onyx/background/celery/tasks/hierarchyfetching/tasks.py index c445cc1a917..50b57560310 100644 --- a/backend/onyx/background/celery/tasks/hierarchyfetching/tasks.py +++ b/backend/onyx/background/celery/tasks/hierarchyfetching/tasks.py @@ -40,6 +40,7 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus +from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries from onyx.db.hierarchy import upsert_hierarchy_nodes_batch from onyx.db.models import ConnectorCredentialPair from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch @@ -289,6 +290,14 @@ def _process_batch() -> int: is_connector_public=is_connector_public, ) + upsert_hierarchy_node_cc_pair_entries( + db_session=db_session, + hierarchy_node_ids=[n.id for n in upserted_nodes], + connector_id=cc_pair.connector_id, + credential_id=cc_pair.credential_id, + commit=True, + ) + # Cache in Redis for fast ancestor resolution cache_entries = [ HierarchyNodeCacheEntry.from_db_model(node) for node in upserted_nodes diff --git a/backend/onyx/background/celery/tasks/opensearch_migration/constants.py b/backend/onyx/background/celery/tasks/opensearch_migration/constants.py index 1d4b2136a5a..b730b9a417f 100644 --- a/backend/onyx/background/celery/tasks/opensearch_migration/constants.py +++ b/backend/onyx/background/celery/tasks/opensearch_migration/constants.py @@ -11,6 +11,9 @@ # lock after its cleanup which happens at most after its soft timeout. # Constants corresponding to migrate_documents_from_vespa_to_opensearch_task. +from onyx.configs.app_configs import OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE + + MIGRATION_TASK_SOFT_TIME_LIMIT_S = 60 * 5 # 5 minutes. MIGRATION_TASK_TIME_LIMIT_S = 60 * 6 # 6 minutes. # The maximum time the lock can be held for. Will automatically be released @@ -44,7 +47,7 @@ # WARNING: Do not change these values without knowing what changes also need to # be made to OpenSearchTenantMigrationRecord. -GET_VESPA_CHUNKS_PAGE_SIZE = 500 +GET_VESPA_CHUNKS_PAGE_SIZE = OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE GET_VESPA_CHUNKS_SLICE_COUNT = 4 # String used to indicate in the vespa_visit_continuation_token mapping that the diff --git a/backend/onyx/background/celery/tasks/opensearch_migration/tasks.py b/backend/onyx/background/celery/tasks/opensearch_migration/tasks.py index 503af958b97..79807a1aca7 100644 --- a/backend/onyx/background/celery/tasks/opensearch_migration/tasks.py +++ b/backend/onyx/background/celery/tasks/opensearch_migration/tasks.py @@ -30,6 +30,7 @@ transform_vespa_chunks_to_opensearch_chunks, ) from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX +from onyx.configs.app_configs import VESPA_MIGRATION_REQUEST_TIMEOUT_S from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisLocks from onyx.db.engine.sql_engine import get_session_with_current_tenant @@ -47,6 +48,7 @@ from onyx.document_index.opensearch.opensearch_document_index import ( OpenSearchDocumentIndex, ) +from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex from onyx.indexing.models import IndexingSetting from onyx.redis.redis_pool import get_redis_client @@ -146,7 +148,12 @@ def migrate_chunks_from_vespa_to_opensearch_task( task_logger.error(err_str) return False - with get_session_with_current_tenant() as db_session: + with ( + get_session_with_current_tenant() as db_session, + get_vespa_http_client( + timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S + ) as vespa_client, + ): try_insert_opensearch_tenant_migration_record_with_commit(db_session) search_settings = get_current_search_settings(db_session) tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT) @@ -161,6 +168,7 @@ def migrate_chunks_from_vespa_to_opensearch_task( index_name=search_settings.index_name, tenant_state=tenant_state, large_chunks_enabled=False, + httpx_client=vespa_client, ) sanitized_doc_start_time = time.monotonic() diff --git a/backend/onyx/background/celery/tasks/pruning/tasks.py b/backend/onyx/background/celery/tasks/pruning/tasks.py index fadd7841ae5..baa0627754d 100644 --- a/backend/onyx/background/celery/tasks/pruning/tasks.py +++ b/backend/onyx/background/celery/tasks/pruning/tasks.py @@ -29,6 +29,7 @@ from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX +from onyx.configs.constants import DocumentSource from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask @@ -47,8 +48,15 @@ from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import SyncStatus from onyx.db.enums import SyncType +from onyx.db.hierarchy import delete_orphaned_hierarchy_nodes +from onyx.db.hierarchy import link_hierarchy_nodes_to_documents +from onyx.db.hierarchy import remove_stale_hierarchy_node_cc_pair_entries +from onyx.db.hierarchy import reparent_orphaned_hierarchy_nodes +from onyx.db.hierarchy import update_document_parent_hierarchy_nodes +from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries from onyx.db.hierarchy import upsert_hierarchy_nodes_batch from onyx.db.models import ConnectorCredentialPair +from onyx.db.models import HierarchyNode as DBHierarchyNode from onyx.db.sync_record import insert_sync_record from onyx.db.sync_record import update_sync_record_status from onyx.db.tag import delete_orphan_tags__no_commit @@ -57,6 +65,9 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch from onyx.redis.redis_hierarchy import ensure_source_node_exists +from onyx.redis.redis_hierarchy import evict_hierarchy_nodes_from_cache +from onyx.redis.redis_hierarchy import get_node_id_from_raw_id +from onyx.redis.redis_hierarchy import get_source_node_id_from_cache from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry from onyx.redis.redis_pool import get_redis_client from onyx.redis.redis_pool import get_redis_replica_client @@ -113,6 +124,38 @@ def progress(self, tag: str, amount: int) -> None: super().progress(tag, amount) +def _resolve_and_update_document_parents( + db_session: Session, + redis_client: Redis, + source: DocumentSource, + raw_id_to_parent: dict[str, str | None], +) -> None: + """Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id for + each document and bulk-update the DB. Mirrors the resolution logic in + run_docfetching.py.""" + source_node_id = get_source_node_id_from_cache(redis_client, db_session, source) + + resolved: dict[str, int | None] = {} + for doc_id, raw_parent_id in raw_id_to_parent.items(): + if raw_parent_id is None: + continue + node_id, found = get_node_id_from_raw_id(redis_client, source, raw_parent_id) + resolved[doc_id] = node_id if found else source_node_id + + if not resolved: + return + + update_document_parent_hierarchy_nodes( + db_session=db_session, + doc_parent_map=resolved, + commit=True, + ) + task_logger.info( + f"Pruning: resolved and updated parent hierarchy for " + f"{len(resolved)} documents (source={source.value})" + ) + + """Jobs / utils for kicking off pruning tasks.""" @@ -535,33 +578,42 @@ def connector_pruning_generator_task( extraction_result = extract_ids_from_runnable_connector( runnable_connector, callback ) - all_connector_doc_ids = extraction_result.doc_ids + all_connector_doc_ids = extraction_result.raw_id_to_parent # Process hierarchy nodes (same as docfetching): # upsert to Postgres and cache in Redis + source = cc_pair.connector.source + redis_client = get_redis_client(tenant_id=tenant_id) + + ensure_source_node_exists(redis_client, db_session, source) + + upserted_nodes: list[DBHierarchyNode] = [] if extraction_result.hierarchy_nodes: is_connector_public = cc_pair.access_type == AccessType.PUBLIC - redis_client = get_redis_client(tenant_id=tenant_id) - ensure_source_node_exists( - redis_client, db_session, cc_pair.connector.source - ) - upserted_nodes = upsert_hierarchy_nodes_batch( db_session=db_session, nodes=extraction_result.hierarchy_nodes, - source=cc_pair.connector.source, + source=source, commit=True, is_connector_public=is_connector_public, ) + upsert_hierarchy_node_cc_pair_entries( + db_session=db_session, + hierarchy_node_ids=[n.id for n in upserted_nodes], + connector_id=connector_id, + credential_id=credential_id, + commit=True, + ) + cache_entries = [ HierarchyNodeCacheEntry.from_db_model(node) for node in upserted_nodes ] cache_hierarchy_nodes_batch( redis_client=redis_client, - source=cc_pair.connector.source, + source=source, entries=cache_entries, ) @@ -570,6 +622,25 @@ def connector_pruning_generator_task( f"hierarchy nodes for cc_pair={cc_pair_id}" ) + # Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id + # and bulk-update documents, mirroring the docfetching resolution + _resolve_and_update_document_parents( + db_session=db_session, + redis_client=redis_client, + source=source, + raw_id_to_parent=all_connector_doc_ids, + ) + + # Link hierarchy nodes to documents for sources where pages can be + # both hierarchy nodes AND documents (e.g. Notion, Confluence) + all_doc_id_list = list(all_connector_doc_ids.keys()) + link_hierarchy_nodes_to_documents( + db_session=db_session, + document_ids=all_doc_id_list, + source=source, + commit=True, + ) + # a list of docs in our local index all_indexed_document_ids = { doc.id @@ -581,7 +652,9 @@ def connector_pruning_generator_task( } # generate list of docs to remove (no longer in the source) - doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids) + doc_ids_to_remove = list( + all_indexed_document_ids - all_connector_doc_ids.keys() + ) task_logger.info( "Pruning set collected: " @@ -605,6 +678,43 @@ def connector_pruning_generator_task( ) redis_connector.prune.generator_complete = tasks_generated + + # --- Hierarchy node pruning --- + live_node_ids = {n.id for n in upserted_nodes} + stale_removed = remove_stale_hierarchy_node_cc_pair_entries( + db_session=db_session, + connector_id=connector_id, + credential_id=credential_id, + live_hierarchy_node_ids=live_node_ids, + commit=True, + ) + deleted_raw_ids = delete_orphaned_hierarchy_nodes( + db_session=db_session, + source=source, + commit=True, + ) + reparented_nodes = reparent_orphaned_hierarchy_nodes( + db_session=db_session, + source=source, + commit=True, + ) + if deleted_raw_ids: + evict_hierarchy_nodes_from_cache(redis_client, source, deleted_raw_ids) + if reparented_nodes: + reparented_cache_entries = [ + HierarchyNodeCacheEntry.from_db_model(node) + for node in reparented_nodes + ] + cache_hierarchy_nodes_batch( + redis_client, source, reparented_cache_entries + ) + if stale_removed or deleted_raw_ids or reparented_nodes: + task_logger.info( + f"Hierarchy node pruning: cc_pair={cc_pair_id} " + f"stale_entries_removed={stale_removed} " + f"nodes_deleted={len(deleted_raw_ids)} " + f"nodes_reparented={len(reparented_nodes)}" + ) except Exception as e: task_logger.exception( f"Pruning exceptioned: cc_pair={cc_pair_id} " diff --git a/backend/onyx/background/celery/tasks/user_file_processing/tasks.py b/backend/onyx/background/celery/tasks/user_file_processing/tasks.py index 6b1b3290eef..6208f884817 100644 --- a/backend/onyx/background/celery/tasks/user_file_processing/tasks.py +++ b/backend/onyx/background/celery/tasks/user_file_processing/tasks.py @@ -12,9 +12,9 @@ from redis.lock import Lock as RedisLock from retry import retry from sqlalchemy import select -from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session +from onyx.access.access import build_access_for_user_files from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.celery_redis import celery_get_queue_length from onyx.background.celery.celery_utils import httpx_init_vespa_pool @@ -43,7 +43,9 @@ from onyx.db.models import UserFile from onyx.db.search_settings import get_active_search_settings from onyx.db.search_settings import get_active_search_settings_list +from onyx.db.user_file import fetch_user_files_with_access_relationships from onyx.document_index.factory import get_all_document_indices +from onyx.document_index.interfaces import VespaDocumentFields from onyx.document_index.interfaces import VespaDocumentUserFields from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT from onyx.file_store.file_store import get_default_file_store @@ -54,6 +56,7 @@ from onyx.indexing.embedder import DefaultIndexingEmbedder from onyx.indexing.indexing_pipeline import run_indexing_pipeline from onyx.redis.redis_pool import get_redis_client +from onyx.utils.variable_functionality import global_version def _as_uuid(value: str | UUID) -> UUID: @@ -414,7 +417,7 @@ def _process_user_file_with_indexing( raise RuntimeError(f"Indexing pipeline failed for user file {user_file_id}") -def _process_user_file_impl( +def process_user_file_impl( *, user_file_id: str, tenant_id: str, redis_locking: bool ) -> None: """Core implementation for processing a single user file. @@ -423,7 +426,7 @@ def _process_user_file_impl( queued-key guard (Celery path). When redis_locking=False, skips all Redis operations (BackgroundTask path). """ - task_logger.info(f"_process_user_file_impl - Starting id={user_file_id}") + task_logger.info(f"process_user_file_impl - Starting id={user_file_id}") start = time.monotonic() file_lock: RedisLock | None = None @@ -436,7 +439,7 @@ def _process_user_file_impl( ) if file_lock is not None and not file_lock.acquire(blocking=False): task_logger.info( - f"_process_user_file_impl - Lock held, skipping user_file_id={user_file_id}" + f"process_user_file_impl - Lock held, skipping user_file_id={user_file_id}" ) return @@ -446,13 +449,16 @@ def _process_user_file_impl( uf = db_session.get(UserFile, _as_uuid(user_file_id)) if not uf: task_logger.warning( - f"_process_user_file_impl - UserFile not found id={user_file_id}" + f"process_user_file_impl - UserFile not found id={user_file_id}" ) return - if uf.status != UserFileStatus.PROCESSING: + if uf.status not in ( + UserFileStatus.PROCESSING, + UserFileStatus.INDEXING, + ): task_logger.info( - f"_process_user_file_impl - Skipping id={user_file_id} status={uf.status}" + f"process_user_file_impl - Skipping id={user_file_id} status={uf.status}" ) return @@ -489,7 +495,7 @@ def _process_user_file_impl( except Exception as e: task_logger.exception( - f"_process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}" + f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}" ) current_user_file = db_session.get(UserFile, _as_uuid(user_file_id)) if ( @@ -503,7 +509,7 @@ def _process_user_file_impl( elapsed = time.monotonic() - start task_logger.info( - f"_process_user_file_impl - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s" + f"process_user_file_impl - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s" ) except Exception as e: with get_session_with_current_tenant() as db_session: @@ -515,8 +521,9 @@ def _process_user_file_impl( db_session.commit() task_logger.exception( - f"_process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}" + f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}" ) + raise finally: if file_lock is not None and file_lock.owned(): file_lock.release() @@ -530,7 +537,7 @@ def _process_user_file_impl( def process_single_user_file( self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001 ) -> None: - _process_user_file_impl( + process_user_file_impl( user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True ) @@ -585,7 +592,7 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None: return None -def _delete_user_file_impl( +def delete_user_file_impl( *, user_file_id: str, tenant_id: str, redis_locking: bool ) -> None: """Core implementation for deleting a single user file. @@ -593,7 +600,7 @@ def _delete_user_file_impl( When redis_locking=True, acquires a per-file Redis lock (Celery path). When redis_locking=False, skips Redis operations (BackgroundTask path). """ - task_logger.info(f"_delete_user_file_impl - Starting id={user_file_id}") + task_logger.info(f"delete_user_file_impl - Starting id={user_file_id}") file_lock: RedisLock | None = None if redis_locking: @@ -604,7 +611,7 @@ def _delete_user_file_impl( ) if file_lock is not None and not file_lock.acquire(blocking=False): task_logger.info( - f"_delete_user_file_impl - Lock held, skipping user_file_id={user_file_id}" + f"delete_user_file_impl - Lock held, skipping user_file_id={user_file_id}" ) return @@ -613,7 +620,7 @@ def _delete_user_file_impl( user_file = db_session.get(UserFile, _as_uuid(user_file_id)) if not user_file: task_logger.info( - f"_delete_user_file_impl - User file not found id={user_file_id}" + f"delete_user_file_impl - User file not found id={user_file_id}" ) return @@ -662,16 +669,17 @@ def _delete_user_file_impl( ) except Exception as e: task_logger.exception( - f"_delete_user_file_impl - Error deleting file id={user_file.id} - {e.__class__.__name__}" + f"delete_user_file_impl - Error deleting file id={user_file.id} - {e.__class__.__name__}" ) db_session.delete(user_file) db_session.commit() - task_logger.info(f"_delete_user_file_impl - Completed id={user_file_id}") + task_logger.info(f"delete_user_file_impl - Completed id={user_file_id}") except Exception as e: task_logger.exception( - f"_delete_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}" + f"delete_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}" ) + raise finally: if file_lock is not None and file_lock.owned(): file_lock.release() @@ -685,7 +693,7 @@ def _delete_user_file_impl( def process_single_user_file_delete( self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001 ) -> None: - _delete_user_file_impl( + delete_user_file_impl( user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True ) @@ -759,7 +767,7 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None: return None -def _project_sync_user_file_impl( +def project_sync_user_file_impl( *, user_file_id: str, tenant_id: str, redis_locking: bool ) -> None: """Core implementation for syncing a user file's project/persona metadata. @@ -768,7 +776,7 @@ def _project_sync_user_file_impl( queued-key guard (Celery path). When redis_locking=False, skips Redis operations (BackgroundTask path). """ - task_logger.info(f"_project_sync_user_file_impl - Starting id={user_file_id}") + task_logger.info(f"project_sync_user_file_impl - Starting id={user_file_id}") file_lock: RedisLock | None = None if redis_locking: @@ -780,20 +788,21 @@ def _project_sync_user_file_impl( ) if file_lock is not None and not file_lock.acquire(blocking=False): task_logger.info( - f"_project_sync_user_file_impl - Lock held, skipping user_file_id={user_file_id}" + f"project_sync_user_file_impl - Lock held, skipping user_file_id={user_file_id}" ) return try: with get_session_with_current_tenant() as db_session: - user_file = db_session.execute( - select(UserFile) - .where(UserFile.id == _as_uuid(user_file_id)) - .options(selectinload(UserFile.assistants)) - ).scalar_one_or_none() + user_files = fetch_user_files_with_access_relationships( + [user_file_id], + db_session, + eager_load_groups=global_version.is_ee_version(), + ) + user_file = user_files[0] if user_files else None if not user_file: task_logger.info( - f"_project_sync_user_file_impl - User file not found id={user_file_id}" + f"project_sync_user_file_impl - User file not found id={user_file_id}" ) return @@ -818,12 +827,21 @@ def _project_sync_user_file_impl( project_ids = [project.id for project in user_file.projects] persona_ids = [p.id for p in user_file.assistants if not p.deleted] + + file_id_str = str(user_file.id) + access_map = build_access_for_user_files([user_file]) + access = access_map.get(file_id_str) + for retry_document_index in retry_document_indices: retry_document_index.update_single( - doc_id=str(user_file.id), + doc_id=file_id_str, tenant_id=tenant_id, chunk_count=user_file.chunk_count, - fields=None, + fields=( + VespaDocumentFields(access=access) + if access is not None + else None + ), user_fields=VespaDocumentUserFields( user_projects=project_ids, personas=persona_ids, @@ -831,7 +849,7 @@ def _project_sync_user_file_impl( ) task_logger.info( - f"_project_sync_user_file_impl - User file id={user_file_id}" + f"project_sync_user_file_impl - User file id={user_file_id}" ) user_file.needs_project_sync = False @@ -844,8 +862,9 @@ def _project_sync_user_file_impl( except Exception as e: task_logger.exception( - f"_project_sync_user_file_impl - Error syncing project for file id={user_file_id} - {e.__class__.__name__}" + f"project_sync_user_file_impl - Error syncing project for file id={user_file_id} - {e.__class__.__name__}" ) + raise finally: if file_lock is not None and file_lock.owned(): file_lock.release() @@ -859,6 +878,6 @@ def _project_sync_user_file_impl( def process_single_user_file_project_sync( self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001 ) -> None: - _project_sync_user_file_impl( + project_sync_user_file_impl( user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True ) diff --git a/backend/onyx/background/celery/versioned_apps/background.py b/backend/onyx/background/celery/versioned_apps/background.py deleted file mode 100644 index 2d060068958..00000000000 --- a/backend/onyx/background/celery/versioned_apps/background.py +++ /dev/null @@ -1,10 +0,0 @@ -from celery import Celery - -from onyx.utils.variable_functionality import fetch_versioned_implementation -from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable - -set_is_ee_based_on_env_variable() -app: Celery = fetch_versioned_implementation( - "onyx.background.celery.apps.background", - "celery_app", -) diff --git a/backend/onyx/background/indexing/run_docfetching.py b/backend/onyx/background/indexing/run_docfetching.py index 52dc393cbeb..975d444a2b7 100644 --- a/backend/onyx/background/indexing/run_docfetching.py +++ b/backend/onyx/background/indexing/run_docfetching.py @@ -45,6 +45,7 @@ from onyx.db.enums import IndexingStatus from onyx.db.enums import IndexModelStatus from onyx.db.enums import ProcessingMode +from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries from onyx.db.hierarchy import upsert_hierarchy_nodes_batch from onyx.db.index_attempt import create_index_attempt_error from onyx.db.index_attempt import get_index_attempt @@ -58,8 +59,6 @@ from onyx.file_store.document_batch_storage import get_document_batch_storage from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.indexing.indexing_pipeline import index_doc_batch_prepare -from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres -from onyx.indexing.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch from onyx.redis.redis_hierarchy import ensure_source_node_exists from onyx.redis.redis_hierarchy import get_node_id_from_raw_id @@ -71,6 +70,8 @@ ) from onyx.utils.logger import setup_logger from onyx.utils.middleware import make_randomized_onyx_request_id +from onyx.utils.postgres_sanitization import sanitize_document_for_postgres +from onyx.utils.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres from onyx.utils.variable_functionality import global_version from shared_configs.configs import MULTI_TENANT from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR @@ -587,6 +588,14 @@ def connector_document_extraction( is_connector_public=is_connector_public, ) + upsert_hierarchy_node_cc_pair_entries( + db_session=db_session, + hierarchy_node_ids=[n.id for n in upserted_nodes], + connector_id=db_connector.id, + credential_id=db_credential.id, + commit=True, + ) + # Cache in Redis for fast ancestor resolution during doc processing redis_client = get_redis_client(tenant_id=tenant_id) cache_entries = [ diff --git a/backend/onyx/background/periodic_poller.py b/backend/onyx/background/periodic_poller.py new file mode 100644 index 00000000000..59f0eadce4d --- /dev/null +++ b/backend/onyx/background/periodic_poller.py @@ -0,0 +1,307 @@ +"""Periodic poller for NO_VECTOR_DB deployments. + +Replaces Celery Beat and background workers with a lightweight daemon thread +that runs from the API server process. Two responsibilities: + +1. Recovery polling (every 30 s): re-processes user files stuck in + PROCESSING / DELETING / needs_sync states via the drain loops defined + in ``task_utils.py``. + +2. Periodic task execution (configurable intervals): runs LLM model updates + and scheduled evals at their configured cadences, with Postgres advisory + lock deduplication across multiple API server instances. +""" + +import threading +import time +from collections.abc import Callable +from dataclasses import dataclass +from dataclasses import field + +from onyx.utils.logger import setup_logger + +logger = setup_logger() + +RECOVERY_INTERVAL_SECONDS = 30 +PERIODIC_TASK_LOCK_BASE = 20_000 +PERIODIC_TASK_KV_PREFIX = "periodic_poller:last_claimed:" + + +# ------------------------------------------------------------------ +# Periodic task definitions +# ------------------------------------------------------------------ + + +_NEVER_RAN: float = -1e18 + + +@dataclass +class _PeriodicTaskDef: + name: str + interval_seconds: float + lock_id: int + run_fn: Callable[[], None] + last_run_at: float = field(default=_NEVER_RAN) + + +def _run_auto_llm_update() -> None: + from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL + + if not AUTO_LLM_CONFIG_URL: + return + + from onyx.db.engine.sql_engine import get_session_with_current_tenant + from onyx.llm.well_known_providers.auto_update_service import ( + sync_llm_models_from_github, + ) + + with get_session_with_current_tenant() as db_session: + sync_llm_models_from_github(db_session) + + +def _run_cache_cleanup() -> None: + from onyx.cache.postgres_backend import cleanup_expired_cache_entries + + cleanup_expired_cache_entries() + + +def _run_scheduled_eval() -> None: + from onyx.configs.app_configs import BRAINTRUST_API_KEY + from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES + from onyx.configs.app_configs import SCHEDULED_EVAL_PERMISSIONS_EMAIL + from onyx.configs.app_configs import SCHEDULED_EVAL_PROJECT + + if not all( + [ + BRAINTRUST_API_KEY, + SCHEDULED_EVAL_PROJECT, + SCHEDULED_EVAL_DATASET_NAMES, + SCHEDULED_EVAL_PERMISSIONS_EMAIL, + ] + ): + return + + from datetime import datetime + from datetime import timezone + + from onyx.evals.eval import run_eval + from onyx.evals.models import EvalConfigurationOptions + + run_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d") + for dataset_name in SCHEDULED_EVAL_DATASET_NAMES: + try: + run_eval( + configuration=EvalConfigurationOptions( + search_permissions_email=SCHEDULED_EVAL_PERMISSIONS_EMAIL, + dataset_name=dataset_name, + no_send_logs=False, + braintrust_project=SCHEDULED_EVAL_PROJECT, + experiment_name=f"{dataset_name} - {run_timestamp}", + ), + remote_dataset_name=dataset_name, + ) + except Exception: + logger.exception( + f"Periodic poller - Failed scheduled eval for dataset {dataset_name}" + ) + + +_CACHE_CLEANUP_INTERVAL_SECONDS = 300 + + +def _build_periodic_tasks() -> list[_PeriodicTaskDef]: + from onyx.cache.interface import CacheBackendType + from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL + from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS + from onyx.configs.app_configs import CACHE_BACKEND + from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES + + tasks: list[_PeriodicTaskDef] = [] + if CACHE_BACKEND == CacheBackendType.POSTGRES: + tasks.append( + _PeriodicTaskDef( + name="cache-cleanup", + interval_seconds=_CACHE_CLEANUP_INTERVAL_SECONDS, + lock_id=PERIODIC_TASK_LOCK_BASE + 2, + run_fn=_run_cache_cleanup, + ) + ) + if AUTO_LLM_CONFIG_URL: + tasks.append( + _PeriodicTaskDef( + name="auto-llm-update", + interval_seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS, + lock_id=PERIODIC_TASK_LOCK_BASE, + run_fn=_run_auto_llm_update, + ) + ) + if SCHEDULED_EVAL_DATASET_NAMES: + tasks.append( + _PeriodicTaskDef( + name="scheduled-eval", + interval_seconds=7 * 24 * 3600, + lock_id=PERIODIC_TASK_LOCK_BASE + 1, + run_fn=_run_scheduled_eval, + ) + ) + return tasks + + +# ------------------------------------------------------------------ +# Periodic task runner with advisory-lock-guarded claim +# ------------------------------------------------------------------ + + +def _try_claim_task(task_def: _PeriodicTaskDef) -> bool: + """Atomically check whether *task_def* should run and record a claim. + + Uses a transaction-scoped advisory lock for atomicity combined with a + ``KVStore`` timestamp for cross-instance dedup. The DB session is held + only for this brief claim transaction, not during task execution. + """ + from datetime import datetime + from datetime import timezone + + from sqlalchemy import text + + from onyx.db.engine.sql_engine import get_session_with_current_tenant + from onyx.db.models import KVStore + + kv_key = PERIODIC_TASK_KV_PREFIX + task_def.name + + with get_session_with_current_tenant() as db_session: + acquired = db_session.execute( + text("SELECT pg_try_advisory_xact_lock(:id)"), + {"id": task_def.lock_id}, + ).scalar() + if not acquired: + return False + + row = db_session.query(KVStore).filter_by(key=kv_key).first() + if row and row.value is not None: + last_claimed = datetime.fromisoformat(str(row.value)) + elapsed = (datetime.now(timezone.utc) - last_claimed).total_seconds() + if elapsed < task_def.interval_seconds: + return False + + now_ts = datetime.now(timezone.utc).isoformat() + if row: + row.value = now_ts + else: + db_session.add(KVStore(key=kv_key, value=now_ts)) + db_session.commit() + + return True + + +def _try_run_periodic_task(task_def: _PeriodicTaskDef) -> None: + """Run *task_def* if its interval has elapsed and no peer holds the lock.""" + now = time.monotonic() + if now - task_def.last_run_at < task_def.interval_seconds: + return + + if not _try_claim_task(task_def): + return + + try: + task_def.run_fn() + task_def.last_run_at = now + except Exception: + logger.exception( + f"Periodic poller - Error running periodic task {task_def.name}" + ) + + +# ------------------------------------------------------------------ +# Recovery / drain loop runner +# ------------------------------------------------------------------ + + +def _run_drain_loops(tenant_id: str) -> None: + from onyx.background.task_utils import drain_delete_loop + from onyx.background.task_utils import drain_processing_loop + from onyx.background.task_utils import drain_project_sync_loop + + drain_processing_loop(tenant_id) + drain_delete_loop(tenant_id) + drain_project_sync_loop(tenant_id) + + +# ------------------------------------------------------------------ +# Startup recovery (10g) +# ------------------------------------------------------------------ + + +def recover_stuck_user_files(tenant_id: str) -> None: + """Run all drain loops once to re-process files left in intermediate states. + + Called from ``lifespan()`` on startup when ``DISABLE_VECTOR_DB`` is set. + """ + logger.info("recover_stuck_user_files - Checking for stuck user files") + try: + _run_drain_loops(tenant_id) + except Exception: + logger.exception("recover_stuck_user_files - Error during recovery") + + +# ------------------------------------------------------------------ +# Daemon thread (10f) +# ------------------------------------------------------------------ + +_shutdown_event = threading.Event() +_poller_thread: threading.Thread | None = None + + +def _poller_loop(tenant_id: str) -> None: + from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR + + CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) + + periodic_tasks = _build_periodic_tasks() + logger.info( + f"Periodic poller started with {len(periodic_tasks)} periodic task(s): " + f"{[t.name for t in periodic_tasks]}" + ) + + while not _shutdown_event.is_set(): + try: + _run_drain_loops(tenant_id) + except Exception: + logger.exception("Periodic poller - Error in recovery polling") + + for task_def in periodic_tasks: + try: + _try_run_periodic_task(task_def) + except Exception: + logger.exception( + f"Periodic poller - Unhandled error checking task {task_def.name}" + ) + + _shutdown_event.wait(RECOVERY_INTERVAL_SECONDS) + + +def start_periodic_poller(tenant_id: str) -> None: + """Start the periodic poller daemon thread.""" + global _poller_thread # noqa: PLW0603 + _shutdown_event.clear() + _poller_thread = threading.Thread( + target=_poller_loop, + args=(tenant_id,), + daemon=True, + name="no-vectordb-periodic-poller", + ) + _poller_thread.start() + logger.info("Periodic poller thread started") + + +def stop_periodic_poller() -> None: + """Signal the periodic poller to stop and wait for it to exit.""" + global _poller_thread # noqa: PLW0603 + if _poller_thread is None: + return + _shutdown_event.set() + _poller_thread.join(timeout=10) + if _poller_thread.is_alive(): + logger.warning("Periodic poller thread did not stop within timeout") + _poller_thread = None + logger.info("Periodic poller thread stopped") diff --git a/backend/onyx/background/task_utils.py b/backend/onyx/background/task_utils.py index a49a56745e6..191f0486140 100644 --- a/backend/onyx/background/task_utils.py +++ b/backend/onyx/background/task_utils.py @@ -1,3 +1,33 @@ +"""Background task utilities. + +Contains query-history report helpers (used by all deployment modes) and +in-process background task execution helpers for NO_VECTOR_DB mode: + +- Atomic claim-and-mark helpers that prevent duplicate processing +- Drain loops that process all pending user file work + +Each claim function runs a short-lived transaction: SELECT ... FOR UPDATE +SKIP LOCKED, UPDATE the row to remove it from future queries, COMMIT. +After the commit the row lock is released, but the row is no longer +eligible for re-claiming. No long-lived sessions or advisory locks. +""" + +from uuid import UUID + +import sqlalchemy as sa +from sqlalchemy import select +from sqlalchemy.orm import Session + +from onyx.db.enums import UserFileStatus +from onyx.db.models import UserFile +from onyx.utils.logger import setup_logger + +logger = setup_logger() + +# ------------------------------------------------------------------ +# Query-history report helpers (pre-existing, used by all modes) +# ------------------------------------------------------------------ + QUERY_REPORT_NAME_PREFIX = "query-history" @@ -9,3 +39,168 @@ def construct_query_history_report_name( def extract_task_id_from_query_history_report_name(name: str) -> str: return name.removeprefix(f"{QUERY_REPORT_NAME_PREFIX}-").removesuffix(".csv") + + +# ------------------------------------------------------------------ +# Atomic claim-and-mark helpers +# ------------------------------------------------------------------ +# Each function runs inside a single short-lived session/transaction: +# 1. SELECT ... FOR UPDATE SKIP LOCKED (locks one eligible row) +# 2. UPDATE the row so it is no longer eligible +# 3. COMMIT (releases the row lock) +# After the commit, no other drain loop can claim the same row. + + +def _claim_next_processing_file(db_session: Session) -> UUID | None: + """Claim the next PROCESSING file by transitioning it to INDEXING. + + Returns the file id, or None when no eligible files remain. + """ + file_id = db_session.execute( + select(UserFile.id) + .where(UserFile.status == UserFileStatus.PROCESSING) + .order_by(UserFile.created_at) + .limit(1) + .with_for_update(skip_locked=True) + ).scalar_one_or_none() + if file_id is None: + return None + + db_session.execute( + sa.update(UserFile) + .where(UserFile.id == file_id) + .values(status=UserFileStatus.INDEXING) + ) + db_session.commit() + return file_id + + +def _claim_next_deleting_file( + db_session: Session, + exclude_ids: set[UUID] | None = None, +) -> UUID | None: + """Claim the next DELETING file. + + No status transition needed — the impl deletes the row on success. + The short-lived FOR UPDATE lock prevents concurrent claims. + *exclude_ids* prevents re-processing the same file if the impl fails. + """ + stmt = ( + select(UserFile.id) + .where(UserFile.status == UserFileStatus.DELETING) + .order_by(UserFile.created_at) + .limit(1) + .with_for_update(skip_locked=True) + ) + if exclude_ids: + stmt = stmt.where(UserFile.id.notin_(exclude_ids)) + file_id = db_session.execute(stmt).scalar_one_or_none() + db_session.commit() + return file_id + + +def _claim_next_sync_file( + db_session: Session, + exclude_ids: set[UUID] | None = None, +) -> UUID | None: + """Claim the next file needing project/persona sync. + + No status transition needed — the impl clears the sync flags on + success. The short-lived FOR UPDATE lock prevents concurrent claims. + *exclude_ids* prevents re-processing the same file if the impl fails. + """ + stmt = ( + select(UserFile.id) + .where( + sa.and_( + sa.or_( + UserFile.needs_project_sync.is_(True), + UserFile.needs_persona_sync.is_(True), + ), + UserFile.status == UserFileStatus.COMPLETED, + ) + ) + .order_by(UserFile.created_at) + .limit(1) + .with_for_update(skip_locked=True) + ) + if exclude_ids: + stmt = stmt.where(UserFile.id.notin_(exclude_ids)) + file_id = db_session.execute(stmt).scalar_one_or_none() + db_session.commit() + return file_id + + +# ------------------------------------------------------------------ +# Drain loops — process *all* pending work of each type +# ------------------------------------------------------------------ + + +def drain_processing_loop(tenant_id: str) -> None: + """Process all pending PROCESSING user files.""" + from onyx.background.celery.tasks.user_file_processing.tasks import ( + process_user_file_impl, + ) + from onyx.db.engine.sql_engine import get_session_with_current_tenant + + while True: + with get_session_with_current_tenant() as session: + file_id = _claim_next_processing_file(session) + if file_id is None: + break + try: + process_user_file_impl( + user_file_id=str(file_id), + tenant_id=tenant_id, + redis_locking=False, + ) + except Exception: + logger.exception(f"Failed to process user file {file_id}") + + +def drain_delete_loop(tenant_id: str) -> None: + """Delete all pending DELETING user files.""" + from onyx.background.celery.tasks.user_file_processing.tasks import ( + delete_user_file_impl, + ) + from onyx.db.engine.sql_engine import get_session_with_current_tenant + + failed: set[UUID] = set() + while True: + with get_session_with_current_tenant() as session: + file_id = _claim_next_deleting_file(session, exclude_ids=failed) + if file_id is None: + break + try: + delete_user_file_impl( + user_file_id=str(file_id), + tenant_id=tenant_id, + redis_locking=False, + ) + except Exception: + logger.exception(f"Failed to delete user file {file_id}") + failed.add(file_id) + + +def drain_project_sync_loop(tenant_id: str) -> None: + """Sync all pending project/persona metadata for user files.""" + from onyx.background.celery.tasks.user_file_processing.tasks import ( + project_sync_user_file_impl, + ) + from onyx.db.engine.sql_engine import get_session_with_current_tenant + + failed: set[UUID] = set() + while True: + with get_session_with_current_tenant() as session: + file_id = _claim_next_sync_file(session, exclude_ids=failed) + if file_id is None: + break + try: + project_sync_user_file_impl( + user_file_id=str(file_id), + tenant_id=tenant_id, + redis_locking=False, + ) + except Exception: + logger.exception(f"Failed to sync user file {file_id}") + failed.add(file_id) diff --git a/backend/onyx/cache/factory.py b/backend/onyx/cache/factory.py new file mode 100644 index 00000000000..6b7c6694f50 --- /dev/null +++ b/backend/onyx/cache/factory.py @@ -0,0 +1,51 @@ +from collections.abc import Callable + +from onyx.cache.interface import CacheBackend +from onyx.cache.interface import CacheBackendType +from onyx.configs.app_configs import CACHE_BACKEND + + +def _build_redis_backend(tenant_id: str) -> CacheBackend: + from onyx.cache.redis_backend import RedisCacheBackend + from onyx.redis.redis_pool import redis_pool + + return RedisCacheBackend(redis_pool.get_client(tenant_id)) + + +def _build_postgres_backend(tenant_id: str) -> CacheBackend: + from onyx.cache.postgres_backend import PostgresCacheBackend + + return PostgresCacheBackend(tenant_id) + + +_BACKEND_BUILDERS: dict[CacheBackendType, Callable[[str], CacheBackend]] = { + CacheBackendType.REDIS: _build_redis_backend, + CacheBackendType.POSTGRES: _build_postgres_backend, +} + + +def get_cache_backend(*, tenant_id: str | None = None) -> CacheBackend: + """Return a tenant-aware ``CacheBackend``. + + If *tenant_id* is ``None``, the current tenant is read from the + thread-local context variable (same behaviour as ``get_redis_client``). + """ + if tenant_id is None: + from shared_configs.contextvars import get_current_tenant_id + + tenant_id = get_current_tenant_id() + + builder = _BACKEND_BUILDERS.get(CACHE_BACKEND) + if builder is None: + raise ValueError( + f"Unsupported CACHE_BACKEND={CACHE_BACKEND!r}. " + f"Supported values: {[t.value for t in CacheBackendType]}" + ) + return builder(tenant_id) + + +def get_shared_cache_backend() -> CacheBackend: + """Return a ``CacheBackend`` in the shared (cross-tenant) namespace.""" + from shared_configs.configs import DEFAULT_REDIS_PREFIX + + return get_cache_backend(tenant_id=DEFAULT_REDIS_PREFIX) diff --git a/backend/onyx/cache/interface.py b/backend/onyx/cache/interface.py new file mode 100644 index 00000000000..810d4cece9e --- /dev/null +++ b/backend/onyx/cache/interface.py @@ -0,0 +1,115 @@ +import abc +from enum import Enum + +from redis.exceptions import RedisError +from sqlalchemy.exc import SQLAlchemyError + +TTL_KEY_NOT_FOUND = -2 +TTL_NO_EXPIRY = -1 + +CACHE_TRANSIENT_ERRORS: tuple[type[Exception], ...] = (RedisError, SQLAlchemyError) +"""Exception types that represent transient cache connectivity / operational +failures. Callers that want to fail-open (or fail-closed) on cache errors +should catch this tuple instead of bare ``Exception``. + +When adding a new ``CacheBackend`` implementation, add its transient error +base class(es) here so all call-sites pick it up automatically.""" + + +class CacheBackendType(str, Enum): + REDIS = "redis" + POSTGRES = "postgres" + + +class CacheLock(abc.ABC): + """Abstract distributed lock returned by CacheBackend.lock().""" + + @abc.abstractmethod + def acquire( + self, + blocking: bool = True, + blocking_timeout: float | None = None, + ) -> bool: + raise NotImplementedError + + @abc.abstractmethod + def release(self) -> None: + raise NotImplementedError + + @abc.abstractmethod + def owned(self) -> bool: + raise NotImplementedError + + def __enter__(self) -> "CacheLock": + if not self.acquire(): + raise RuntimeError("Failed to acquire lock") + return self + + def __exit__(self, *args: object) -> None: + self.release() + + +class CacheBackend(abc.ABC): + """Thin abstraction over a key-value cache with TTL, locks, and blocking lists. + + Covers the subset of Redis operations used outside of Celery. When + CACHE_BACKEND=postgres, a PostgreSQL-backed implementation is used instead. + """ + + # -- basic key/value --------------------------------------------------- + + @abc.abstractmethod + def get(self, key: str) -> bytes | None: + raise NotImplementedError + + @abc.abstractmethod + def set( + self, + key: str, + value: str | bytes | int | float, + ex: int | None = None, + ) -> None: + raise NotImplementedError + + @abc.abstractmethod + def delete(self, key: str) -> None: + raise NotImplementedError + + @abc.abstractmethod + def exists(self, key: str) -> bool: + raise NotImplementedError + + # -- TTL --------------------------------------------------------------- + + @abc.abstractmethod + def expire(self, key: str, seconds: int) -> None: + raise NotImplementedError + + @abc.abstractmethod + def ttl(self, key: str) -> int: + """Return remaining TTL in seconds. + + Returns ``TTL_NO_EXPIRY`` (-1) if key exists without expiry, + ``TTL_KEY_NOT_FOUND`` (-2) if key is missing or expired. + """ + raise NotImplementedError + + # -- distributed lock -------------------------------------------------- + + @abc.abstractmethod + def lock(self, name: str, timeout: float | None = None) -> CacheLock: + raise NotImplementedError + + # -- blocking list (used by MCP OAuth BLPOP pattern) ------------------- + + @abc.abstractmethod + def rpush(self, key: str, value: str | bytes) -> None: + raise NotImplementedError + + @abc.abstractmethod + def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None: + """Block until a value is available on one of *keys*, or *timeout* expires. + + Returns ``(key, value)`` or ``None`` on timeout. + """ + raise NotImplementedError diff --git a/backend/onyx/cache/postgres_backend.py b/backend/onyx/cache/postgres_backend.py new file mode 100644 index 00000000000..03219761975 --- /dev/null +++ b/backend/onyx/cache/postgres_backend.py @@ -0,0 +1,323 @@ +"""PostgreSQL-backed ``CacheBackend`` for NO_VECTOR_DB deployments. + +Uses the ``cache_store`` table for key-value storage, PostgreSQL advisory locks +for distributed locking, and a polling loop for the BLPOP pattern. +""" + +import hashlib +import struct +import time +import uuid +from contextlib import AbstractContextManager +from datetime import datetime +from datetime import timedelta +from datetime import timezone + +from sqlalchemy import delete +from sqlalchemy import func +from sqlalchemy import or_ +from sqlalchemy import select +from sqlalchemy import update +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.orm import Session + +from onyx.cache.interface import CacheBackend +from onyx.cache.interface import CacheLock +from onyx.cache.interface import TTL_KEY_NOT_FOUND +from onyx.cache.interface import TTL_NO_EXPIRY +from onyx.db.models import CacheStore + +_LIST_KEY_PREFIX = "_q:" +# ASCII: ':' (0x3A) < ';' (0x3B). Upper bound for range queries so [prefix+, prefix;) +# captures all list-item keys (e.g. _q:mylist:123:uuid) without including other +# lists whose names share a prefix (e.g. _q:mylist2:...). +_LIST_KEY_RANGE_TERMINATOR = ";" +_LIST_ITEM_TTL_SECONDS = 3600 +_LOCK_POLL_INTERVAL = 0.1 +_BLPOP_POLL_INTERVAL = 0.25 + + +def _list_item_key(key: str) -> str: + """Unique key for a list item. Timestamp for FIFO ordering; UUID prevents + collision when concurrent rpush calls occur within the same nanosecond. + """ + return f"{_LIST_KEY_PREFIX}{key}:{time.time_ns()}:{uuid.uuid4().hex}" + + +def _to_bytes(value: str | bytes | int | float) -> bytes: + if isinstance(value, bytes): + return value + return str(value).encode() + + +# ------------------------------------------------------------------ +# Lock +# ------------------------------------------------------------------ + + +class PostgresCacheLock(CacheLock): + """Advisory-lock-based distributed lock. + + Uses ``get_session_with_tenant`` for connection lifecycle. The lock is tied + to the session's connection; releasing or closing the session frees it. + + NOTE: Unlike Redis locks, advisory locks do not auto-expire after + ``timeout`` seconds. They are released when ``release()`` is + called or when the session is closed. + """ + + def __init__(self, lock_id: int, timeout: float | None, tenant_id: str) -> None: + self._lock_id = lock_id + self._timeout = timeout + self._tenant_id = tenant_id + self._session_cm: AbstractContextManager[Session] | None = None + self._session: Session | None = None + self._acquired = False + + def acquire( + self, + blocking: bool = True, + blocking_timeout: float | None = None, + ) -> bool: + from onyx.db.engine.sql_engine import get_session_with_tenant + + self._session_cm = get_session_with_tenant(tenant_id=self._tenant_id) + self._session = self._session_cm.__enter__() + try: + if not blocking: + return self._try_lock() + + effective_timeout = blocking_timeout or self._timeout + deadline = ( + (time.monotonic() + effective_timeout) if effective_timeout else None + ) + while True: + if self._try_lock(): + return True + if deadline is not None and time.monotonic() >= deadline: + return False + time.sleep(_LOCK_POLL_INTERVAL) + finally: + if not self._acquired: + self._close_session() + + def release(self) -> None: + if not self._acquired or self._session is None: + return + try: + self._session.execute(select(func.pg_advisory_unlock(self._lock_id))) + finally: + self._acquired = False + self._close_session() + + def owned(self) -> bool: + return self._acquired + + def _close_session(self) -> None: + if self._session_cm is not None: + try: + self._session_cm.__exit__(None, None, None) + finally: + self._session_cm = None + self._session = None + + def _try_lock(self) -> bool: + assert self._session is not None + result = self._session.execute( + select(func.pg_try_advisory_lock(self._lock_id)) + ).scalar() + if result: + self._acquired = True + return True + return False + + +# ------------------------------------------------------------------ +# Backend +# ------------------------------------------------------------------ + + +class PostgresCacheBackend(CacheBackend): + """``CacheBackend`` backed by the ``cache_store`` table in PostgreSQL. + + Each operation opens and closes its own database session so the backend + is safe to share across threads. Tenant isolation is handled by + SQLAlchemy's ``schema_translate_map`` (set by ``get_session_with_tenant``). + """ + + def __init__(self, tenant_id: str) -> None: + self._tenant_id = tenant_id + + # -- basic key/value --------------------------------------------------- + + def get(self, key: str) -> bytes | None: + from onyx.db.engine.sql_engine import get_session_with_tenant + + stmt = select(CacheStore.value).where( + CacheStore.key == key, + or_(CacheStore.expires_at.is_(None), CacheStore.expires_at > func.now()), + ) + with get_session_with_tenant(tenant_id=self._tenant_id) as session: + value = session.execute(stmt).scalar_one_or_none() + if value is None: + return None + return bytes(value) + + def set( + self, + key: str, + value: str | bytes | int | float, + ex: int | None = None, + ) -> None: + from onyx.db.engine.sql_engine import get_session_with_tenant + + value_bytes = _to_bytes(value) + expires_at = ( + datetime.now(timezone.utc) + timedelta(seconds=ex) + if ex is not None + else None + ) + stmt = ( + pg_insert(CacheStore) + .values(key=key, value=value_bytes, expires_at=expires_at) + .on_conflict_do_update( + index_elements=[CacheStore.key], + set_={"value": value_bytes, "expires_at": expires_at}, + ) + ) + with get_session_with_tenant(tenant_id=self._tenant_id) as session: + session.execute(stmt) + session.commit() + + def delete(self, key: str) -> None: + from onyx.db.engine.sql_engine import get_session_with_tenant + + with get_session_with_tenant(tenant_id=self._tenant_id) as session: + session.execute(delete(CacheStore).where(CacheStore.key == key)) + session.commit() + + def exists(self, key: str) -> bool: + from onyx.db.engine.sql_engine import get_session_with_tenant + + stmt = ( + select(CacheStore.key) + .where( + CacheStore.key == key, + or_( + CacheStore.expires_at.is_(None), + CacheStore.expires_at > func.now(), + ), + ) + .limit(1) + ) + with get_session_with_tenant(tenant_id=self._tenant_id) as session: + return session.execute(stmt).first() is not None + + # -- TTL --------------------------------------------------------------- + + def expire(self, key: str, seconds: int) -> None: + from onyx.db.engine.sql_engine import get_session_with_tenant + + new_exp = datetime.now(timezone.utc) + timedelta(seconds=seconds) + stmt = ( + update(CacheStore).where(CacheStore.key == key).values(expires_at=new_exp) + ) + with get_session_with_tenant(tenant_id=self._tenant_id) as session: + session.execute(stmt) + session.commit() + + def ttl(self, key: str) -> int: + from onyx.db.engine.sql_engine import get_session_with_tenant + + stmt = select(CacheStore.expires_at).where(CacheStore.key == key) + with get_session_with_tenant(tenant_id=self._tenant_id) as session: + result = session.execute(stmt).first() + if result is None: + return TTL_KEY_NOT_FOUND + expires_at: datetime | None = result[0] + if expires_at is None: + return TTL_NO_EXPIRY + remaining = (expires_at - datetime.now(timezone.utc)).total_seconds() + if remaining <= 0: + return TTL_KEY_NOT_FOUND + return int(remaining) + + # -- distributed lock -------------------------------------------------- + + def lock(self, name: str, timeout: float | None = None) -> CacheLock: + return PostgresCacheLock( + self._lock_id_for(name), timeout, tenant_id=self._tenant_id + ) + + # -- blocking list (MCP OAuth BLPOP pattern) --------------------------- + + def rpush(self, key: str, value: str | bytes) -> None: + self.set(_list_item_key(key), value, ex=_LIST_ITEM_TTL_SECONDS) + + def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None: + if timeout <= 0: + raise ValueError( + "PostgresCacheBackend.blpop requires timeout > 0. " + "timeout=0 would block the calling thread indefinitely " + "with no way to interrupt short of process termination." + ) + from onyx.db.engine.sql_engine import get_session_with_tenant + + deadline = time.monotonic() + timeout + while True: + for key in keys: + lower = f"{_LIST_KEY_PREFIX}{key}:" + upper = f"{_LIST_KEY_PREFIX}{key}{_LIST_KEY_RANGE_TERMINATOR}" + stmt = ( + select(CacheStore) + .where( + CacheStore.key >= lower, + CacheStore.key < upper, + or_( + CacheStore.expires_at.is_(None), + CacheStore.expires_at > func.now(), + ), + ) + .order_by(CacheStore.key) + .limit(1) + .with_for_update(skip_locked=True) + ) + with get_session_with_tenant(tenant_id=self._tenant_id) as session: + row = session.execute(stmt).scalars().first() + if row is not None: + value = bytes(row.value) if row.value else b"" + session.delete(row) + session.commit() + return (key.encode(), value) + if time.monotonic() >= deadline: + return None + time.sleep(_BLPOP_POLL_INTERVAL) + + # -- helpers ----------------------------------------------------------- + + def _lock_id_for(self, name: str) -> int: + """Map *name* to a 64-bit signed int for ``pg_advisory_lock``.""" + h = hashlib.md5(f"{self._tenant_id}:{name}".encode()).digest() + return struct.unpack("q", h[:8])[0] + + +# ------------------------------------------------------------------ +# Periodic cleanup +# ------------------------------------------------------------------ + + +def cleanup_expired_cache_entries() -> None: + """Delete rows whose ``expires_at`` is in the past. + + Called by the periodic poller every 5 minutes. + """ + from onyx.db.engine.sql_engine import get_session_with_current_tenant + + with get_session_with_current_tenant() as session: + session.execute( + delete(CacheStore).where( + CacheStore.expires_at.is_not(None), + CacheStore.expires_at < func.now(), + ) + ) + session.commit() diff --git a/backend/onyx/cache/redis_backend.py b/backend/onyx/cache/redis_backend.py new file mode 100644 index 00000000000..6730307aee6 --- /dev/null +++ b/backend/onyx/cache/redis_backend.py @@ -0,0 +1,92 @@ +from typing import cast + +from redis.client import Redis +from redis.lock import Lock as RedisLock + +from onyx.cache.interface import CacheBackend +from onyx.cache.interface import CacheLock + + +class RedisCacheLock(CacheLock): + """Wraps ``redis.lock.Lock`` behind the ``CacheLock`` interface.""" + + def __init__(self, lock: RedisLock) -> None: + self._lock = lock + + def acquire( + self, + blocking: bool = True, + blocking_timeout: float | None = None, + ) -> bool: + return bool( + self._lock.acquire( + blocking=blocking, + blocking_timeout=blocking_timeout, + ) + ) + + def release(self) -> None: + self._lock.release() + + def owned(self) -> bool: + return bool(self._lock.owned()) + + +class RedisCacheBackend(CacheBackend): + """``CacheBackend`` implementation that delegates to a ``redis.Redis`` client. + + This is a thin pass-through — every method maps 1-to-1 to the underlying + Redis command. ``TenantRedis`` key-prefixing is handled by the client + itself (provided by ``get_redis_client``). + """ + + def __init__(self, redis_client: Redis) -> None: + self._r = redis_client + + # -- basic key/value --------------------------------------------------- + + def get(self, key: str) -> bytes | None: + val = self._r.get(key) + if val is None: + return None + if isinstance(val, bytes): + return val + return str(val).encode() + + def set( + self, + key: str, + value: str | bytes | int | float, + ex: int | None = None, + ) -> None: + self._r.set(key, value, ex=ex) + + def delete(self, key: str) -> None: + self._r.delete(key) + + def exists(self, key: str) -> bool: + return bool(self._r.exists(key)) + + # -- TTL --------------------------------------------------------------- + + def expire(self, key: str, seconds: int) -> None: + self._r.expire(key, seconds) + + def ttl(self, key: str) -> int: + return cast(int, self._r.ttl(key)) + + # -- distributed lock -------------------------------------------------- + + def lock(self, name: str, timeout: float | None = None) -> CacheLock: + return RedisCacheLock(self._r.lock(name, timeout=timeout)) + + # -- blocking list (MCP OAuth BLPOP pattern) --------------------------- + + def rpush(self, key: str, value: str | bytes) -> None: + self._r.rpush(key, value) + + def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None: + result = cast(list[bytes] | None, self._r.blpop(keys, timeout=timeout)) + if result is None: + return None + return (result[0], result[1]) diff --git a/backend/onyx/chat/chat_processing_checker.py b/backend/onyx/chat/chat_processing_checker.py index f7b6923ce8f..41dc9a207f8 100644 --- a/backend/onyx/chat/chat_processing_checker.py +++ b/backend/onyx/chat/chat_processing_checker.py @@ -1,57 +1,52 @@ from uuid import UUID -from redis.client import Redis +from onyx.cache.interface import CacheBackend -# Redis key prefixes for chat message processing PREFIX = "chatprocessing" FENCE_PREFIX = f"{PREFIX}_fence" FENCE_TTL = 30 * 60 # 30 minutes def _get_fence_key(chat_session_id: UUID) -> str: - """ - Generate the Redis key for a chat session processing a message. + """Generate the cache key for a chat session processing fence. Args: chat_session_id: The UUID of the chat session Returns: - The fence key string (tenant_id is automatically added by the Redis client) + The fence key string. Tenant isolation is handled automatically + by the cache backend (Redis key-prefixing or Postgres schema routing). """ return f"{FENCE_PREFIX}_{chat_session_id}" def set_processing_status( - chat_session_id: UUID, redis_client: Redis, value: bool + chat_session_id: UUID, cache: CacheBackend, value: bool ) -> None: - """ - Set or clear the fence for a chat session processing a message. + """Set or clear the fence for a chat session processing a message. - If the key exists, we are processing a message. If the key does not exist, we are not processing a message. + If the key exists, a message is being processed. Args: chat_session_id: The UUID of the chat session - redis_client: The Redis client to use + cache: Tenant-aware cache backend value: True to set the fence, False to clear it """ fence_key = _get_fence_key(chat_session_id) - if value: - redis_client.set(fence_key, 0, ex=FENCE_TTL) + cache.set(fence_key, 0, ex=FENCE_TTL) else: - redis_client.delete(fence_key) + cache.delete(fence_key) -def is_chat_session_processing(chat_session_id: UUID, redis_client: Redis) -> bool: - """ - Check if the chat session is processing a message. +def is_chat_session_processing(chat_session_id: UUID, cache: CacheBackend) -> bool: + """Check if the chat session is processing a message. Args: chat_session_id: The UUID of the chat session - redis_client: The Redis client to use + cache: Tenant-aware cache backend Returns: True if the chat session is processing a message, False otherwise """ - fence_key = _get_fence_key(chat_session_id) - return bool(redis_client.exists(fence_key)) + return cache.exists(_get_fence_key(chat_session_id)) diff --git a/backend/onyx/chat/llm_loop.py b/backend/onyx/chat/llm_loop.py index bc802ce95bd..60dcb9de523 100644 --- a/backend/onyx/chat/llm_loop.py +++ b/backend/onyx/chat/llm_loop.py @@ -1,6 +1,7 @@ import json import time from collections.abc import Callable +from typing import Any from typing import Literal from sqlalchemy.orm import Session @@ -35,7 +36,6 @@ from onyx.db.memory import update_memory_at_index from onyx.db.memory import UserMemoryContext from onyx.db.models import Persona -from onyx.llm.constants import LlmProviderNames from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLMUserIdentity from onyx.llm.interfaces import ToolChoiceOptions @@ -50,7 +50,9 @@ from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES from onyx.tools.interface import Tool from onyx.tools.models import ChatFile +from onyx.tools.models import CustomToolCallSummary from onyx.tools.models import MemoryToolResponseSnapshot +from onyx.tools.models import PythonToolRichResponse from onyx.tools.models import ToolCallInfo from onyx.tools.models import ToolCallKickoff from onyx.tools.models import ToolResponse @@ -82,28 +84,6 @@ def _looks_like_xml_tool_call_payload(text: str | None) -> bool: ) -def _should_keep_bedrock_tool_definitions( - llm: object, simple_chat_history: list[ChatMessageSimple] -) -> bool: - """Bedrock requires tool config when history includes toolUse/toolResult blocks.""" - model_provider = getattr(getattr(llm, "config", None), "model_provider", None) - if model_provider not in { - LlmProviderNames.BEDROCK, - LlmProviderNames.BEDROCK_CONVERSE, - }: - return False - - return any( - ( - msg.message_type == MessageType.ASSISTANT - and msg.tool_calls - and len(msg.tool_calls) > 0 - ) - or msg.message_type == MessageType.TOOL_CALL_RESPONSE - for msg in simple_chat_history - ) - - def _try_fallback_tool_extraction( llm_step_result: LlmStepResult, tool_choice: ToolChoiceOptions, @@ -530,11 +510,13 @@ def _create_file_tool_metadata_message( """ lines = [ "You have access to the following files. Use the read_file tool to " - "read sections of any file:" + "read sections of any file. You MUST pass the file_id UUID (not the " + "filename) to read_file:" ] for meta in file_metadata: lines.append( - f'- {meta.file_id}: "{meta.filename}" (~{meta.approx_char_count:,} chars)' + f'- file_id="{meta.file_id}" filename="{meta.filename}" ' + f"(~{meta.approx_char_count:,} chars)" ) message_content = "\n".join(lines) @@ -558,12 +540,16 @@ def _create_context_files_message( # Format as documents JSON as described in README documents_list = [] for idx, file_text in enumerate(context_files.file_texts, start=1): - documents_list.append( - { - "document": idx, - "contents": file_text, - } + title = ( + context_files.file_metadata[idx - 1].filename + if idx - 1 < len(context_files.file_metadata) + else None ) + entry: dict[str, Any] = {"document": idx} + if title: + entry["title"] = title + entry["contents"] = file_text + documents_list.append(entry) documents_json = json.dumps({"documents": documents_list}, indent=2) message_content = f"Here are some documents provided for context, they may not all be relevant:\n{documents_json}" @@ -678,12 +664,7 @@ def run_llm_loop( elif out_of_cycles or ran_image_gen: # Last cycle, no tools allowed, just answer! tool_choice = ToolChoiceOptions.NONE - # Bedrock requires tool config in requests that include toolUse/toolResult history. - final_tools = ( - tools - if _should_keep_bedrock_tool_definitions(llm, simple_chat_history) - else [] - ) + final_tools = [] else: tool_choice = ToolChoiceOptions.AUTO final_tools = tools @@ -959,6 +940,13 @@ def run_llm_loop( ): generated_images = tool_response.rich_response.generated_images + # Extract generated_files if this is a code interpreter response + generated_files = None + if isinstance(tool_response.rich_response, PythonToolRichResponse): + generated_files = ( + tool_response.rich_response.generated_files or None + ) + # Persist memory if this is a memory tool response memory_snapshot: MemoryToolResponseSnapshot | None = None if isinstance(tool_response.rich_response, MemoryToolResponse): @@ -993,6 +981,10 @@ def run_llm_loop( if memory_snapshot: saved_response = json.dumps(memory_snapshot.model_dump()) + elif isinstance(tool_response.rich_response, CustomToolCallSummary): + saved_response = json.dumps( + tool_response.rich_response.model_dump() + ) elif isinstance(tool_response.rich_response, str): saved_response = tool_response.rich_response else: @@ -1010,6 +1002,7 @@ def run_llm_loop( tool_call_response=saved_response, search_docs=displayed_docs or search_docs, generated_images=generated_images, + generated_files=generated_files, ) # Add to state container for partial save support state_container.add_tool_call(tool_call_info) diff --git a/backend/onyx/chat/llm_step.py b/backend/onyx/chat/llm_step.py index 4c15f0654e7..ab151b66da3 100644 --- a/backend/onyx/chat/llm_step.py +++ b/backend/onyx/chat/llm_step.py @@ -15,6 +15,7 @@ from onyx.chat.emitter import Emitter from onyx.chat.models import ChatMessageSimple from onyx.chat.models import LlmStepResult +from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY from onyx.configs.constants import MessageType @@ -54,7 +55,9 @@ from onyx.tools.models import ToolCallKickoff from onyx.tracing.framework.create import generation_span from onyx.utils.b64 import get_image_type_from_bytes +from onyx.utils.jsonriver import Parser from onyx.utils.logger import setup_logger +from onyx.utils.postgres_sanitization import sanitize_string from onyx.utils.text_processing import find_all_json_objects logger = setup_logger() @@ -166,15 +169,6 @@ def _find_function_calls_open_marker(text_lower: str) -> int: search_from = idx + 1 -def _sanitize_llm_output(value: str) -> str: - """Remove characters that PostgreSQL's text/JSONB types cannot store. - - - NULL bytes (\x00): Not allowed in PostgreSQL text types - - UTF-16 surrogates (\ud800-\udfff): Invalid in UTF-8 encoding - """ - return "".join(c for c in value if c != "\x00" and not ("\ud800" <= c <= "\udfff")) - - def _try_parse_json_string(value: Any) -> Any: """Attempt to parse a JSON string value into its Python equivalent. @@ -222,9 +216,7 @@ def _parse_tool_args_to_dict(raw_args: Any) -> dict[str, Any]: if isinstance(raw_args, dict): # Parse any string values that look like JSON arrays/objects return { - k: _try_parse_json_string( - _sanitize_llm_output(v) if isinstance(v, str) else v - ) + k: _try_parse_json_string(sanitize_string(v) if isinstance(v, str) else v) for k, v in raw_args.items() } @@ -232,7 +224,7 @@ def _parse_tool_args_to_dict(raw_args: Any) -> dict[str, Any]: return {} # Sanitize before parsing to remove NULL bytes and surrogates - raw_args = _sanitize_llm_output(raw_args) + raw_args = sanitize_string(raw_args) try: parsed1: Any = json.loads(raw_args) @@ -545,12 +537,12 @@ def _extract_xml_attribute(attrs: str, attr_name: str) -> str | None: ) if not attr_match: return None - return _sanitize_llm_output(unescape(attr_match.group(2).strip())) + return sanitize_string(unescape(attr_match.group(2).strip())) def _parse_xml_parameter_value(raw_value: str, string_attr: str | None) -> Any: """Parse a parameter value from XML-style tool call payloads.""" - value = _sanitize_llm_output(unescape(raw_value).strip()) + value = sanitize_string(unescape(raw_value).strip()) if string_attr and string_attr.lower() == "true": return value @@ -569,6 +561,7 @@ def _resolve_tool_arguments(obj: dict[str, Any]) -> dict[str, Any] | None: """ arguments = obj.get("arguments", obj.get("parameters", {})) if isinstance(arguments, str): + arguments = sanitize_string(arguments) try: arguments = json.loads(arguments) except json.JSONDecodeError: @@ -1018,6 +1011,7 @@ def _current_placement() -> Placement: ) id_to_tool_call_map: dict[int, dict[str, Any]] = {} + arg_parsers: dict[int, Parser] = {} reasoning_start = False answer_start = False accumulated_reasoning = "" @@ -1224,7 +1218,14 @@ def _emit_content_chunk(content_chunk: str) -> Generator[Packet, None, None]: yield from _close_reasoning_if_active() for tool_call_delta in delta.tool_calls: + # maybe_emit depends and update being called first and attaching the delta _update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta) + yield from maybe_emit_argument_delta( + tool_calls_in_progress=id_to_tool_call_map, + tool_call_delta=tool_call_delta, + placement=_current_placement(), + parsers=arg_parsers, + ) # Flush any tail text buffered while checking for split " bool: - return check_stop_signal(chat_session.id, redis_client) + return check_stop_signal(chat_session.id, cache) set_processing_status( chat_session_id=chat_session.id, - redis_client=redis_client, + cache=cache, value=True, ) @@ -968,10 +968,10 @@ def llm_loop_completion_callback( reset_llm_mock_response(mock_response_token) try: - if redis_client is not None and chat_session is not None: + if cache is not None and chat_session is not None: set_processing_status( chat_session_id=chat_session.id, - redis_client=redis_client, + cache=cache, value=False, ) except Exception: diff --git a/backend/onyx/chat/save_chat.py b/backend/onyx/chat/save_chat.py index 4a0dd7a94d8..a36ebed9492 100644 --- a/backend/onyx/chat/save_chat.py +++ b/backend/onyx/chat/save_chat.py @@ -1,4 +1,5 @@ import json +import mimetypes from sqlalchemy.orm import Session @@ -12,14 +13,42 @@ from onyx.db.models import ChatMessage from onyx.db.models import ToolCall from onyx.db.tools import create_tool_call_no_commit +from onyx.file_store.models import FileDescriptor from onyx.natural_language_processing.utils import BaseTokenizer from onyx.natural_language_processing.utils import get_tokenizer +from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type from onyx.tools.models import ToolCallInfo from onyx.utils.logger import setup_logger +from onyx.utils.postgres_sanitization import sanitize_string logger = setup_logger() +def _extract_referenced_file_descriptors( + tool_calls: list[ToolCallInfo], + message_text: str, +) -> list[FileDescriptor]: + """Extract FileDescriptors for code interpreter files referenced in the message text.""" + descriptors: list[FileDescriptor] = [] + for tool_call_info in tool_calls: + if not tool_call_info.generated_files: + continue + for gen_file in tool_call_info.generated_files: + file_id = ( + gen_file.file_link.rsplit("/", 1)[-1] if gen_file.file_link else "" + ) + if file_id and file_id in message_text: + mime_type, _ = mimetypes.guess_type(gen_file.filename) + descriptors.append( + FileDescriptor( + id=file_id, + type=mime_type_to_chat_file_type(mime_type), + name=gen_file.filename, + ) + ) + return descriptors + + def _create_and_link_tool_calls( tool_calls: list[ToolCallInfo], assistant_message: ChatMessage, @@ -173,8 +202,13 @@ def save_chat_turn( pre_answer_processing_time: Duration of processing before answer starts (in seconds) """ # 1. Update ChatMessage with message content, reasoning tokens, and token count - assistant_message.message = message_text - assistant_message.reasoning_tokens = reasoning_tokens + sanitized_message_text = ( + sanitize_string(message_text) if message_text else message_text + ) + assistant_message.message = sanitized_message_text + assistant_message.reasoning_tokens = ( + sanitize_string(reasoning_tokens) if reasoning_tokens else reasoning_tokens + ) assistant_message.is_clarification = is_clarification # Use pre-answer processing time (captured when MESSAGE_START was emitted) @@ -184,8 +218,10 @@ def save_chat_turn( # Calculate token count using default tokenizer, when storing, this should not use the LLM # specific one so we use a system default tokenizer here. default_tokenizer = get_tokenizer(None, None) - if message_text: - assistant_message.token_count = len(default_tokenizer.encode(message_text)) + if sanitized_message_text: + assistant_message.token_count = len( + default_tokenizer.encode(sanitized_message_text) + ) else: assistant_message.token_count = 0 @@ -297,5 +333,16 @@ def save_chat_turn( citation_number_to_search_doc_id if citation_number_to_search_doc_id else None ) + # 8. Attach code interpreter generated files that the assistant actually + # referenced in its response, so they are available via load_all_chat_files + # on subsequent turns. Files not mentioned are intermediate artifacts. + if sanitized_message_text: + referenced = _extract_referenced_file_descriptors( + tool_calls, sanitized_message_text + ) + if referenced: + existing_files = assistant_message.files or [] + assistant_message.files = existing_files + referenced + # Finally save the messages, tool calls, and docs db_session.commit() diff --git a/backend/onyx/chat/stop_signal_checker.py b/backend/onyx/chat/stop_signal_checker.py index 4caa55976f3..518f5cd2e1b 100644 --- a/backend/onyx/chat/stop_signal_checker.py +++ b/backend/onyx/chat/stop_signal_checker.py @@ -1,65 +1,58 @@ from uuid import UUID -from redis.client import Redis +from onyx.cache.interface import CacheBackend -# Redis key prefixes for chat session stop signals PREFIX = "chatsessionstop" FENCE_PREFIX = f"{PREFIX}_fence" -FENCE_TTL = 10 * 60 # 10 minutes - defensive TTL to prevent memory leaks +FENCE_TTL = 10 * 60 # 10 minutes def _get_fence_key(chat_session_id: UUID) -> str: - """ - Generate the Redis key for a chat session stop signal fence. + """Generate the cache key for a chat session stop signal fence. Args: chat_session_id: The UUID of the chat session Returns: - The fence key string (tenant_id is automatically added by the Redis client) + The fence key string. Tenant isolation is handled automatically + by the cache backend (Redis key-prefixing or Postgres schema routing). """ return f"{FENCE_PREFIX}_{chat_session_id}" -def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None: - """ - Set or clear the stop signal fence for a chat session. +def set_fence(chat_session_id: UUID, cache: CacheBackend, value: bool) -> None: + """Set or clear the stop signal fence for a chat session. Args: chat_session_id: The UUID of the chat session - redis_client: Redis client to use (tenant-aware client that auto-prefixes keys) + cache: Tenant-aware cache backend value: True to set the fence (stop signal), False to clear it """ fence_key = _get_fence_key(chat_session_id) if not value: - redis_client.delete(fence_key) + cache.delete(fence_key) return + cache.set(fence_key, 0, ex=FENCE_TTL) - redis_client.set(fence_key, 0, ex=FENCE_TTL) - -def is_connected(chat_session_id: UUID, redis_client: Redis) -> bool: - """ - Check if the chat session should continue (not stopped). +def is_connected(chat_session_id: UUID, cache: CacheBackend) -> bool: + """Check if the chat session should continue (not stopped). Args: chat_session_id: The UUID of the chat session to check - redis_client: Redis client to use for checking the stop signal (tenant-aware client that auto-prefixes keys) + cache: Tenant-aware cache backend Returns: True if the session should continue, False if it should stop """ - fence_key = _get_fence_key(chat_session_id) - return not bool(redis_client.exists(fence_key)) + return not cache.exists(_get_fence_key(chat_session_id)) -def reset_cancel_status(chat_session_id: UUID, redis_client: Redis) -> None: - """ - Clear the stop signal for a chat session. +def reset_cancel_status(chat_session_id: UUID, cache: CacheBackend) -> None: + """Clear the stop signal for a chat session. Args: chat_session_id: The UUID of the chat session - redis_client: Redis client to use (tenant-aware client that auto-prefixes keys) + cache: Tenant-aware cache backend """ - fence_key = _get_fence_key(chat_session_id) - redis_client.delete(fence_key) + cache.delete(_get_fence_key(chat_session_id)) diff --git a/backend/onyx/chat/tool_call_args_streaming.py b/backend/onyx/chat/tool_call_args_streaming.py new file mode 100644 index 00000000000..2520042bed7 --- /dev/null +++ b/backend/onyx/chat/tool_call_args_streaming.py @@ -0,0 +1,77 @@ +from collections.abc import Generator +from collections.abc import Mapping +from typing import Any +from typing import Type + +from onyx.llm.model_response import ChatCompletionDeltaToolCall +from onyx.server.query_and_chat.placement import Placement +from onyx.server.query_and_chat.streaming_models import Packet +from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta +from onyx.tools.built_in_tools import TOOL_NAME_TO_CLASS +from onyx.tools.interface import Tool +from onyx.utils.jsonriver import Parser + + +def _get_tool_class( + tool_calls_in_progress: Mapping[int, Mapping[str, Any]], + tool_call_delta: ChatCompletionDeltaToolCall, +) -> Type[Tool] | None: + """Look up the Tool subclass for a streaming tool call delta.""" + tool_name = tool_calls_in_progress.get(tool_call_delta.index, {}).get("name") + if not tool_name: + return None + return TOOL_NAME_TO_CLASS.get(tool_name) + + +def maybe_emit_argument_delta( + tool_calls_in_progress: Mapping[int, Mapping[str, Any]], + tool_call_delta: ChatCompletionDeltaToolCall, + placement: Placement, + parsers: dict[int, Parser], +) -> Generator[Packet, None, None]: + """Emit decoded tool-call argument deltas to the frontend. + + Uses a ``jsonriver.Parser`` per tool-call index to incrementally parse + the JSON argument string and extract only the newly-appended content + for each string-valued argument. + + NOTE: Non-string arguments (numbers, booleans, null, arrays, objects) + are skipped — they are available in the final tool-call kickoff packet. + + ``parsers`` is a mutable dict keyed by tool-call index. A new + ``Parser`` is created automatically for each new index. + """ + tool_cls = _get_tool_class(tool_calls_in_progress, tool_call_delta) + if not tool_cls or not tool_cls.should_emit_argument_deltas(): + return + + fn = tool_call_delta.function + delta_fragment = fn.arguments if fn else None + if not delta_fragment: + return + + idx = tool_call_delta.index + if idx not in parsers: + parsers[idx] = Parser() + parser = parsers[idx] + + deltas = parser.feed(delta_fragment) + + argument_deltas: dict[str, str] = {} + for delta in deltas: + if isinstance(delta, dict): + for key, value in delta.items(): + if isinstance(value, str): + argument_deltas[key] = argument_deltas.get(key, "") + value + + if not argument_deltas: + return + + tc_data = tool_calls_in_progress[tool_call_delta.index] + yield Packet( + placement=placement, + obj=ToolCallArgumentDelta( + tool_type=tc_data.get("name", ""), + argument_deltas=argument_deltas, + ), + ) diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index 7f54659192c..9cac16fd1df 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -6,6 +6,7 @@ from typing import cast from onyx.auth.schemas import AuthBackend +from onyx.cache.interface import CacheBackendType from onyx.configs.constants import AuthType from onyx.configs.constants import QueryHistoryType from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy @@ -54,6 +55,12 @@ # are disabled but core chat, tools, user file uploads, and Projects still work. DISABLE_VECTOR_DB = os.environ.get("DISABLE_VECTOR_DB", "").lower() == "true" +# Which backend to use for caching, locks, and ephemeral state. +# "redis" (default) or "postgres" (only valid when DISABLE_VECTOR_DB=true). +CACHE_BACKEND = CacheBackendType( + os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS) +) + # Maximum token count for a single uploaded file. Files exceeding this are rejected. # Defaults to 100k tokens (or 10M when vector DB is disabled). _DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000 @@ -61,6 +68,10 @@ os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT)) ) +# Maximum upload size for a single user file (chat/projects) in MB. +USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50) +USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024 + # If set to true, will show extra/uncommon connectors in the "Other" category SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true" @@ -85,19 +96,12 @@ ##### # Auth Configs ##### -# Upgrades users from disabled auth to basic auth and shows warning. -_auth_type_str = (os.environ.get("AUTH_TYPE") or "basic").lower() -if _auth_type_str == "disabled": - logger.warning( - "AUTH_TYPE='disabled' is no longer supported. " - "Defaulting to 'basic'. Please update your configuration. " - "Your existing data will be migrated automatically." - ) - _auth_type_str = AuthType.BASIC.value -try: +# Silently default to basic - warnings/errors logged in verify_auth_setting() +# which only runs on app startup, not during migrations/scripts +_auth_type_str = (os.environ.get("AUTH_TYPE") or "").lower() +if _auth_type_str in [auth_type.value for auth_type in AuthType]: AUTH_TYPE = AuthType(_auth_type_str) -except ValueError: - logger.error(f"Invalid AUTH_TYPE: {_auth_type_str}. Defaulting to 'basic'.") +else: AUTH_TYPE = AuthType.BASIC PASSWORD_MIN_LENGTH = int(os.getenv("PASSWORD_MIN_LENGTH", 8)) @@ -192,6 +196,10 @@ except Exception: pass +# Enables PKCE for OIDC login flow. Disabled by default to preserve +# backwards compatibility for existing OIDC deployments. +OIDC_PKCE_ENABLED = os.environ.get("OIDC_PKCE_ENABLED", "").lower() == "true" + # Applicable for SAML Auth SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/onyx/configs/saml_config" @@ -200,6 +208,12 @@ USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "") +if AUTH_TYPE == AuthType.BASIC and not USER_AUTH_SECRET: + logger.warning( + "USER_AUTH_SECRET is not set. This is required for secure password reset " + "and email verification tokens. Please set USER_AUTH_SECRET in production." + ) + # Duration (in seconds) for which the FastAPI Users JWT token remains valid in the user's browser. # By default, this is set to match the Redis expiry time for consistency. AUTH_COOKIE_EXPIRE_TIME_SECONDS = int( @@ -281,8 +295,9 @@ # environments we always want to be dual indexing into both OpenSearch and Vespa # to stress test the new codepaths. Only enable this if there is some instance # of OpenSearch running for the relevant Onyx instance. +# NOTE: Now enabled on by default, unless the env indicates otherwise. ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = ( - os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true" + os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "true").lower() == "true" ) # NOTE: This effectively does nothing anymore, admins can now toggle whether # retrieval is through OpenSearch. This value is only used as a final fallback @@ -300,6 +315,12 @@ os.environ.get("VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT", "true").lower() == "true" ) +OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE = int( + os.environ.get("OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE") or 500 +) +OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = int( + os.environ.get("OPENSEARCH_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES") or 0 +) VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost" # NOTE: this is used if and only if the vespa config server is accessible via a @@ -488,14 +509,7 @@ os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4 ) -# Consolidated background worker (light, docprocessing, docfetching, heavy, monitoring, user_file_processing) -# separate workers' defaults: light=24, docprocessing=6, docfetching=1, heavy=4, kg=2, monitoring=1, user_file=2 -# Total would be 40, but we use a more conservative default of 20 for the consolidated worker -CELERY_WORKER_BACKGROUND_CONCURRENCY = int( - os.environ.get("CELERY_WORKER_BACKGROUND_CONCURRENCY") or 20 -) - -# Individual worker concurrency settings (used when USE_LIGHTWEIGHT_BACKGROUND_WORKER is False or on Kuberenetes deployments) +# Individual worker concurrency settings CELERY_WORKER_HEAVY_CONCURRENCY = int( os.environ.get("CELERY_WORKER_HEAVY_CONCURRENCY") or 4 ) @@ -812,7 +826,9 @@ def get_current_tz_offset() -> int: # Tool Configs ##### # Code Interpreter Service Configuration -CODE_INTERPRETER_BASE_URL = os.environ.get("CODE_INTERPRETER_BASE_URL") +CODE_INTERPRETER_BASE_URL = os.environ.get( + "CODE_INTERPRETER_BASE_URL", "http://localhost:8000" +) CODE_INTERPRETER_DEFAULT_TIMEOUT_MS = int( os.environ.get("CODE_INTERPRETER_DEFAULT_TIMEOUT_MS") or 60_000 @@ -893,6 +909,9 @@ def get_current_tz_offset() -> int: ) VESPA_REQUEST_TIMEOUT = int(os.environ.get("VESPA_REQUEST_TIMEOUT") or "15") +VESPA_MIGRATION_REQUEST_TIMEOUT_S = int( + os.environ.get("VESPA_MIGRATION_REQUEST_TIMEOUT_S") or "120" +) SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000") diff --git a/backend/onyx/configs/constants.py b/backend/onyx/configs/constants.py index dd5af81206d..b7c14b82967 100644 --- a/backend/onyx/configs/constants.py +++ b/backend/onyx/configs/constants.py @@ -84,7 +84,6 @@ POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing" POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching" POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child" -POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME = "celery_worker_background" POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy" POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring" POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = ( diff --git a/backend/onyx/connectors/confluence/connector.py b/backend/onyx/connectors/confluence/connector.py index bbd343bb1cf..462cd625d55 100644 --- a/backend/onyx/connectors/confluence/connector.py +++ b/backend/onyx/connectors/confluence/connector.py @@ -943,6 +943,9 @@ def get_external_access( if include_permissions else None ), + parent_hierarchy_raw_node_id=self._get_parent_hierarchy_raw_id( + page + ), ) ) @@ -992,6 +995,7 @@ def get_external_access( if include_permissions else None ), + parent_hierarchy_raw_node_id=page_id, ) ) diff --git a/backend/onyx/connectors/discord/connector.py b/backend/onyx/connectors/discord/connector.py index 2ff37fe8e2d..8cfc9f459b3 100644 --- a/backend/onyx/connectors/discord/connector.py +++ b/backend/onyx/connectors/discord/connector.py @@ -1,4 +1,5 @@ import asyncio +from collections.abc import AsyncGenerator from collections.abc import AsyncIterable from collections.abc import Iterable from datetime import datetime @@ -204,7 +205,7 @@ def _manage_async_retrieval( end_time: datetime | None = end - async def _async_fetch() -> AsyncIterable[Document]: + async def _async_fetch() -> AsyncGenerator[Document, None]: intents = Intents.default() intents.message_content = True async with Client(intents=intents) as discord_client: @@ -227,22 +228,23 @@ async def _async_fetch() -> AsyncIterable[Document]: def run_and_yield() -> Iterable[Document]: loop = asyncio.new_event_loop() + async_gen = _async_fetch() try: - # Get the async generator - async_gen = _async_fetch() - # Convert to AsyncIterator - async_iter = async_gen.__aiter__() while True: try: - # Create a coroutine by calling anext with the async iterator - next_coro = anext(async_iter) - # Run the coroutine to get the next document - doc = loop.run_until_complete(next_coro) + doc = loop.run_until_complete(anext(async_gen)) yield doc except StopAsyncIteration: break finally: - loop.close() + # Must close the async generator before the loop so the Discord + # client's `async with` block can await its shutdown coroutine. + # The nested try/finally ensures the loop always closes even if + # aclose() raises (same pattern as cursor.close() before conn.close()). + try: + loop.run_until_complete(async_gen.aclose()) + finally: + loop.close() return run_and_yield() diff --git a/backend/onyx/connectors/google_drive/connector.py b/backend/onyx/connectors/google_drive/connector.py index f28c7e78bca..944def29a5e 100644 --- a/backend/onyx/connectors/google_drive/connector.py +++ b/backend/onyx/connectors/google_drive/connector.py @@ -1722,6 +1722,7 @@ def _yield_slim_batch() -> list[SlimDocument | HierarchyNode]: primary_admin_email=self.primary_admin_email, google_domain=self.google_domain, ), + retriever_email=file.user_email, ): slim_batch.append(doc) diff --git a/backend/onyx/connectors/google_drive/doc_conversion.py b/backend/onyx/connectors/google_drive/doc_conversion.py index 2e8cf3d9f92..fdbe3a78a1f 100644 --- a/backend/onyx/connectors/google_drive/doc_conversion.py +++ b/backend/onyx/connectors/google_drive/doc_conversion.py @@ -476,6 +476,7 @@ def _get_external_access_for_raw_gdrive_file( company_domain: str, retriever_drive_service: GoogleDriveService | None, admin_drive_service: GoogleDriveService, + fallback_user_email: str, add_prefix: bool = False, ) -> ExternalAccess: """ @@ -484,6 +485,8 @@ def _get_external_access_for_raw_gdrive_file( add_prefix: When True, prefix group IDs with source type (for indexing path). When False (default), leave unprefixed (for permission sync path where upsert_document_external_perms handles prefixing). + fallback_user_email: When permission info can't be retrieved (e.g. externally-owned + files), fall back to granting access to this user. """ external_access_fn = cast( Callable[ @@ -492,6 +495,7 @@ def _get_external_access_for_raw_gdrive_file( str, GoogleDriveService | None, GoogleDriveService, + str, bool, ], ExternalAccess, @@ -507,6 +511,7 @@ def _get_external_access_for_raw_gdrive_file( company_domain, retriever_drive_service, admin_drive_service, + fallback_user_email, add_prefix, ) @@ -672,6 +677,7 @@ def _get_docs_service() -> GoogleDocsService: creds, user_email=permission_sync_context.primary_admin_email ), add_prefix=True, # Indexing path - prefix here + fallback_user_email=retriever_email, ) if permission_sync_context else None @@ -753,6 +759,7 @@ def build_slim_document( # if not specified, we will not sync permissions # will also be a no-op if EE is not enabled permission_sync_context: PermissionSyncContext | None, + retriever_email: str, ) -> SlimDocument | None: if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]: return None @@ -774,6 +781,7 @@ def build_slim_document( creds, user_email=permission_sync_context.primary_admin_email, ), + fallback_user_email=retriever_email, ) if permission_sync_context else None @@ -781,4 +789,5 @@ def build_slim_document( return SlimDocument( id=onyx_document_id_from_drive_file(file), external_access=external_access, + parent_hierarchy_raw_node_id=(file.get("parents") or [None])[0], ) diff --git a/backend/onyx/connectors/google_utils/google_kv.py b/backend/onyx/connectors/google_utils/google_kv.py index 17d9d0c1e35..6974aea271c 100644 --- a/backend/onyx/connectors/google_utils/google_kv.py +++ b/backend/onyx/connectors/google_utils/google_kv.py @@ -44,6 +44,7 @@ from onyx.db.credentials import update_credential_json from onyx.db.models import User from onyx.key_value_store.factory import get_kv_store +from onyx.key_value_store.interface import unwrap_str from onyx.server.documents.models import CredentialBase from onyx.server.documents.models import GoogleAppCredentials from onyx.server.documents.models import GoogleServiceAccountKey @@ -89,7 +90,7 @@ def _get_current_oauth_user(creds: OAuthCredentials, source: DocumentSource) -> def verify_csrf(credential_id: int, state: str) -> None: - csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id))) + csrf = unwrap_str(get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))) if csrf != state: raise PermissionError( "State from Google Drive Connector callback does not match expected" @@ -178,7 +179,9 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str: params = parse_qs(parsed_url.query) get_kv_store().store( - KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True + KV_CRED_KEY.format(credential_id), + {"value": params.get("state", [None])[0]}, + encrypt=True, ) return str(auth_url) diff --git a/backend/onyx/connectors/jira/connector.py b/backend/onyx/connectors/jira/connector.py index 2050a7901ae..4751705307f 100644 --- a/backend/onyx/connectors/jira/connector.py +++ b/backend/onyx/connectors/jira/connector.py @@ -902,6 +902,11 @@ def retrieve_all_slim_docs_perm_sync( external_access=self._get_project_permissions( project_key, add_prefix=False ), + parent_hierarchy_raw_node_id=( + self._get_parent_hierarchy_raw_node_id(issue, project_key) + if project_key + else None + ), ) ) current_offset += 1 diff --git a/backend/onyx/connectors/models.py b/backend/onyx/connectors/models.py index 1298e18653e..a2cb446dcd7 100644 --- a/backend/onyx/connectors/models.py +++ b/backend/onyx/connectors/models.py @@ -385,6 +385,7 @@ def get_total_char_length(self) -> int: class SlimDocument(BaseModel): id: str external_access: ExternalAccess | None = None + parent_hierarchy_raw_node_id: str | None = None class HierarchyNode(BaseModel): diff --git a/backend/onyx/connectors/sharepoint/connector.py b/backend/onyx/connectors/sharepoint/connector.py index ab009f059e2..6088f3cf6f9 100644 --- a/backend/onyx/connectors/sharepoint/connector.py +++ b/backend/onyx/connectors/sharepoint/connector.py @@ -33,6 +33,7 @@ from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped] from pydantic import BaseModel from pydantic import Field +from requests.exceptions import HTTPError from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS @@ -258,6 +259,10 @@ class SharepointConnectorCheckpoint(ConnectorCheckpoint): # Track yielded hierarchy nodes by their raw_node_id (URLs) to avoid duplicates seen_hierarchy_node_raw_ids: set[str] = Field(default_factory=set) + # Track yielded document IDs to avoid processing the same document twice. + # The Microsoft Graph delta API can return the same item on multiple pages. + seen_document_ids: set[str] = Field(default_factory=set) + class SharepointAuthMethod(Enum): CLIENT_SECRET = "client_secret" @@ -268,6 +273,15 @@ class SizeCapExceeded(Exception): """Exception raised when the size cap is exceeded.""" +def _log_and_raise_for_status(response: requests.Response) -> None: + """Log the response text and raise for status.""" + try: + response.raise_for_status() + except Exception: + logger.error(f"HTTP request failed: {response.text}") + raise + + def load_certificate_from_pfx(pfx_data: bytes, password: str) -> CertificateData | None: """Load certificate from .pfx file for MSAL authentication""" try: @@ -344,7 +358,7 @@ def _probe_remote_size(url: str, timeout: int) -> int | None: """Determine remote size using HEAD or a range GET probe. Returns None if unknown.""" try: head_resp = requests.head(url, timeout=timeout, allow_redirects=True) - head_resp.raise_for_status() + _log_and_raise_for_status(head_resp) cl = head_resp.headers.get("Content-Length") if cl and cl.isdigit(): return int(cl) @@ -359,7 +373,7 @@ def _probe_remote_size(url: str, timeout: int) -> int | None: timeout=timeout, stream=True, ) as range_resp: - range_resp.raise_for_status() + _log_and_raise_for_status(range_resp) cr = range_resp.headers.get("Content-Range") # e.g., "bytes 0-0/12345" if cr and "/" in cr: total = cr.split("/")[-1] @@ -384,7 +398,7 @@ def _download_with_cap(url: str, timeout: int, cap: int) -> bytes: - Returns the full bytes if the content fits within `cap`. """ with requests.get(url, stream=True, timeout=timeout) as resp: - resp.raise_for_status() + _log_and_raise_for_status(resp) # If the server provides Content-Length, prefer an early decision. cl_header = resp.headers.get("Content-Length") @@ -428,7 +442,7 @@ def _download_via_graph_api( with requests.get( url, headers=headers, stream=True, timeout=REQUEST_TIMEOUT_SECONDS ) as resp: - resp.raise_for_status() + _log_and_raise_for_status(resp) buf = io.BytesIO() for chunk in resp.iter_content(64 * 1024): if not chunk: @@ -772,6 +786,7 @@ def _convert_driveitem_to_slim_document( drive_name: str, ctx: ClientContext, graph_client: GraphClient, + parent_hierarchy_raw_node_id: str | None = None, ) -> SlimDocument: if driveitem.id is None: raise ValueError("DriveItem ID is required") @@ -787,11 +802,15 @@ def _convert_driveitem_to_slim_document( return SlimDocument( id=driveitem.id, external_access=external_access, + parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id, ) def _convert_sitepage_to_slim_document( - site_page: dict[str, Any], ctx: ClientContext | None, graph_client: GraphClient + site_page: dict[str, Any], + ctx: ClientContext | None, + graph_client: GraphClient, + parent_hierarchy_raw_node_id: str | None = None, ) -> SlimDocument: """Convert a SharePoint site page to a SlimDocument object.""" if site_page.get("id") is None: @@ -808,6 +827,7 @@ def _convert_sitepage_to_slim_document( return SlimDocument( id=id, external_access=external_access, + parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id, ) @@ -1239,7 +1259,14 @@ def _fetch_site_pages( total_yielded = 0 while page_url: - data = self._graph_api_get_json(page_url, params) + try: + data = self._graph_api_get_json(page_url, params) + except HTTPError as e: + if e.response.status_code == 404: + logger.warning(f"Site page not found: {page_url}") + break + raise + params = None # nextLink already embeds query params for page in data.get("value", []): @@ -1303,7 +1330,7 @@ def _graph_api_get_json( access_token = self._get_graph_access_token() headers = {"Authorization": f"Bearer {access_token}"} continue - response.raise_for_status() + _log_and_raise_for_status(response) return response.json() except (requests.ConnectionError, requests.Timeout): if attempt < GRAPH_API_MAX_RETRIES: @@ -1551,6 +1578,7 @@ def _clear_drive_checkpoint_state( checkpoint.current_drive_id = None checkpoint.current_drive_web_url = None checkpoint.current_drive_delta_next_link = None + checkpoint.seen_document_ids.clear() def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput: site_descriptors = self.site_descriptors or self.fetch_sites() @@ -1594,12 +1622,22 @@ def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput: ) ) + parent_hierarchy_url: str | None = None + if drive_web_url: + parent_hierarchy_url = self._get_parent_hierarchy_url( + site_url, drive_web_url, drive_name, driveitem + ) + try: logger.debug(f"Processing: {driveitem.web_url}") ctx = self._create_rest_client_context(site_descriptor.url) doc_batch.append( _convert_driveitem_to_slim_document( - driveitem, drive_name, ctx, self.graph_client + driveitem, + drive_name, + ctx, + self.graph_client, + parent_hierarchy_raw_node_id=parent_hierarchy_url, ) ) except Exception as e: @@ -1619,7 +1657,10 @@ def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput: ctx = self._create_rest_client_context(site_descriptor.url) doc_batch.append( _convert_sitepage_to_slim_document( - site_page, ctx, self.graph_client + site_page, + ctx, + self.graph_client, + parent_hierarchy_raw_node_id=site_descriptor.url, ) ) if len(doc_batch) >= SLIM_BATCH_SIZE: @@ -2118,6 +2159,14 @@ def _load_from_checkpoint( item_count = 0 for driveitem in driveitems: item_count += 1 + + if driveitem.id and driveitem.id in checkpoint.seen_document_ids: + logger.debug( + f"Skipping duplicate document {driveitem.id} " + f"({driveitem.name})" + ) + continue + driveitem_extension = get_file_ext(driveitem.name) if driveitem_extension not in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS: logger.warning( @@ -2170,11 +2219,13 @@ def _load_from_checkpoint( if isinstance(doc_or_failure, Document): if doc_or_failure.sections: + checkpoint.seen_document_ids.add(doc_or_failure.id) yield doc_or_failure elif should_yield_if_empty: doc_or_failure.sections = [ TextSection(link=driveitem.web_url, text="") ] + checkpoint.seen_document_ids.add(doc_or_failure.id) yield doc_or_failure else: logger.warning( diff --git a/backend/onyx/connectors/slack/connector.py b/backend/onyx/connectors/slack/connector.py index ce30399d2fe..ea6c2200656 100644 --- a/backend/onyx/connectors/slack/connector.py +++ b/backend/onyx/connectors/slack/connector.py @@ -565,6 +565,7 @@ def _get_all_doc_ids( channel_id=channel_id, thread_ts=message["ts"] ), external_access=external_access, + parent_hierarchy_raw_node_id=channel_id, ) ) diff --git a/backend/onyx/db/chat.py b/backend/onyx/db/chat.py index 2ca7e5a2a24..1588508f19b 100644 --- a/backend/onyx/db/chat.py +++ b/backend/onyx/db/chat.py @@ -38,6 +38,7 @@ from onyx.llm.override_models import PromptOverride from onyx.server.query_and_chat.models import ChatMessageDetail from onyx.utils.logger import setup_logger +from onyx.utils.postgres_sanitization import sanitize_string logger = setup_logger() @@ -98,6 +99,7 @@ def get_chat_sessions_by_user( db_session: Session, include_onyxbot_flows: bool = False, limit: int = 50, + before: datetime | None = None, project_id: int | None = None, only_non_project_chats: bool = False, include_failed_chats: bool = False, @@ -112,6 +114,9 @@ def get_chat_sessions_by_user( if deleted is not None: stmt = stmt.where(ChatSession.deleted == deleted) + if before is not None: + stmt = stmt.where(ChatSession.time_updated < before) + if limit: stmt = stmt.limit(limit) @@ -671,58 +676,43 @@ def set_as_latest_chat_message( db_session.commit() -def _sanitize_for_postgres(value: str) -> str: - """Remove NUL (0x00) characters from strings as PostgreSQL doesn't allow them.""" - sanitized = value.replace("\x00", "") - if value and not sanitized: - logger.warning("Sanitization removed all characters from string") - return sanitized - - -def _sanitize_list_for_postgres(values: list[str]) -> list[str]: - """Remove NUL (0x00) characters from all strings in a list.""" - return [_sanitize_for_postgres(v) for v in values] - - def create_db_search_doc( server_search_doc: ServerSearchDoc, db_session: Session, commit: bool = True, ) -> DBSearchDoc: - # Sanitize string fields to remove NUL characters (PostgreSQL doesn't allow them) db_search_doc = DBSearchDoc( - document_id=_sanitize_for_postgres(server_search_doc.document_id), + document_id=sanitize_string(server_search_doc.document_id), chunk_ind=server_search_doc.chunk_ind, - semantic_id=_sanitize_for_postgres(server_search_doc.semantic_identifier), + semantic_id=sanitize_string(server_search_doc.semantic_identifier), link=( - _sanitize_for_postgres(server_search_doc.link) + sanitize_string(server_search_doc.link) if server_search_doc.link is not None else None ), - blurb=_sanitize_for_postgres(server_search_doc.blurb), + blurb=sanitize_string(server_search_doc.blurb), source_type=server_search_doc.source_type, boost=server_search_doc.boost, hidden=server_search_doc.hidden, doc_metadata=server_search_doc.metadata, is_relevant=server_search_doc.is_relevant, relevance_explanation=( - _sanitize_for_postgres(server_search_doc.relevance_explanation) + sanitize_string(server_search_doc.relevance_explanation) if server_search_doc.relevance_explanation is not None else None ), - # For docs further down that aren't reranked, we can't use the retrieval score score=server_search_doc.score or 0.0, - match_highlights=_sanitize_list_for_postgres( - server_search_doc.match_highlights - ), + match_highlights=[ + sanitize_string(h) for h in server_search_doc.match_highlights + ], updated_at=server_search_doc.updated_at, primary_owners=( - _sanitize_list_for_postgres(server_search_doc.primary_owners) + [sanitize_string(o) for o in server_search_doc.primary_owners] if server_search_doc.primary_owners is not None else None ), secondary_owners=( - _sanitize_list_for_postgres(server_search_doc.secondary_owners) + [sanitize_string(o) for o in server_search_doc.secondary_owners] if server_search_doc.secondary_owners is not None else None ), diff --git a/backend/onyx/db/document_set.py b/backend/onyx/db/document_set.py index 277f41ac280..b39114ac88f 100644 --- a/backend/onyx/db/document_set.py +++ b/backend/onyx/db/document_set.py @@ -13,6 +13,7 @@ from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session +from onyx.configs.app_configs import DISABLE_VECTOR_DB from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids from onyx.db.connector_credential_pair import get_connector_credential_pairs from onyx.db.enums import AccessType @@ -246,6 +247,7 @@ def insert_document_set( description=document_set_creation_request.description, user_id=user_id, is_public=document_set_creation_request.is_public, + is_up_to_date=DISABLE_VECTOR_DB, time_last_modified_by_user=func.now(), ) db_session.add(new_document_set_row) @@ -336,7 +338,8 @@ def update_document_set( ) document_set_row.description = document_set_update_request.description - document_set_row.is_up_to_date = False + if not DISABLE_VECTOR_DB: + document_set_row.is_up_to_date = False document_set_row.is_public = document_set_update_request.is_public document_set_row.time_last_modified_by_user = func.now() versioned_private_doc_set_fn = fetch_versioned_implementation( diff --git a/backend/onyx/db/enums.py b/backend/onyx/db/enums.py index e6191db1baa..de7b666d4b7 100644 --- a/backend/onyx/db/enums.py +++ b/backend/onyx/db/enums.py @@ -186,6 +186,7 @@ class EmbeddingPrecision(str, PyEnum): class UserFileStatus(str, PyEnum): PROCESSING = "PROCESSING" + INDEXING = "INDEXING" COMPLETED = "COMPLETED" FAILED = "FAILED" CANCELED = "CANCELED" diff --git a/backend/onyx/db/hierarchy.py b/backend/onyx/db/hierarchy.py index 4b9aceba2b4..b04eb4a3789 100644 --- a/backend/onyx/db/hierarchy.py +++ b/backend/onyx/db/hierarchy.py @@ -1,6 +1,11 @@ """CRUD operations for HierarchyNode.""" +from collections import defaultdict + +from sqlalchemy import delete from sqlalchemy import select +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session from onyx.configs.constants import DocumentSource @@ -8,6 +13,7 @@ from onyx.db.enums import HierarchyNodeType from onyx.db.models import Document from onyx.db.models import HierarchyNode +from onyx.db.models import HierarchyNodeByConnectorCredentialPair from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_versioned_implementation @@ -456,7 +462,7 @@ def get_all_hierarchy_nodes_for_source( def _get_accessible_hierarchy_nodes_for_source( db_session: Session, source: DocumentSource, - user_email: str | None, # noqa: ARG001 + user_email: str, # noqa: ARG001 external_group_ids: list[str], # noqa: ARG001 ) -> list[HierarchyNode]: """ @@ -483,7 +489,7 @@ def _get_accessible_hierarchy_nodes_for_source( def get_accessible_hierarchy_nodes_for_source( db_session: Session, source: DocumentSource, - user_email: str | None, + user_email: str, external_group_ids: list[str], ) -> list[HierarchyNode]: """ @@ -525,6 +531,53 @@ def get_document_parent_hierarchy_node_ids( return {doc_id: parent_id for doc_id, parent_id in results} +def update_document_parent_hierarchy_nodes( + db_session: Session, + doc_parent_map: dict[str, int | None], + commit: bool = True, +) -> int: + """Bulk-update Document.parent_hierarchy_node_id for multiple documents. + + Only updates rows whose current value differs from the desired value to + avoid unnecessary writes. + + Args: + db_session: SQLAlchemy session + doc_parent_map: Mapping of document_id → desired parent_hierarchy_node_id + commit: Whether to commit the transaction + + Returns: + Number of documents actually updated + """ + if not doc_parent_map: + return 0 + + doc_ids = list(doc_parent_map.keys()) + existing = get_document_parent_hierarchy_node_ids(db_session, doc_ids) + + by_parent: dict[int | None, list[str]] = defaultdict(list) + for doc_id, desired_parent_id in doc_parent_map.items(): + current = existing.get(doc_id) + if current == desired_parent_id or doc_id not in existing: + continue + by_parent[desired_parent_id].append(doc_id) + + updated = 0 + for desired_parent_id, ids in by_parent.items(): + db_session.query(Document).filter(Document.id.in_(ids)).update( + {Document.parent_hierarchy_node_id: desired_parent_id}, + synchronize_session=False, + ) + updated += len(ids) + + if commit: + db_session.commit() + elif updated: + db_session.flush() + + return updated + + def update_hierarchy_node_permissions( db_session: Session, raw_node_id: str, @@ -571,3 +624,154 @@ def update_hierarchy_node_permissions( db_session.flush() return True + + +def upsert_hierarchy_node_cc_pair_entries( + db_session: Session, + hierarchy_node_ids: list[int], + connector_id: int, + credential_id: int, + commit: bool = True, +) -> None: + """Insert rows into HierarchyNodeByConnectorCredentialPair, ignoring conflicts. + + This records that the given cc_pair "owns" these hierarchy nodes. Used by + indexing, pruning, and hierarchy-fetching paths. + """ + if not hierarchy_node_ids: + return + + _M = HierarchyNodeByConnectorCredentialPair + stmt = pg_insert(_M).values( + [ + { + _M.hierarchy_node_id: node_id, + _M.connector_id: connector_id, + _M.credential_id: credential_id, + } + for node_id in hierarchy_node_ids + ] + ) + stmt = stmt.on_conflict_do_nothing() + db_session.execute(stmt) + + if commit: + db_session.commit() + else: + db_session.flush() + + +def remove_stale_hierarchy_node_cc_pair_entries( + db_session: Session, + connector_id: int, + credential_id: int, + live_hierarchy_node_ids: set[int], + commit: bool = True, +) -> int: + """Delete join-table rows for this cc_pair that are NOT in the live set. + + If ``live_hierarchy_node_ids`` is empty ALL rows for the cc_pair are deleted + (i.e. the connector no longer has any hierarchy nodes). Callers that want a + no-op when there are no live nodes must guard before calling. + + Returns the number of deleted rows. + """ + stmt = delete(HierarchyNodeByConnectorCredentialPair).where( + HierarchyNodeByConnectorCredentialPair.connector_id == connector_id, + HierarchyNodeByConnectorCredentialPair.credential_id == credential_id, + ) + if live_hierarchy_node_ids: + stmt = stmt.where( + HierarchyNodeByConnectorCredentialPair.hierarchy_node_id.notin_( + live_hierarchy_node_ids + ) + ) + + result: CursorResult = db_session.execute(stmt) # type: ignore[assignment] + deleted = result.rowcount + + if commit: + db_session.commit() + elif deleted: + db_session.flush() + + return deleted + + +def delete_orphaned_hierarchy_nodes( + db_session: Session, + source: DocumentSource, + commit: bool = True, +) -> list[str]: + """Delete hierarchy nodes for a source that have zero cc_pair associations. + + SOURCE-type nodes are excluded (they are synthetic roots). + + Returns the list of raw_node_ids that were deleted (for cache eviction). + """ + # Find orphaned nodes: no rows in the join table + orphan_stmt = ( + select(HierarchyNode.id, HierarchyNode.raw_node_id) + .outerjoin( + HierarchyNodeByConnectorCredentialPair, + HierarchyNode.id + == HierarchyNodeByConnectorCredentialPair.hierarchy_node_id, + ) + .where( + HierarchyNode.source == source, + HierarchyNode.node_type != HierarchyNodeType.SOURCE, + HierarchyNodeByConnectorCredentialPair.hierarchy_node_id.is_(None), + ) + ) + orphans = db_session.execute(orphan_stmt).all() + if not orphans: + return [] + + orphan_ids = [row[0] for row in orphans] + deleted_raw_ids = [row[1] for row in orphans] + + db_session.execute(delete(HierarchyNode).where(HierarchyNode.id.in_(orphan_ids))) + + if commit: + db_session.commit() + else: + db_session.flush() + + return deleted_raw_ids + + +def reparent_orphaned_hierarchy_nodes( + db_session: Session, + source: DocumentSource, + commit: bool = True, +) -> list[HierarchyNode]: + """Re-parent hierarchy nodes whose parent_id is NULL to the SOURCE node. + + After pruning deletes stale nodes, their former children get parent_id=NULL + via the SET NULL cascade. This function points them back to the SOURCE root. + + Returns the reparented HierarchyNode objects (with updated parent_id) + so callers can refresh downstream caches. + """ + source_node = get_source_hierarchy_node(db_session, source) + if not source_node: + return [] + + stmt = select(HierarchyNode).where( + HierarchyNode.source == source, + HierarchyNode.parent_id.is_(None), + HierarchyNode.node_type != HierarchyNodeType.SOURCE, + ) + orphans = list(db_session.execute(stmt).scalars().all()) + if not orphans: + return [] + + for node in orphans: + node.parent_id = source_node.id + + if commit: + db_session.commit() + else: + db_session.flush() + + return orphans diff --git a/backend/onyx/db/llm.py b/backend/onyx/db/llm.py index 4ee1e88251e..151504f7104 100644 --- a/backend/onyx/db/llm.py +++ b/backend/onyx/db/llm.py @@ -25,8 +25,12 @@ from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest from onyx.server.manage.llm.models import LLMProviderUpsertRequest from onyx.server.manage.llm.models import LLMProviderView +from onyx.server.manage.llm.models import SyncModelEntry +from onyx.utils.logger import setup_logger from shared_configs.enums import EmbeddingProvider +logger = setup_logger() + def update_group_llm_provider_relationships__no_commit( llm_provider_id: int, @@ -109,45 +113,38 @@ def can_user_access_llm_provider( is_admin: If True, bypass user group restrictions but still respect persona restrictions Access logic: - 1. If is_public=True → everyone has access (public override) - 2. If is_public=False: - - Both groups AND personas set → must satisfy BOTH (AND logic, admins bypass group check) - - Only groups set → must be in one of the groups (OR across groups, admins bypass) - - Only personas set → must use one of the personas (OR across personas, applies to admins) - - Neither set → NOBODY has access unless admin (locked, admin-only) + - is_public controls USER access (group bypass): when True, all users can access + regardless of group membership. When False, user must be in a whitelisted group + (or be admin). + - Persona restrictions are ALWAYS enforced when set, regardless of is_public. + This allows admins to make a provider available to all users while still + restricting which personas (assistants) can use it. + + Decision matrix: + 1. is_public=True, no personas set → everyone has access + 2. is_public=True, personas set → all users, but only whitelisted personas + 3. is_public=False, groups+personas set → must satisfy BOTH (admins bypass groups) + 4. is_public=False, only groups set → must be in group (admins bypass) + 5. is_public=False, only personas set → must use whitelisted persona + 6. is_public=False, neither set → admin-only (locked) """ - # Public override - everyone has access - if provider.is_public: - return True - - # Extract IDs once to avoid multiple iterations - provider_group_ids = ( - {group.id for group in provider.groups} if provider.groups else set() - ) - provider_persona_ids = ( - {p.id for p in provider.personas} if provider.personas else set() - ) - + provider_group_ids = {g.id for g in (provider.groups or [])} + provider_persona_ids = {p.id for p in (provider.personas or [])} has_groups = bool(provider_group_ids) has_personas = bool(provider_persona_ids) - # Both groups AND personas set → AND logic (must satisfy both) - if has_groups and has_personas: - # Admins bypass group check but still must satisfy persona restrictions - user_in_group = is_admin or bool(user_group_ids & provider_group_ids) - persona_allowed = persona.id in provider_persona_ids if persona else False - return user_in_group and persona_allowed + # Persona restrictions are always enforced when set, regardless of is_public + if has_personas and not (persona and persona.id in provider_persona_ids): + return False + + if provider.is_public: + return True - # Only groups set → user must be in one of the groups (admins bypass) if has_groups: return is_admin or bool(user_group_ids & provider_group_ids) - # Only personas set → persona must be in allowed list (applies to admins too) - if has_personas: - return persona.id in provider_persona_ids if persona else False - - # Neither groups nor personas set, and not public → admins can access - return is_admin + # No groups: either persona-whitelisted (already passed) or admin-only if locked + return has_personas or is_admin def validate_persona_ids_exist( @@ -274,10 +271,35 @@ def upsert_llm_provider( mc.name for mc in llm_provider_upsert_request.model_configurations } + # Build a lookup of requested visibility by model name + requested_visibility = { + mc.name: mc.is_visible + for mc in llm_provider_upsert_request.model_configurations + } + # Delete removed models removed_ids = [ mc.id for name, mc in existing_by_name.items() if name not in models_to_exist ] + + default_model = fetch_default_llm_model(db_session) + + # Prevent removing and hiding the default model + if default_model: + for name, mc in existing_by_name.items(): + if mc.id == default_model.id: + if default_model.id in removed_ids: + raise ValueError( + f"Cannot remove the default model '{name}'. " + "Please change the default model before removing." + ) + if not requested_visibility.get(name, True): + raise ValueError( + f"Cannot hide the default model '{name}'. " + "Please change the default model before hiding." + ) + break + if removed_ids: db_session.query(ModelConfiguration).filter( ModelConfiguration.id.in_(removed_ids) @@ -348,9 +370,9 @@ def upsert_llm_provider( def sync_model_configurations( db_session: Session, provider_name: str, - models: list[dict], + models: list[SyncModelEntry], ) -> int: - """Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama). + """Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama, etc.). This inserts NEW models from the source API without overwriting existing ones. User preferences (is_visible, max_input_tokens) are preserved for existing models. @@ -358,7 +380,7 @@ def sync_model_configurations( Args: db_session: Database session provider_name: Name of the LLM provider - models: List of model dicts with keys: name, display_name, max_input_tokens, supports_image_input + models: List of SyncModelEntry objects describing the fetched models Returns: Number of new models added @@ -372,21 +394,20 @@ def sync_model_configurations( new_count = 0 for model in models: - model_name = model["name"] - if model_name not in existing_names: + if model.name not in existing_names: # Insert new model with is_visible=False (user must explicitly enable) supported_flows = [LLMModelFlowType.CHAT] - if model.get("supports_image_input", False): + if model.supports_image_input: supported_flows.append(LLMModelFlowType.VISION) insert_new_model_configuration__no_commit( db_session=db_session, llm_provider_id=provider.id, - model_name=model_name, + model_name=model.name, supported_flows=supported_flows, is_visible=False, - max_input_tokens=model.get("max_input_tokens"), - display_name=model.get("display_name"), + max_input_tokens=model.max_input_tokens, + display_name=model.display_name, ) new_count += 1 @@ -539,9 +560,9 @@ def fetch_default_model( ) -> ModelConfiguration | None: model_config = db_session.scalar( select(ModelConfiguration) + .options(selectinload(ModelConfiguration.llm_provider)) .join(LLMModelFlow) .where( - ModelConfiguration.is_visible == True, # noqa: E712 LLMModelFlow.llm_model_flow_type == flow_type, LLMModelFlow.is_default == True, # noqa: E712 ) @@ -817,6 +838,29 @@ def sync_auto_mode_models( ) changes += 1 + # Update the default if this provider currently holds the global CHAT default. + # We flush (but don't commit) so that _update_default_model can see the new + # model rows, then commit everything atomically to avoid a window where the + # old default is invisible but still pointed-to. + db_session.flush() + + recommended_default = llm_recommendations.get_default_model(provider.provider) + if recommended_default: + current_default = fetch_default_llm_model(db_session) + + if ( + current_default + and current_default.llm_provider_id == provider.id + and current_default.name != recommended_default.name + ): + _update_default_model__no_commit( + db_session=db_session, + provider_id=provider.id, + model=recommended_default.name, + flow_type=LLMModelFlowType.CHAT, + ) + changes += 1 + db_session.commit() return changes @@ -948,7 +992,7 @@ def update_model_configuration__no_commit( db_session.flush() -def _update_default_model( +def _update_default_model__no_commit( db_session: Session, provider_id: int, model: str, @@ -986,6 +1030,14 @@ def _update_default_model( new_default.is_default = True model_config.is_visible = True + +def _update_default_model( + db_session: Session, + provider_id: int, + model: str, + flow_type: LLMModelFlowType, +) -> None: + _update_default_model__no_commit(db_session, provider_id, model, flow_type) db_session.commit() diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index 5987385bfd3..042bdc25218 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -25,6 +25,7 @@ from sqlalchemy import Enum from sqlalchemy import Float from sqlalchemy import ForeignKey +from sqlalchemy import ForeignKeyConstraint from sqlalchemy import func from sqlalchemy import Index from sqlalchemy import Integer @@ -36,9 +37,11 @@ from sqlalchemy import text from sqlalchemy import UniqueConstraint from sqlalchemy.dialects import postgresql +from sqlalchemy import event from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped +from sqlalchemy.orm import Mapper from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from sqlalchemy.types import LargeBinary @@ -117,10 +120,52 @@ class Base(DeclarativeBase): __abstract__ = True -class EncryptedString(TypeDecorator): +class _EncryptedBase(TypeDecorator): + """Base for encrypted column types that wrap values in SensitiveValue.""" + impl = LargeBinary - # This type's behavior is fully deterministic and doesn't depend on any external factors. cache_ok = True + _is_json: bool = False + + def wrap_raw(self, value: Any) -> SensitiveValue: + """Encrypt a raw value and wrap it in SensitiveValue. + + Called by the attribute set event so the Python-side type is always + SensitiveValue, regardless of whether the value was loaded from the DB + or assigned in application code. + """ + if self._is_json: + if not isinstance(value, dict): + raise TypeError( + f"EncryptedJson column expected dict, got {type(value).__name__}" + ) + raw_str = json.dumps(value) + else: + if not isinstance(value, str): + raise TypeError( + f"EncryptedString column expected str, got {type(value).__name__}" + ) + raw_str = value + return SensitiveValue( + encrypted_bytes=encrypt_string_to_bytes(raw_str), + decrypt_fn=decrypt_bytes_to_string, + is_json=self._is_json, + ) + + def compare_values(self, x: Any, y: Any) -> bool: + if x is None or y is None: + return x == y + if isinstance(x, SensitiveValue): + x = x.get_value(apply_mask=False) + if isinstance(y, SensitiveValue): + y = y.get_value(apply_mask=False) + return x == y + + +class EncryptedString(_EncryptedBase): + # Must redeclare cache_ok in this child class since we explicitly redeclare _is_json + cache_ok = True + _is_json: bool = False def process_bind_param( self, value: str | SensitiveValue[str] | None, dialect: Dialect # noqa: ARG002 @@ -144,20 +189,10 @@ def process_result_value( ) return None - def compare_values(self, x: Any, y: Any) -> bool: - if x is None or y is None: - return x == y - if isinstance(x, SensitiveValue): - x = x.get_value(apply_mask=False) - if isinstance(y, SensitiveValue): - y = y.get_value(apply_mask=False) - return x == y - -class EncryptedJson(TypeDecorator): - impl = LargeBinary - # This type's behavior is fully deterministic and doesn't depend on any external factors. +class EncryptedJson(_EncryptedBase): cache_ok = True + _is_json: bool = True def process_bind_param( self, @@ -165,9 +200,7 @@ def process_bind_param( dialect: Dialect, # noqa: ARG002 ) -> bytes | None: if value is not None: - # Handle both raw dicts and SensitiveValue wrappers if isinstance(value, SensitiveValue): - # Get raw value for storage value = value.get_value(apply_mask=False) json_str = json.dumps(value) return encrypt_string_to_bytes(json_str) @@ -184,14 +217,40 @@ def process_result_value( ) return None - def compare_values(self, x: Any, y: Any) -> bool: - if x is None or y is None: - return x == y - if isinstance(x, SensitiveValue): - x = x.get_value(apply_mask=False) - if isinstance(y, SensitiveValue): - y = y.get_value(apply_mask=False) - return x == y + +_REGISTERED_ATTRS: set[str] = set() + + +@event.listens_for(Mapper, "mapper_configured") +def _register_sensitive_value_set_events( + mapper: Mapper, + class_: type, +) -> None: + """Auto-wrap raw values in SensitiveValue when assigned to encrypted columns.""" + for prop in mapper.column_attrs: + for col in prop.columns: + if isinstance(col.type, _EncryptedBase): + col_type = col.type + attr = getattr(class_, prop.key) + + # Guard against double-registration (e.g. if mapper is + # re-configured in test setups) + attr_key = f"{class_.__qualname__}.{prop.key}" + if attr_key in _REGISTERED_ATTRS: + continue + _REGISTERED_ATTRS.add(attr_key) + + @event.listens_for(attr, "set", retval=True) + def _wrap_value( + target: Any, # noqa: ARG001 + value: Any, + oldvalue: Any, # noqa: ARG001 + initiator: Any, # noqa: ARG001 + _col_type: _EncryptedBase = col_type, + ) -> Any: + if value is not None and not isinstance(value, SensitiveValue): + return _col_type.wrap_raw(value) + return value class NullFilteredString(TypeDecorator): @@ -280,6 +339,16 @@ class User(SQLAlchemyBaseUserTableUUID, Base): TIMESTAMPAware(timezone=True), nullable=True ) + created_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), nullable=False + ) + updated_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) + default_model: Mapped[str] = mapped_column(Text, nullable=True) # organized in typical structured fashion # formatted as `displayName__provider__modelName` @@ -2370,6 +2439,38 @@ class SyncRecord(Base): ) +class HierarchyNodeByConnectorCredentialPair(Base): + """Tracks which cc_pairs reference each hierarchy node. + + During pruning, stale entries are removed for the current cc_pair. + Hierarchy nodes with zero remaining entries are then deleted. + """ + + __tablename__ = "hierarchy_node_by_connector_credential_pair" + + hierarchy_node_id: Mapped[int] = mapped_column( + ForeignKey("hierarchy_node.id", ondelete="CASCADE"), primary_key=True + ) + connector_id: Mapped[int] = mapped_column(primary_key=True) + credential_id: Mapped[int] = mapped_column(primary_key=True) + + __table_args__ = ( + ForeignKeyConstraint( + ["connector_id", "credential_id"], + [ + "connector_credential_pair.connector_id", + "connector_credential_pair.credential_id", + ], + ondelete="CASCADE", + ), + Index( + "ix_hierarchy_node_cc_pair_connector_credential", + "connector_id", + "credential_id", + ), + ) + + class DocumentByConnectorCredentialPair(Base): """Represents an indexing of a document by a specific connector / credential pair""" @@ -4926,7 +5027,9 @@ class ScimUserMapping(Base): __tablename__ = "scim_user_mapping" id: Mapped[int] = mapped_column(Integer, primary_key=True) - external_id: Mapped[str] = mapped_column(String, unique=True, index=True) + external_id: Mapped[str | None] = mapped_column( + String, unique=True, index=True, nullable=True + ) user_id: Mapped[UUID] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False ) @@ -4983,3 +5086,25 @@ class CodeInterpreterServer(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + + +class CacheStore(Base): + """Key-value cache table used by ``PostgresCacheBackend``. + + Replaces Redis for simple KV caching, locks, and list operations + when ``CACHE_BACKEND=postgres`` (NO_VECTOR_DB deployments). + + Intentionally separate from ``KVStore``: + - Stores raw bytes (LargeBinary) vs JSONB, matching Redis semantics. + - Has ``expires_at`` for TTL; rows are periodically garbage-collected. + - Holds ephemeral data (tokens, stop signals, lock state) not + persistent application config, so cleanup can be aggressive. + """ + + __tablename__ = "cache_store" + + key: Mapped[str] = mapped_column(String, primary_key=True) + value: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True) + expires_at: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) diff --git a/backend/onyx/db/persona.py b/backend/onyx/db/persona.py index a35c0c6d26f..3c396b563eb 100644 --- a/backend/onyx/db/persona.py +++ b/backend/onyx/db/persona.py @@ -205,7 +205,9 @@ def update_persona_access( NOTE: Callers are responsible for committing.""" + needs_sync = False if is_public is not None: + needs_sync = True persona = db_session.query(Persona).filter(Persona.id == persona_id).first() if persona: persona.is_public = is_public @@ -213,6 +215,7 @@ def update_persona_access( # NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares", # and a non-empty list means "replace with these shares". if user_ids is not None: + needs_sync = True db_session.query(Persona__User).filter( Persona__User.persona_id == persona_id ).delete(synchronize_session="fetch") @@ -233,6 +236,7 @@ def update_persona_access( # MIT doesn't support group-based sharing, so we allow clearing (no-op since # there shouldn't be any) but raise an error if trying to add actual groups. if group_ids is not None: + needs_sync = True db_session.query(Persona__UserGroup).filter( Persona__UserGroup.persona_id == persona_id ).delete(synchronize_session="fetch") @@ -240,6 +244,10 @@ def update_persona_access( if group_ids: raise NotImplementedError("Onyx MIT does not support group-based sharing") + # When sharing changes, user file ACLs need to be updated in the vector DB + if needs_sync: + mark_persona_user_files_for_sync(persona_id, db_session) + def create_update_persona( persona_id: int | None, @@ -851,6 +859,24 @@ def update_personas_display_priority( db_session.commit() +def mark_persona_user_files_for_sync( + persona_id: int, + db_session: Session, +) -> None: + """When persona sharing changes, mark all of its user files for sync + so that their ACLs get updated in the vector DB.""" + persona = ( + db_session.query(Persona) + .options(selectinload(Persona.user_files)) + .filter(Persona.id == persona_id) + .first() + ) + if not persona: + return + file_ids = [uf.id for uf in persona.user_files] + _mark_files_need_persona_sync(db_session, file_ids) + + def _mark_files_need_persona_sync( db_session: Session, user_file_ids: list[UUID], diff --git a/backend/onyx/db/projects.py b/backend/onyx/db/projects.py index 428e9cb56cf..ce74a650c24 100644 --- a/backend/onyx/db/projects.py +++ b/backend/onyx/db/projects.py @@ -9,8 +9,9 @@ from pydantic import ConfigDict from sqlalchemy import func from sqlalchemy.orm import Session +from starlette.background import BackgroundTasks -from onyx.background.celery.versioned_apps.client import app as client_app +from onyx.configs.app_configs import DISABLE_VECTOR_DB from onyx.configs.constants import FileOrigin from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues @@ -51,7 +52,7 @@ def create_user_files( ) -> CategorizedFilesResult: # Categorize the files - categorized_files = categorize_uploaded_files(files) + categorized_files = categorize_uploaded_files(files, db_session) # NOTE: At the moment, zip metadata is not used for user files. # Should revisit to decide whether this should be a feature. upload_response = upload_files(categorized_files.acceptable, FileOrigin.USER_FILE) @@ -105,8 +106,8 @@ def upload_files_to_user_files_with_indexing( user: User, temp_id_map: dict[str, str] | None, db_session: Session, + background_tasks: BackgroundTasks | None = None, ) -> CategorizedFilesResult: - # Validate project ownership if a project_id is provided if project_id is not None and user is not None: if not check_project_ownership(project_id, user.id, db_session): raise HTTPException(status_code=404, detail="Project not found") @@ -127,16 +128,27 @@ def upload_files_to_user_files_with_indexing( logger.warning( f"File {rejected_file.filename} rejected for {rejected_file.reason}" ) - for user_file in user_files: - task = client_app.send_task( - OnyxCeleryTask.PROCESS_SINGLE_USER_FILE, - kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id}, - queue=OnyxCeleryQueues.USER_FILE_PROCESSING, - priority=OnyxCeleryPriority.HIGH, - ) - logger.info( - f"Triggered indexing for user_file_id={user_file.id} with task_id={task.id}" - ) + + if DISABLE_VECTOR_DB and background_tasks is not None: + from onyx.background.task_utils import drain_processing_loop + + background_tasks.add_task(drain_processing_loop, tenant_id) + for user_file in user_files: + logger.info(f"Queued in-process processing for user_file_id={user_file.id}") + else: + from onyx.background.celery.versioned_apps.client import app as client_app + + for user_file in user_files: + task = client_app.send_task( + OnyxCeleryTask.PROCESS_SINGLE_USER_FILE, + kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id}, + queue=OnyxCeleryQueues.USER_FILE_PROCESSING, + priority=OnyxCeleryPriority.HIGH, + ) + logger.info( + f"Triggered indexing for user_file_id={user_file.id} " + f"with task_id={task.id}" + ) return CategorizedFilesResult( user_files=user_files, diff --git a/backend/onyx/db/rotate_encryption_key.py b/backend/onyx/db/rotate_encryption_key.py new file mode 100644 index 00000000000..ddb99fc9c43 --- /dev/null +++ b/backend/onyx/db/rotate_encryption_key.py @@ -0,0 +1,161 @@ +"""Rotate encryption key for all encrypted columns. + +Dynamically discovers all columns using EncryptedString / EncryptedJson, +decrypts each value with the old key, and re-encrypts with the current +ENCRYPTION_KEY_SECRET. + +The operation is idempotent: rows already encrypted with the current key +are skipped. Commits are made in batches so a crash mid-rotation can be +safely resumed by re-running. +""" + +import json +from typing import Any + +from sqlalchemy import LargeBinary +from sqlalchemy import select +from sqlalchemy import update +from sqlalchemy.orm import Session + +from onyx.configs.app_configs import ENCRYPTION_KEY_SECRET +from onyx.db.models import Base +from onyx.db.models import EncryptedJson +from onyx.db.models import EncryptedString +from onyx.utils.encryption import decrypt_bytes_to_string +from onyx.utils.logger import setup_logger +from onyx.utils.variable_functionality import global_version + +logger = setup_logger() + +_BATCH_SIZE = 500 + + +def _can_decrypt_with_current_key(data: bytes) -> bool: + """Check if data is already encrypted with the current key. + + Passes the key explicitly so the fallback-to-raw-decode path in + _decrypt_bytes is NOT triggered — a clean success/failure signal. + """ + try: + decrypt_bytes_to_string(data, key=ENCRYPTION_KEY_SECRET) + return True + except Exception: + return False + + +def _discover_encrypted_columns() -> list[tuple[type, str, list[str], bool]]: + """Walk all ORM models and find columns using EncryptedString/EncryptedJson. + + Returns list of (ModelClass, column_attr_name, [pk_attr_names], is_json). + """ + results: list[tuple[type, str, list[str], bool]] = [] + + for mapper in Base.registry.mappers: + model_cls = mapper.class_ + pk_names = [col.key for col in mapper.primary_key] + + for prop in mapper.column_attrs: + for col in prop.columns: + if isinstance(col.type, EncryptedJson): + results.append((model_cls, prop.key, pk_names, True)) + elif isinstance(col.type, EncryptedString): + results.append((model_cls, prop.key, pk_names, False)) + + return results + + +def rotate_encryption_key( + db_session: Session, + old_key: str | None, + dry_run: bool = False, +) -> dict[str, int]: + """Decrypt all encrypted columns with old_key and re-encrypt with the current key. + + Args: + db_session: Active database session. + old_key: The previous encryption key. Pass None or "" if values were + not previously encrypted with a key. + dry_run: If True, count rows that need rotation without modifying data. + + Returns: + Dict of "table.column" -> number of rows re-encrypted (or would be). + + Commits every _BATCH_SIZE rows so that locks are held briefly and progress + is preserved on crash. Already-rotated rows are detected and skipped, + making the operation safe to re-run. + """ + if not global_version.is_ee_version(): + raise RuntimeError("EE mode is not enabled — rotation requires EE encryption.") + + if not ENCRYPTION_KEY_SECRET: + raise RuntimeError( + "ENCRYPTION_KEY_SECRET is not set — cannot rotate. " + "Set the target encryption key in the environment before running." + ) + + encrypted_columns = _discover_encrypted_columns() + totals: dict[str, int] = {} + + for model_cls, col_name, pk_names, is_json in encrypted_columns: + table_name: str = model_cls.__tablename__ # type: ignore[attr-defined] + col_attr = getattr(model_cls, col_name) + pk_attrs = [getattr(model_cls, pk) for pk in pk_names] + + # Read raw bytes directly, bypassing the TypeDecorator + raw_col = col_attr.property.columns[0] + + stmt = select(*pk_attrs, raw_col.cast(LargeBinary)).where(col_attr.is_not(None)) + rows = db_session.execute(stmt).all() + + reencrypted = 0 + batch_pending = 0 + for row in rows: + raw_bytes: bytes | None = row[-1] + if raw_bytes is None: + continue + + if _can_decrypt_with_current_key(raw_bytes): + continue + + try: + if not old_key: + decrypted_str = raw_bytes.decode("utf-8") + else: + decrypted_str = decrypt_bytes_to_string(raw_bytes, key=old_key) + + # For EncryptedJson, parse back to dict so the TypeDecorator + # can json.dumps() it cleanly (avoids double-encoding). + value: Any = json.loads(decrypted_str) if is_json else decrypted_str + except (ValueError, UnicodeDecodeError) as e: + pk_vals = [row[i] for i in range(len(pk_names))] + logger.warning( + f"Could not decrypt/parse {table_name}.{col_name} " + f"row {pk_vals} — skipping: {e}" + ) + continue + + if not dry_run: + pk_filters = [pk_attr == row[i] for i, pk_attr in enumerate(pk_attrs)] + update_stmt = ( + update(model_cls).where(*pk_filters).values({col_name: value}) + ) + db_session.execute(update_stmt) + batch_pending += 1 + + if batch_pending >= _BATCH_SIZE: + db_session.commit() + batch_pending = 0 + reencrypted += 1 + + # Flush remaining rows in this column + if batch_pending > 0: + db_session.commit() + + if reencrypted > 0: + totals[f"{table_name}.{col_name}"] = reencrypted + logger.info( + f"{'[DRY RUN] Would re-encrypt' if dry_run else 'Re-encrypted'} " + f"{reencrypted} value(s) in {table_name}.{col_name}" + ) + + return totals diff --git a/backend/onyx/db/search_settings.py b/backend/onyx/db/search_settings.py index ed6477205e4..b16f517eec2 100644 --- a/backend/onyx/db/search_settings.py +++ b/backend/onyx/db/search_settings.py @@ -129,7 +129,7 @@ def get_current_search_settings(db_session: Session) -> SearchSettings: latest_settings = result.scalars().first() if not latest_settings: - raise RuntimeError("No search settings specified, DB is not in a valid state") + raise RuntimeError("No search settings specified; DB is not in a valid state.") return latest_settings diff --git a/backend/onyx/db/tools.py b/backend/onyx/db/tools.py index af2606525e8..e51409bb49f 100644 --- a/backend/onyx/db/tools.py +++ b/backend/onyx/db/tools.py @@ -13,12 +13,15 @@ from onyx.db.constants import UnsetType from onyx.db.enums import MCPServerStatus from onyx.db.models import MCPServer +from onyx.db.models import OAuthConfig from onyx.db.models import Tool from onyx.db.models import ToolCall from onyx.server.features.tool.models import Header from onyx.tools.built_in_tools import BUILT_IN_TOOL_TYPES from onyx.utils.headers import HeaderItemDict from onyx.utils.logger import setup_logger +from onyx.utils.postgres_sanitization import sanitize_json_like +from onyx.utils.postgres_sanitization import sanitize_string if TYPE_CHECKING: pass @@ -159,10 +162,26 @@ def update_tool( ] if passthrough_auth is not None: tool.passthrough_auth = passthrough_auth + old_oauth_config_id = tool.oauth_config_id if not isinstance(oauth_config_id, UnsetType): tool.oauth_config_id = oauth_config_id - db_session.commit() + db_session.flush() + # Clean up orphaned OAuthConfig if the oauth_config_id was changed + if ( + old_oauth_config_id is not None + and not isinstance(oauth_config_id, UnsetType) + and old_oauth_config_id != oauth_config_id + ): + other_tools = db_session.scalars( + select(Tool).where(Tool.oauth_config_id == old_oauth_config_id) + ).all() + if not other_tools: + oauth_config = db_session.get(OAuthConfig, old_oauth_config_id) + if oauth_config: + db_session.delete(oauth_config) + + db_session.commit() return tool @@ -171,8 +190,21 @@ def delete_tool__no_commit(tool_id: int, db_session: Session) -> None: if tool is None: raise ValueError(f"Tool with ID {tool_id} does not exist") + oauth_config_id = tool.oauth_config_id + db_session.delete(tool) - db_session.flush() # Don't commit yet, let caller decide when to commit + db_session.flush() + + # Clean up orphaned OAuthConfig if no other tools reference it + if oauth_config_id is not None: + other_tools = db_session.scalars( + select(Tool).where(Tool.oauth_config_id == oauth_config_id) + ).all() + if not other_tools: + oauth_config = db_session.get(OAuthConfig, oauth_config_id) + if oauth_config: + db_session.delete(oauth_config) + db_session.flush() def get_builtin_tool( @@ -256,11 +288,13 @@ def create_tool_call_no_commit( tab_index=tab_index, tool_id=tool_id, tool_call_id=tool_call_id, - reasoning_tokens=reasoning_tokens, - tool_call_arguments=tool_call_arguments, - tool_call_response=tool_call_response, + reasoning_tokens=( + sanitize_string(reasoning_tokens) if reasoning_tokens else reasoning_tokens + ), + tool_call_arguments=sanitize_json_like(tool_call_arguments), + tool_call_response=sanitize_json_like(tool_call_response), tool_call_tokens=tool_call_tokens, - generated_images=generated_images, + generated_images=sanitize_json_like(generated_images), ) db_session.add(tool_call) diff --git a/backend/onyx/db/user_file.py b/backend/onyx/db/user_file.py index 5e4800a2149..1de96ff917c 100644 --- a/backend/onyx/db/user_file.py +++ b/backend/onyx/db/user_file.py @@ -3,9 +3,11 @@ from sqlalchemy import func from sqlalchemy import select +from sqlalchemy.orm import joinedload from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session +from onyx.db.models import Persona from onyx.db.models import Project__UserFile from onyx.db.models import UserFile @@ -118,3 +120,31 @@ def get_file_ids_by_user_file_ids( ) -> list[str]: user_files = db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).all() return [user_file.file_id for user_file in user_files] + + +def fetch_user_files_with_access_relationships( + user_file_ids: list[str], + db_session: Session, + eager_load_groups: bool = False, +) -> list[UserFile]: + """Fetch user files with the owner and assistant relationships + eagerly loaded (needed for computing access control). + + When eager_load_groups is True, Persona.groups is also loaded so that + callers can extract user-group names without a second DB round-trip.""" + persona_sub_options = [ + selectinload(Persona.users), + selectinload(Persona.user), + ] + if eager_load_groups: + persona_sub_options.append(selectinload(Persona.groups)) + + return ( + db_session.query(UserFile) + .options( + joinedload(UserFile.user), + selectinload(UserFile.assistants).options(*persona_sub_options), + ) + .filter(UserFile.id.in_(user_file_ids)) + .all() + ) diff --git a/backend/onyx/db/users.py b/backend/onyx/db/users.py index 17851bc542e..8f29737c6ef 100644 --- a/backend/onyx/db/users.py +++ b/backend/onyx/db/users.py @@ -4,6 +4,7 @@ from fastapi import HTTPException from fastapi_users.password import PasswordHelper +from sqlalchemy import case from sqlalchemy import func from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -11,6 +12,7 @@ from sqlalchemy.sql import expression from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import KeyedColumnElement +from sqlalchemy.sql.expression import or_ from onyx.auth.invited_users import remove_user_from_invited_users from onyx.auth.schemas import UserRole @@ -24,6 +26,7 @@ from onyx.db.models import SamlAccount from onyx.db.models import User from onyx.db.models import User__UserGroup +from onyx.db.models import UserGroup from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop @@ -162,7 +165,13 @@ def _get_accepted_user_where_clause( where_clause.append(User.role != UserRole.EXT_PERM_USER) if email_filter_string is not None: - where_clause.append(email_col.ilike(f"%{email_filter_string}%")) + personal_name_col: KeyedColumnElement[Any] = User.__table__.c.personal_name + where_clause.append( + or_( + email_col.ilike(f"%{email_filter_string}%"), + personal_name_col.ilike(f"%{email_filter_string}%"), + ) + ) if roles_filter: where_clause.append(User.role.in_(roles_filter)) @@ -173,6 +182,21 @@ def _get_accepted_user_where_clause( return where_clause +def get_all_accepted_users( + db_session: Session, + include_external: bool = False, +) -> Sequence[User]: + """Returns all accepted users without pagination. + Uses the same filtering as the paginated endpoint but without + search, role, or active filters.""" + stmt = select(User) + where_clause = _get_accepted_user_where_clause( + include_external=include_external, + ) + stmt = stmt.where(*where_clause).order_by(User.email) + return db_session.scalars(stmt).unique().all() + + def get_page_of_filtered_users( db_session: Session, page_size: int, @@ -218,6 +242,41 @@ def get_total_filtered_users_count( return db_session.scalar(total_count_stmt) or 0 +def get_user_counts_by_role_and_status( + db_session: Session, +) -> dict[str, dict[str, int]]: + """Returns user counts grouped by role and by active/inactive status. + + Excludes API key users, anonymous users, and no-auth placeholder users. + Uses a single query with conditional aggregation. + """ + base_where = _get_accepted_user_where_clause() + role_col = User.__table__.c.role + is_active_col = User.__table__.c.is_active + + stmt = ( + select( + role_col, + func.count().label("total"), + func.sum(case((is_active_col.is_(True), 1), else_=0)).label("active"), + func.sum(case((is_active_col.is_(False), 1), else_=0)).label("inactive"), + ) + .where(*base_where) + .group_by(role_col) + ) + + role_counts: dict[str, int] = {} + status_counts: dict[str, int] = {"active": 0, "inactive": 0} + + for role_val, total, active, inactive in db_session.execute(stmt).all(): + key = role_val.value if hasattr(role_val, "value") else str(role_val) + role_counts[key] = total + status_counts["active"] += active or 0 + status_counts["inactive"] += inactive or 0 + + return {"role_counts": role_counts, "status_counts": status_counts} + + def get_user_by_email(email: str, db_session: Session) -> User | None: user = ( db_session.query(User) @@ -294,24 +353,23 @@ def batch_add_ext_perm_user_if_not_exists( lower_emails = [email.lower() for email in emails] found_users, missing_lower_emails = _get_users_by_emails(db_session, lower_emails) - new_users: list[User] = [] + # Use savepoints (begin_nested) so that a failed insert only rolls back + # that single user, not the entire transaction. A plain rollback() would + # discard all previously flushed users in the same transaction. + # We also avoid add_all() because SQLAlchemy 2.0's insertmanyvalues + # batch path hits a UUID sentinel mismatch with server_default columns. for email in missing_lower_emails: - new_users.append(_generate_ext_permissioned_user(email=email)) - - try: - db_session.add_all(new_users) - db_session.commit() - except IntegrityError: - db_session.rollback() - if not continue_on_error: - raise - for user in new_users: - try: - db_session.add(user) - db_session.commit() - except IntegrityError: - db_session.rollback() - continue + user = _generate_ext_permissioned_user(email=email) + savepoint = db_session.begin_nested() + try: + db_session.add(user) + savepoint.commit() + except IntegrityError: + savepoint.rollback() + if not continue_on_error: + raise + + db_session.commit() # Fetch all users again to ensure we have the most up-to-date list all_users, _ = _get_users_by_emails(db_session, lower_emails) return all_users @@ -358,3 +416,28 @@ def delete_user_from_db( # NOTE: edge case may exist with race conditions # with this `invited user` scheme generally. remove_user_from_invited_users(user_to_delete.email) + + +def batch_get_user_groups( + db_session: Session, + user_ids: list[UUID], +) -> dict[UUID, list[tuple[int, str]]]: + """Fetch group memberships for a batch of users in a single query. + Returns a mapping of user_id -> list of (group_id, group_name) tuples.""" + if not user_ids: + return {} + + rows = db_session.execute( + select( + User__UserGroup.user_id, + UserGroup.id, + UserGroup.name, + ) + .join(UserGroup, UserGroup.id == User__UserGroup.user_group_id) + .where(User__UserGroup.user_id.in_(user_ids)) + ).all() + + result: dict[UUID, list[tuple[int, str]]] = {uid: [] for uid in user_ids} + for user_id, group_id, group_name in rows: + result[user_id].append((group_id, group_name)) + return result diff --git a/backend/onyx/document_index/FILTER_SEMANTICS.md b/backend/onyx/document_index/FILTER_SEMANTICS.md new file mode 100644 index 00000000000..5b0cc763e8f --- /dev/null +++ b/backend/onyx/document_index/FILTER_SEMANTICS.md @@ -0,0 +1,103 @@ +# Vector DB Filter Semantics + +How `IndexFilters` fields combine into the final query filter. Applies to both Vespa and OpenSearch. + +## Filter categories + +| Category | Fields | Join logic | +|---|---|---| +| **Visibility** | `hidden` | Always applied (unless `include_hidden`) | +| **Tenant** | `tenant_id` | AND (multi-tenant only) | +| **ACL** | `access_control_list` | OR within, AND with rest | +| **Narrowing** | `source_type`, `tags`, `time_cutoff` | Each OR within, AND with rest | +| **Knowledge scope** | `document_set`, `user_file_ids`, `attached_document_ids`, `hierarchy_node_ids` | OR within group, AND with rest | +| **Additive scope** | `project_id`, `persona_id` | OR'd into knowledge scope **only when** a knowledge scope filter already exists | + +## How filters combine + +All categories are AND'd together. Within the knowledge scope category, individual filters are OR'd. + +``` +NOT hidden +AND tenant = T -- if multi-tenant +AND (acl contains A1 OR acl contains A2) +AND (source_type = S1 OR ...) -- if set +AND (tag = T1 OR ...) -- if set +AND -- see below +AND time >= cutoff -- if set +``` + +## Knowledge scope rules + +The knowledge scope filter controls **what knowledge an assistant can access**. + +### No explicit knowledge attached + +When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_node_ids` are all empty/None: + +- **No knowledge scope filter is applied.** The assistant can see everything (subject to ACL). +- `project_id` and `persona_id` are ignored — they never restrict on their own. + +### One explicit knowledge type + +``` +-- Only document sets +AND (document_sets contains "Engineering" OR document_sets contains "Legal") + +-- Only user files +AND (document_id = "uuid-1" OR document_id = "uuid-2") +``` + +### Multiple explicit knowledge types (OR'd) + +``` +-- Document sets + user files +AND ( + document_sets contains "Engineering" + OR document_id = "uuid-1" +) +``` + +### Explicit knowledge + overflowing user files + +When an explicit knowledge restriction is in effect **and** `project_id` or `persona_id` is set (user files overflowed the LLM context window), the additive scopes widen the filter: + +``` +-- Document sets + persona user files overflowed +AND ( + document_sets contains "Engineering" + OR personas contains 42 +) + +-- User files + project files overflowed +AND ( + document_id = "uuid-1" + OR user_project contains 7 +) +``` + +### Only project_id or persona_id (no explicit knowledge) + +No knowledge scope filter. The assistant searches everything. + +``` +-- Just ACL, no restriction +NOT hidden +AND (acl contains ...) +``` + +## Field reference + +| Filter field | Vespa field | Vespa type | Purpose | +|---|---|---|---| +| `document_set` | `document_sets` | `weightedset` | Connector doc sets attached to assistant | +| `user_file_ids` | `document_id` | `string` | User files uploaded to assistant | +| `attached_document_ids` | `document_id` | `string` | Documents explicitly attached (OpenSearch only) | +| `hierarchy_node_ids` | `ancestor_hierarchy_node_ids` | `array` | Folder/space nodes (OpenSearch only) | +| `project_id` | `user_project` | `array` | Project tag for overflowing user files | +| `persona_id` | `personas` | `array` | Persona tag for overflowing user files | +| `access_control_list` | `access_control_list` | `weightedset` | ACL entries for the requesting user | +| `source_type` | `source_type` | `string` | Connector source type (e.g. `web`, `jira`) | +| `tags` | `metadata_list` | `array` | Document metadata tags | +| `time_cutoff` | `doc_updated_at` | `long` | Minimum document update timestamp | +| `tenant_id` | `tenant_id` | `string` | Tenant isolation (multi-tenant) | diff --git a/backend/onyx/document_index/document_index_utils.py b/backend/onyx/document_index/document_index_utils.py index dc78921412d..f3b8489e79e 100644 --- a/backend/onyx/document_index/document_index_utils.py +++ b/backend/onyx/document_index/document_index_utils.py @@ -32,9 +32,6 @@ def get_multipass_config(search_settings: SearchSettings) -> MultipassConfig: Determines whether to enable multipass and large chunks by examining the current search settings and the embedder configuration. """ - if not search_settings: - return MultipassConfig(multipass_indexing=False, enable_large_chunks=False) - multipass = should_use_multipass(search_settings) enable_large_chunks = SearchSettings.can_use_large_chunks( multipass, search_settings.model_name, search_settings.provider_type diff --git a/backend/onyx/document_index/factory.py b/backend/onyx/document_index/factory.py index 1bf6a03f026..a7db8e54cca 100644 --- a/backend/onyx/document_index/factory.py +++ b/backend/onyx/document_index/factory.py @@ -26,11 +26,10 @@ def get_default_document_index( To be used for retrieval only. Indexing should be done through both indices until Vespa is deprecated. - Pre-existing docstring for this function, although secondary indices are not - currently supported: Primary index is the index that is used for querying/updating etc. Secondary index is for when both the currently used index and the upcoming index both - need to be updated, updates are applied to both indices. + need to be updated. Updates are applied to both indices. + WARNING: In that case, get_all_document_indices should be used. """ if DISABLE_VECTOR_DB: return DisabledDocumentIndex( @@ -51,11 +50,26 @@ def get_default_document_index( opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session) if opensearch_retrieval_enabled: indexing_setting = IndexingSetting.from_db_model(search_settings) + secondary_indexing_setting = ( + IndexingSetting.from_db_model(secondary_search_settings) + if secondary_search_settings + else None + ) return OpenSearchOldDocumentIndex( index_name=search_settings.index_name, embedding_dim=indexing_setting.final_embedding_dim, embedding_precision=indexing_setting.embedding_precision, secondary_index_name=secondary_index_name, + secondary_embedding_dim=( + secondary_indexing_setting.final_embedding_dim + if secondary_indexing_setting + else None + ), + secondary_embedding_precision=( + secondary_indexing_setting.embedding_precision + if secondary_indexing_setting + else None + ), large_chunks_enabled=search_settings.large_chunks_enabled, secondary_large_chunks_enabled=secondary_large_chunks_enabled, multitenant=MULTI_TENANT, @@ -86,8 +100,7 @@ def get_all_document_indices( Used for indexing only. Until Vespa is deprecated we will index into both document indices. Retrieval is done through only one index however. - Large chunks and secondary indices are not currently supported so we - hardcode appropriate values. + Large chunks are not currently supported so we hardcode appropriate values. NOTE: Make sure the Vespa index object is returned first. In the rare event that there is some conflict between indexing and the migration task, it is @@ -123,13 +136,36 @@ def get_all_document_indices( opensearch_document_index: OpenSearchOldDocumentIndex | None = None if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX: indexing_setting = IndexingSetting.from_db_model(search_settings) + secondary_indexing_setting = ( + IndexingSetting.from_db_model(secondary_search_settings) + if secondary_search_settings + else None + ) opensearch_document_index = OpenSearchOldDocumentIndex( index_name=search_settings.index_name, embedding_dim=indexing_setting.final_embedding_dim, embedding_precision=indexing_setting.embedding_precision, - secondary_index_name=None, - large_chunks_enabled=False, - secondary_large_chunks_enabled=None, + secondary_index_name=( + secondary_search_settings.index_name + if secondary_search_settings + else None + ), + secondary_embedding_dim=( + secondary_indexing_setting.final_embedding_dim + if secondary_indexing_setting + else None + ), + secondary_embedding_precision=( + secondary_indexing_setting.embedding_precision + if secondary_indexing_setting + else None + ), + large_chunks_enabled=search_settings.large_chunks_enabled, + secondary_large_chunks_enabled=( + secondary_search_settings.large_chunks_enabled + if secondary_search_settings + else None + ), multitenant=MULTI_TENANT, httpx_client=httpx_client, ) diff --git a/backend/onyx/document_index/opensearch/client.py b/backend/onyx/document_index/opensearch/client.py index 7fe80358796..351bbe89278 100644 --- a/backend/onyx/document_index/opensearch/client.py +++ b/backend/onyx/document_index/opensearch/client.py @@ -61,6 +61,25 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]): explanation: dict[str, Any] | None = None +class IndexInfo(BaseModel): + """ + Represents information about an OpenSearch index. + """ + + model_config = {"frozen": True} + + name: str + health: str + status: str + num_primary_shards: str + num_replica_shards: str + docs_count: str + docs_deleted: str + created_at: str + total_size: str + primary_shards_size: str + + def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]: """Recursively replaces vectors in the body with their length. @@ -159,8 +178,8 @@ def create_search_pipeline( Raises: Exception: There was an error creating the search pipeline. """ - result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body) - if not result.get("acknowledged", False): + response = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body) + if not response.get("acknowledged", False): raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.") @log_function_time(print_only=True, debug_only=True, include_args=True) @@ -173,8 +192,8 @@ def delete_search_pipeline(self, pipeline_id: str) -> None: Raises: Exception: There was an error deleting the search pipeline. """ - result = self._client.search_pipeline.delete(id=pipeline_id) - if not result.get("acknowledged", False): + response = self._client.search_pipeline.delete(id=pipeline_id) + if not response.get("acknowledged", False): raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.") @log_function_time(print_only=True, debug_only=True, include_args=True) @@ -198,6 +217,34 @@ def put_cluster_settings(self, settings: dict[str, Any]) -> bool: logger.error(f"Failed to put cluster settings: {response}.") return False + @log_function_time(print_only=True, debug_only=True) + def list_indices_with_info(self) -> list[IndexInfo]: + """ + Lists the indices in the OpenSearch cluster with information about each + index. + + Returns: + A list of IndexInfo objects for each index. + """ + response = self._client.cat.indices(format="json") + indices: list[IndexInfo] = [] + for raw_index_info in response: + indices.append( + IndexInfo( + name=raw_index_info.get("index", ""), + health=raw_index_info.get("health", ""), + status=raw_index_info.get("status", ""), + num_primary_shards=raw_index_info.get("pri", ""), + num_replica_shards=raw_index_info.get("rep", ""), + docs_count=raw_index_info.get("docs.count", ""), + docs_deleted=raw_index_info.get("docs.deleted", ""), + created_at=raw_index_info.get("creation.date.string", ""), + total_size=raw_index_info.get("store.size", ""), + primary_shards_size=raw_index_info.get("pri.store.size", ""), + ) + ) + return indices + @log_function_time(print_only=True, debug_only=True) def ping(self) -> bool: """Pings the OpenSearch cluster. diff --git a/backend/onyx/document_index/opensearch/constants.py b/backend/onyx/document_index/opensearch/constants.py index 61257190a03..1447570e056 100644 --- a/backend/onyx/document_index/opensearch/constants.py +++ b/backend/onyx/document_index/opensearch/constants.py @@ -1,5 +1,10 @@ # Default value for the maximum number of tokens a chunk can hold, if none is # specified when creating an index. +from onyx.configs.app_configs import ( + OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES, +) + + DEFAULT_MAX_CHUNK_SIZE = 512 # Size of the dynamic list used to consider elements during kNN graph creation. @@ -10,27 +15,43 @@ # quality but increase memory footprint. Values typically range between 12 - 48. M = 32 # Set relatively high for better accuracy. -# When performing hybrid search, we need to consider more candidates than the number of results to be returned. -# This is because the scoring is hybrid and the results are reordered due to the hybrid scoring. -# Higher = more candidates for hybrid fusion = better retrieval accuracy, but results in more computation per query. -# Imagine a simple case with a single keyword query and a single vector query and we want 10 final docs. -# If we only fetch 10 candidates from each of keyword and vector, they would have to have perfect overlap to get a good hybrid -# ranking for the 10 results. If we fetch 1000 candidates from each, we have a much higher chance of all 10 of the final desired -# docs showing up and getting scored. In worse situations, the final 10 docs don't even show up as the final 10 (worse than just -# a miss at the reranking step). -DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = 750 - -# Number of vectors to examine for top k neighbors for the HNSW method. +# When performing hybrid search, we need to consider more candidates than the +# number of results to be returned. This is because the scoring is hybrid and +# the results are reordered due to the hybrid scoring. Higher = more candidates +# for hybrid fusion = better retrieval accuracy, but results in more computation +# per query. Imagine a simple case with a single keyword query and a single +# vector query and we want 10 final docs. If we only fetch 10 candidates from +# each of keyword and vector, they would have to have perfect overlap to get a +# good hybrid ranking for the 10 results. If we fetch 1000 candidates from each, +# we have a much higher chance of all 10 of the final desired docs showing up +# and getting scored. In worse situations, the final 10 docs don't even show up +# as the final 10 (worse than just a miss at the reranking step). +DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = ( + OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES + if OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES > 0 + else 750 +) + +# Number of vectors to examine to decide the top k neighbors for the HNSW +# method. +# NOTE: "When creating a search query, you must specify k. If you provide both k +# and ef_search, then the larger value is passed to the engine. If ef_search is +# larger than k, you can provide the size parameter to limit the final number of +# results to k." from +# https://docs.opensearch.org/latest/query-dsl/specialized/k-nn/index/#ef_search EF_SEARCH = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES -# Since the titles are included in the contents, they are heavily downweighted as they act as a boost -# rather than an independent scoring component. +# Since the titles are included in the contents, the embedding matches are +# heavily downweighted as they act as a boost rather than an independent scoring +# component. SEARCH_TITLE_VECTOR_WEIGHT = 0.1 SEARCH_CONTENT_VECTOR_WEIGHT = 0.45 -# Single keyword weight for both title and content (merged from former title keyword + content keyword). +# Single keyword weight for both title and content (merged from former title +# keyword + content keyword). SEARCH_KEYWORD_WEIGHT = 0.45 -# NOTE: it is critical that the order of these weights matches the order of the sub-queries in the hybrid search. +# NOTE: It is critical that the order of these weights matches the order of the +# sub-queries in the hybrid search. HYBRID_SEARCH_NORMALIZATION_WEIGHTS = [ SEARCH_TITLE_VECTOR_WEIGHT, SEARCH_CONTENT_VECTOR_WEIGHT, diff --git a/backend/onyx/document_index/opensearch/opensearch_document_index.py b/backend/onyx/document_index/opensearch/opensearch_document_index.py index 2013f5ede90..7f233493bff 100644 --- a/backend/onyx/document_index/opensearch/opensearch_document_index.py +++ b/backend/onyx/document_index/opensearch/opensearch_document_index.py @@ -6,7 +6,6 @@ from opensearchpy import NotFoundError from onyx.access.models import DocumentAccess -from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT from onyx.configs.chat_configs import NUM_RETURNED_HITS from onyx.configs.chat_configs import TITLE_CONTENT_RATIO @@ -272,6 +271,9 @@ def __init__( embedding_dim: int, embedding_precision: EmbeddingPrecision, secondary_index_name: str | None, + secondary_embedding_dim: int | None, + secondary_embedding_precision: EmbeddingPrecision | None, + # NOTE: We do not support large chunks right now. large_chunks_enabled: bool, # noqa: ARG002 secondary_large_chunks_enabled: bool | None, # noqa: ARG002 multitenant: bool = False, @@ -287,12 +289,25 @@ def __init__( f"Expected {MULTI_TENANT}, got {multitenant}." ) tenant_id = get_current_tenant_id() + tenant_state = TenantState(tenant_id=tenant_id, multitenant=multitenant) self._real_index = OpenSearchDocumentIndex( - tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant), + tenant_state=tenant_state, index_name=index_name, embedding_dim=embedding_dim, embedding_precision=embedding_precision, ) + self._secondary_real_index: OpenSearchDocumentIndex | None = None + if self.secondary_index_name: + if secondary_embedding_dim is None or secondary_embedding_precision is None: + raise ValueError( + "Bug: Secondary index embedding dimension and precision are not set." + ) + self._secondary_real_index = OpenSearchDocumentIndex( + tenant_state=tenant_state, + index_name=self.secondary_index_name, + embedding_dim=secondary_embedding_dim, + embedding_precision=secondary_embedding_precision, + ) @staticmethod def register_multitenant_indices( @@ -308,19 +323,38 @@ def ensure_indices_exist( self, primary_embedding_dim: int, primary_embedding_precision: EmbeddingPrecision, - secondary_index_embedding_dim: int | None, # noqa: ARG002 - secondary_index_embedding_precision: EmbeddingPrecision | None, # noqa: ARG002 + secondary_index_embedding_dim: int | None, + secondary_index_embedding_precision: EmbeddingPrecision | None, ) -> None: - # Only handle primary index for now, ignore secondary. - return self._real_index.verify_and_create_index_if_necessary( + self._real_index.verify_and_create_index_if_necessary( primary_embedding_dim, primary_embedding_precision ) + if self.secondary_index_name: + if ( + secondary_index_embedding_dim is None + or secondary_index_embedding_precision is None + ): + raise ValueError( + "Bug: Secondary index embedding dimension and precision are not set." + ) + assert ( + self._secondary_real_index is not None + ), "Bug: Secondary index is not initialized." + self._secondary_real_index.verify_and_create_index_if_necessary( + secondary_index_embedding_dim, secondary_index_embedding_precision + ) def index( self, chunks: list[DocMetadataAwareIndexChunk], index_batch_params: IndexBatchParams, ) -> set[OldDocumentInsertionRecord]: + """ + NOTE: Do NOT consider the secondary index here. A separate indexing + pipeline will be responsible for indexing to the secondary index. This + design is not ideal and we should reconsider this when revamping index + swapping. + """ # Convert IndexBatchParams to IndexingMetadata. chunk_counts: dict[str, IndexingMetadata.ChunkCounts] = {} for doc_id in index_batch_params.doc_id_to_new_chunk_cnt: @@ -352,7 +386,20 @@ def delete_single( tenant_id: str, # noqa: ARG002 chunk_count: int | None, ) -> int: - return self._real_index.delete(doc_id, chunk_count) + """ + NOTE: Remember to handle the secondary index here. There is no separate + pipeline for deleting chunks in the secondary index. This design is not + ideal and we should reconsider this when revamping index swapping. + """ + total_chunks_deleted = self._real_index.delete(doc_id, chunk_count) + if self.secondary_index_name: + assert ( + self._secondary_real_index is not None + ), "Bug: Secondary index is not initialized." + total_chunks_deleted += self._secondary_real_index.delete( + doc_id, chunk_count + ) + return total_chunks_deleted def update_single( self, @@ -363,6 +410,11 @@ def update_single( fields: VespaDocumentFields | None, user_fields: VespaDocumentUserFields | None, ) -> None: + """ + NOTE: Remember to handle the secondary index here. There is no separate + pipeline for updating chunks in the secondary index. This design is not + ideal and we should reconsider this when revamping index swapping. + """ if fields is None and user_fields is None: logger.warning( f"Tried to update document {doc_id} with no updated fields or user fields." @@ -381,18 +433,27 @@ def update_single( hidden=fields.hidden if fields else None, project_ids=( set(user_fields.user_projects) - if user_fields and user_fields.user_projects + # NOTE: Empty user_projects is semantically different from None + # user_projects. + if user_fields and user_fields.user_projects is not None else None ), persona_ids=( set(user_fields.personas) - if user_fields and user_fields.personas + # NOTE: Empty personas is semantically different from None + # personas. + if user_fields and user_fields.personas is not None else None ), ) try: self._real_index.update([update_request]) + if self.secondary_index_name: + assert ( + self._secondary_real_index is not None + ), "Bug: Secondary index is not initialized." + self._secondary_real_index.update([update_request]) except NotFoundError: logger.exception( f"Tried to update document {doc_id} but at least one of its chunks was not found in OpenSearch. " @@ -563,12 +624,7 @@ def verify_and_create_index_if_necessary( ) if not self._client.index_exists(): - if USING_AWS_MANAGED_OPENSEARCH: - index_settings = ( - DocumentSchema.get_index_settings_for_aws_managed_opensearch() - ) - else: - index_settings = DocumentSchema.get_index_settings() + index_settings = DocumentSchema.get_index_settings_based_on_environment() self._client.create_index( mappings=expected_mappings, settings=index_settings, @@ -687,7 +743,8 @@ def delete( The number of chunks successfully deleted. """ logger.debug( - f"[OpenSearchDocumentIndex] Deleting document {document_id} from index {self._index_name}." + f"[OpenSearchDocumentIndex] Deleting document {document_id} from index " + f"{self._index_name}." ) query_body = DocumentQuery.delete_from_document_id_query( document_id=document_id, @@ -723,7 +780,8 @@ def update( specified documents. """ logger.debug( - f"[OpenSearchDocumentIndex] Updating {len(update_requests)} chunks for index {self._index_name}." + f"[OpenSearchDocumentIndex] Updating {len(update_requests)} chunks for index " + f"{self._index_name}." ) for update_request in update_requests: properties_to_update: dict[str, Any] = dict() @@ -779,9 +837,11 @@ def update( # here. # TODO(andrei): Fix the aforementioned race condition. raise ChunkCountNotFoundError( - f"Tried to update document {doc_id} but its chunk count is not known. Older versions of the " - "application used to permit this but is not a supported state for a document when using OpenSearch. " - "The document was likely just added to the indexing pipeline and the chunk count will be updated shortly." + f"Tried to update document {doc_id} but its chunk count is not known. " + "Older versions of the application used to permit this but is not a " + "supported state for a document when using OpenSearch. The document was " + "likely just added to the indexing pipeline and the chunk count will be " + "updated shortly." ) if doc_chunk_count == 0: raise ValueError( @@ -813,7 +873,8 @@ def id_based_retrieval( chunk IDs vs querying for matching document chunks. """ logger.debug( - f"[OpenSearchDocumentIndex] Retrieving {len(chunk_requests)} chunks for index {self._index_name}." + f"[OpenSearchDocumentIndex] Retrieving {len(chunk_requests)} chunks for index " + f"{self._index_name}." ) results: list[InferenceChunk] = [] for chunk_request in chunk_requests: @@ -860,7 +921,8 @@ def hybrid_retrieval( num_to_retrieve: int, ) -> list[InferenceChunk]: logger.debug( - f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index {self._index_name}." + f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index " + f"{self._index_name}." ) # TODO(andrei): This could be better, the caller should just make this # decision when passing in the query param. See the above comment in the @@ -880,8 +942,10 @@ def hybrid_retrieval( index_filters=filters, include_hidden=False, ) - # NOTE: Using z-score normalization here because it's better for hybrid search from a theoretical standpoint. - # Empirically on a small dataset of up to 10K docs, it's not very different. Likely more impactful at scale. + # NOTE: Using z-score normalization here because it's better for hybrid + # search from a theoretical standpoint. Empirically on a small dataset + # of up to 10K docs, it's not very different. Likely more impactful at + # scale. # https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/ search_hits: list[SearchHit[DocumentChunk]] = self._client.search( body=query_body, @@ -908,7 +972,8 @@ def random_retrieval( dirty: bool | None = None, # noqa: ARG002 ) -> list[InferenceChunk]: logger.debug( - f"[OpenSearchDocumentIndex] Randomly retrieving {num_to_retrieve} chunks for index {self._index_name}." + f"[OpenSearchDocumentIndex] Randomly retrieving {num_to_retrieve} chunks for index " + f"{self._index_name}." ) query_body = DocumentQuery.get_random_search_query( tenant_state=self._tenant_state, @@ -938,7 +1003,8 @@ def index_raw_chunks(self, chunks: list[DocumentChunk]) -> None: complete. """ logger.debug( - f"[OpenSearchDocumentIndex] Indexing {len(chunks)} raw chunks for index {self._index_name}." + f"[OpenSearchDocumentIndex] Indexing {len(chunks)} raw chunks for index " + f"{self._index_name}." ) # Do not raise if the document already exists, just update. This is # because the document may already have been indexed during the diff --git a/backend/onyx/document_index/opensearch/schema.py b/backend/onyx/document_index/opensearch/schema.py index cac59aaad09..10e8a19792d 100644 --- a/backend/onyx/document_index/opensearch/schema.py +++ b/backend/onyx/document_index/opensearch/schema.py @@ -12,6 +12,7 @@ from pydantic import SerializerFunctionWrapHandler from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER +from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH from onyx.document_index.interfaces_new import TenantState from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE from onyx.document_index.opensearch.constants import EF_CONSTRUCTION @@ -242,7 +243,8 @@ def parse_epoch_seconds_to_datetime(cls, value: Any) -> datetime | None: return value if not isinstance(value, int): raise ValueError( - f"Bug: Expected an int for the last_updated property from OpenSearch, got {type(value)} instead." + f"Bug: Expected an int for the last_updated property from OpenSearch, got " + f"{type(value)} instead." ) return datetime.fromtimestamp(value, tz=timezone.utc) @@ -283,19 +285,22 @@ def parse_tenant_id(cls, value: Any) -> TenantState: elif isinstance(value, TenantState): if MULTI_TENANT != value.multitenant: raise ValueError( - f"Bug: An existing TenantState object was supplied to the DocumentChunk model but its multi-tenant mode " - f"({value.multitenant}) does not match the program's current global tenancy state." + f"Bug: An existing TenantState object was supplied to the DocumentChunk model " + f"but its multi-tenant mode ({value.multitenant}) does not match the program's " + "current global tenancy state." ) return value elif not isinstance(value, str): raise ValueError( - f"Bug: Expected a str for the tenant_id property from OpenSearch, got {type(value)} instead." + f"Bug: Expected a str for the tenant_id property from OpenSearch, got " + f"{type(value)} instead." ) else: if not MULTI_TENANT: raise ValueError( - "Bug: Got a non-null str for the tenant_id property from OpenSearch but multi-tenant mode is not enabled. " - "This is unexpected because in single-tenant mode we don't expect to see a tenant_id." + "Bug: Got a non-null str for the tenant_id property from OpenSearch but " + "multi-tenant mode is not enabled. This is unexpected because in single-tenant " + "mode we don't expect to see a tenant_id." ) return TenantState(tenant_id=value, multitenant=MULTI_TENANT) @@ -351,8 +356,10 @@ def get_document_schema(vector_dimension: int, multitenant: bool) -> dict[str, A "properties": { TITLE_FIELD_NAME: { "type": "text", - # Language analyzer (e.g. english) stems at index and search time for variant matching. - # Configure via OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing after a change. + # Language analyzer (e.g. english) stems at index and search + # time for variant matching. Configure via + # OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing + # after a change. "analyzer": OPENSEARCH_TEXT_ANALYZER, "fields": { # Subfield accessed as title.keyword. Not indexed for @@ -525,7 +532,7 @@ def get_index_settings() -> dict[str, Any]: } @staticmethod - def get_index_settings_for_aws_managed_opensearch() -> dict[str, Any]: + def get_index_settings_for_aws_managed_opensearch_st_dev() -> dict[str, Any]: """ Settings for AWS-managed OpenSearch. @@ -546,3 +553,41 @@ def get_index_settings_for_aws_managed_opensearch() -> dict[str, Any]: "knn.algo_param.ef_search": EF_SEARCH, } } + + @staticmethod + def get_index_settings_for_aws_managed_opensearch_mt_cloud() -> dict[str, Any]: + """ + Settings for AWS-managed OpenSearch in multi-tenant cloud. + + 324 shards very roughly targets a storage load of ~30Gb per shard, which + according to AWS OpenSearch documentation is within a good target range. + + As documented above we need 2 replicas for a total of 3 copies of the + data because the cluster is configured with 3-AZ awareness. + """ + return { + "index": { + "number_of_shards": 324, + "number_of_replicas": 2, + # Required for vector search. + "knn": True, + "knn.algo_param.ef_search": EF_SEARCH, + } + } + + @staticmethod + def get_index_settings_based_on_environment() -> dict[str, Any]: + """ + Returns the index settings based on the environment. + """ + if USING_AWS_MANAGED_OPENSEARCH: + if MULTI_TENANT: + return ( + DocumentSchema.get_index_settings_for_aws_managed_opensearch_mt_cloud() + ) + else: + return ( + DocumentSchema.get_index_settings_for_aws_managed_opensearch_st_dev() + ) + else: + return DocumentSchema.get_index_settings() diff --git a/backend/onyx/document_index/opensearch/search.py b/backend/onyx/document_index/opensearch/search.py index 7ff7a5002f7..b9a6e7fc2dc 100644 --- a/backend/onyx/document_index/opensearch/search.py +++ b/backend/onyx/document_index/opensearch/search.py @@ -255,8 +255,12 @@ def get_hybrid_search_query( f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})." ) + # TODO(andrei, yuhong): We can tune this more dynamically based on + # num_hits. + max_results_per_subquery = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES + hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries( - query_text, query_vector + query_text, query_vector, vector_candidates=max_results_per_subquery ) hybrid_search_filters = DocumentQuery._get_search_filters( tenant_state=tenant_state, @@ -285,13 +289,16 @@ def get_hybrid_search_query( hybrid_search_query: dict[str, Any] = { "hybrid": { "queries": hybrid_search_subqueries, - # Max results per subquery per shard before aggregation. Ensures keyword and vector - # subqueries contribute equally to the candidate pool for hybrid fusion. + # Max results per subquery per shard before aggregation. Ensures + # keyword and vector subqueries contribute equally to the + # candidate pool for hybrid fusion. # Sources: # https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/ # https://opensearch.org/blog/navigating-pagination-in-hybrid-queries-with-the-pagination_depth-parameter/ - "pagination_depth": DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES, - # Applied to all the sub-queries independently (this avoids having subqueries having a lot of results thrown out). + "pagination_depth": max_results_per_subquery, + # Applied to all the sub-queries independently (this avoids + # subqueries having a lot of results thrown out during + # aggregation). # Sources: # https://docs.opensearch.org/latest/query-dsl/compound/hybrid/ # https://opensearch.org/blog/introducing-common-filter-support-for-hybrid-search-queries @@ -374,9 +381,10 @@ def get_random_search_query( def _get_hybrid_search_subqueries( query_text: str, query_vector: list[float], - # The default number of neighbors to consider for knn vector similarity search. - # This is higher than the number of results because the scoring is hybrid. - # for a detailed breakdown, see where the default value is set. + # The default number of neighbors to consider for knn vector similarity + # search. This is higher than the number of results because the scoring + # is hybrid. For a detailed breakdown, see where the default value is + # set. vector_candidates: int = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES, ) -> list[dict[str, Any]]: """Returns subqueries for hybrid search. @@ -400,20 +408,27 @@ def _get_hybrid_search_subqueries( in a single hybrid query. Source: https://docs.opensearch.org/latest/query-dsl/compound/hybrid/ - NOTE: Each query is independent during the search phase, there is no backfilling of scores for missing query components. - What this means is that if a document was a good vector match but did not show up for keyword, it gets a score of 0 for - the keyword component of the hybrid scoring. This is not as bad as just disregarding a score though as there is - normalization applied after. So really it is "increasing" the missing score compared to if it was included and the range - was renormalized. This does however mean that between docs that have high scores for say the vector field, the keyword - scores between them are completely ignored unless they also showed up in the keyword query as a reasonably high match. - TLDR, this is a bit of unique funky behavior but it seems ok. + NOTE: Each query is independent during the search phase, there is no + backfilling of scores for missing query components. What this means is + that if a document was a good vector match but did not show up for + keyword, it gets a score of 0 for the keyword component of the hybrid + scoring. This is not as bad as just disregarding a score though as there + is normalization applied after. So really it is "increasing" the missing + score compared to if it was included and the range was renormalized. + This does however mean that between docs that have high scores for say + the vector field, the keyword scores between them are completely ignored + unless they also showed up in the keyword query as a reasonably high + match. TLDR, this is a bit of unique funky behavior but it seems ok. NOTE: Options considered and rejected: - - minimum_should_match: Since it's hybrid search and users often provide semantic queries, there is often a lot of terms, - and very low number of meaningful keywords (and a low ratio of keywords). - - fuzziness AUTO: typo tolerance (0/1/2 edit distance by term length). It's mostly for typos as the analyzer ("english by - default") already does some stemming and tokenization. In testing datasets, this makes recall slightly worse. It also is - less performant so not really any reason to do it. + - minimum_should_match: Since it's hybrid search and users often provide + semantic queries, there is often a lot of terms, and very low number + of meaningful keywords (and a low ratio of keywords). + - fuzziness AUTO: Typo tolerance (0/1/2 edit distance by term length). + It's mostly for typos as the analyzer ("english" by default) already + does some stemming and tokenization. In testing datasets, this makes + recall slightly worse. It also is less performant so not really any + reason to do it. Args: query_text: The text of the query to search for. @@ -698,41 +713,6 @@ def _get_hierarchy_node_filter( """ return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: node_ids}} - def _get_assistant_knowledge_filter( - attached_doc_ids: list[str] | None, - node_ids: list[int] | None, - file_ids: list[UUID] | None, - document_sets: list[str] | None, - ) -> dict[str, Any]: - """Combined filter for assistant knowledge. - - When an assistant has attached knowledge, search should be scoped to: - - Documents explicitly attached (by document ID), OR - - Documents under attached hierarchy nodes (by ancestor node IDs), OR - - User-uploaded files attached to the assistant, OR - - Documents in the assistant's document sets (if any) - """ - knowledge_filter: dict[str, Any] = { - "bool": {"should": [], "minimum_should_match": 1} - } - if attached_doc_ids: - knowledge_filter["bool"]["should"].append( - _get_attached_document_id_filter(attached_doc_ids) - ) - if node_ids: - knowledge_filter["bool"]["should"].append( - _get_hierarchy_node_filter(node_ids) - ) - if file_ids: - knowledge_filter["bool"]["should"].append( - _get_user_file_id_filter(file_ids) - ) - if document_sets: - knowledge_filter["bool"]["should"].append( - _get_document_set_filter(document_sets) - ) - return knowledge_filter - filter_clauses: list[dict[str, Any]] = [] if not include_hidden: @@ -758,41 +738,51 @@ def _get_assistant_knowledge_filter( # document's metadata list. filter_clauses.append(_get_tag_filter(tags)) - # Check if this is an assistant knowledge search (has any assistant-scoped knowledge) - has_assistant_knowledge = ( + # Knowledge scope: explicit knowledge attachments restrict what an + # assistant can see. When none are set the assistant searches + # everything. + # + # project_id / persona_id are additive: they make overflowing user files + # findable but must NOT trigger the restriction on their own (an agent + # with no explicit knowledge should search everything). + has_knowledge_scope = ( attached_document_ids or hierarchy_node_ids or user_file_ids or document_sets ) - if has_assistant_knowledge: - # If assistant has attached knowledge, scope search to that knowledge. - # Document sets are included in the OR filter so directly attached - # docs are always findable even if not in the document sets. - filter_clauses.append( - _get_assistant_knowledge_filter( - attached_document_ids, - hierarchy_node_ids, - user_file_ids, - document_sets, + if has_knowledge_scope: + knowledge_filter: dict[str, Any] = { + "bool": {"should": [], "minimum_should_match": 1} + } + if attached_document_ids: + knowledge_filter["bool"]["should"].append( + _get_attached_document_id_filter(attached_document_ids) ) - ) - elif user_file_ids: - # Fallback for non-assistant user file searches (e.g., project searches) - # If at least one user file ID is provided, the caller will only - # retrieve documents where the document ID is in this input list of - # file IDs. - filter_clauses.append(_get_user_file_id_filter(user_file_ids)) - - if project_id is not None: - # If a project ID is provided, the caller will only retrieve - # documents where the project ID provided here is present in the - # document's user projects list. - filter_clauses.append(_get_user_project_filter(project_id)) - - if persona_id is not None: - filter_clauses.append(_get_persona_filter(persona_id)) + if hierarchy_node_ids: + knowledge_filter["bool"]["should"].append( + _get_hierarchy_node_filter(hierarchy_node_ids) + ) + if user_file_ids: + knowledge_filter["bool"]["should"].append( + _get_user_file_id_filter(user_file_ids) + ) + if document_sets: + knowledge_filter["bool"]["should"].append( + _get_document_set_filter(document_sets) + ) + # Additive: widen scope to also cover overflowing user files, but + # only when an explicit restriction is already in effect. + if project_id is not None: + knowledge_filter["bool"]["should"].append( + _get_user_project_filter(project_id) + ) + if persona_id is not None: + knowledge_filter["bool"]["should"].append( + _get_persona_filter(persona_id) + ) + filter_clauses.append(knowledge_filter) if time_cutoff is not None: # If a time cutoff is provided, the caller will only retrieve diff --git a/backend/onyx/document_index/vespa/chunk_retrieval.py b/backend/onyx/document_index/vespa/chunk_retrieval.py index 49124995b4e..dd3258fab65 100644 --- a/backend/onyx/document_index/vespa/chunk_retrieval.py +++ b/backend/onyx/document_index/vespa/chunk_retrieval.py @@ -1,5 +1,6 @@ import json import string +import time from collections.abc import Callable from collections.abc import Mapping from datetime import datetime @@ -18,6 +19,7 @@ ) from onyx.configs.app_configs import LOG_VESPA_TIMING_INFORMATION from onyx.configs.app_configs import VESPA_LANGUAGE_OVERRIDE +from onyx.configs.app_configs import VESPA_MIGRATION_REQUEST_TIMEOUT_S from onyx.context.search.models import IndexFilters from onyx.context.search.models import InferenceChunkUncleaned from onyx.document_index.interfaces import VespaChunkRequest @@ -338,12 +340,18 @@ def _get_all_chunks_paginated_for_slice( params["continuation"] = continuation_token response: httpx.Response | None = None + start_time = time.monotonic() try: - with get_vespa_http_client() as http_client: + with get_vespa_http_client( + timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S + ) as http_client: response = http_client.get(url, params=params) response.raise_for_status() except httpx.HTTPError as e: - error_base = f"Failed to get chunks from Vespa slice {slice_id} with continuation token {continuation_token}." + error_base = ( + f"Failed to get chunks from Vespa slice {slice_id} with continuation token " + f"{continuation_token} in {time.monotonic() - start_time:.3f} seconds." + ) logger.exception( f"Request URL: {e.request.url}\n" f"Request Headers: {e.request.headers}\n" diff --git a/backend/onyx/document_index/vespa/index.py b/backend/onyx/document_index/vespa/index.py index da029f47cae..71322f90f67 100644 --- a/backend/onyx/document_index/vespa/index.py +++ b/backend/onyx/document_index/vespa/index.py @@ -465,6 +465,12 @@ def index( chunks: list[DocMetadataAwareIndexChunk], index_batch_params: IndexBatchParams, ) -> set[OldDocumentInsertionRecord]: + """ + NOTE: Do NOT consider the secondary index here. A separate indexing + pipeline will be responsible for indexing to the secondary index. This + design is not ideal and we should reconsider this when revamping index + swapping. + """ if len(index_batch_params.doc_id_to_previous_chunk_cnt) != len( index_batch_params.doc_id_to_new_chunk_cnt ): @@ -659,6 +665,10 @@ def update_single( """Note: if the document id does not exist, the update will be a no-op and the function will complete with no errors or exceptions. Handle other exceptions if you wish to implement retry behavior + + NOTE: Remember to handle the secondary index here. There is no separate + pipeline for updating chunks in the secondary index. This design is not + ideal and we should reconsider this when revamping index swapping. """ if fields is None and user_fields is None: logger.warning( @@ -679,17 +689,13 @@ def update_single( f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}." ) - vespa_document_index = VespaDocumentIndex( - index_name=self.index_name, - tenant_state=tenant_state, - large_chunks_enabled=self.large_chunks_enabled, - httpx_client=self.httpx_client, - ) - project_ids: set[int] | None = None + # NOTE: Empty user_projects is semantically different from None + # user_projects. if user_fields is not None and user_fields.user_projects is not None: project_ids = set(user_fields.user_projects) persona_ids: set[int] | None = None + # NOTE: Empty personas is semantically different from None personas. if user_fields is not None and user_fields.personas is not None: persona_ids = set(user_fields.personas) update_request = MetadataUpdateRequest( @@ -705,7 +711,20 @@ def update_single( persona_ids=persona_ids, ) - vespa_document_index.update([update_request]) + indices = [self.index_name] + if self.secondary_index_name: + indices.append(self.secondary_index_name) + + for index_name in indices: + vespa_document_index = VespaDocumentIndex( + index_name=index_name, + tenant_state=tenant_state, + large_chunks_enabled=self.index_to_large_chunks_enabled.get( + index_name, False + ), + httpx_client=self.httpx_client, + ) + vespa_document_index.update([update_request]) def delete_single( self, @@ -714,6 +733,11 @@ def delete_single( tenant_id: str, chunk_count: int | None, ) -> int: + """ + NOTE: Remember to handle the secondary index here. There is no separate + pipeline for deleting chunks in the secondary index. This design is not + ideal and we should reconsider this when revamping index swapping. + """ tenant_state = TenantState( tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT, @@ -726,13 +750,25 @@ def delete_single( raise ValueError( f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}." ) - vespa_document_index = VespaDocumentIndex( - index_name=self.index_name, - tenant_state=tenant_state, - large_chunks_enabled=self.large_chunks_enabled, - httpx_client=self.httpx_client, - ) - return vespa_document_index.delete(document_id=doc_id, chunk_count=chunk_count) + indices = [self.index_name] + if self.secondary_index_name: + indices.append(self.secondary_index_name) + + total_chunks_deleted = 0 + for index_name in indices: + vespa_document_index = VespaDocumentIndex( + index_name=index_name, + tenant_state=tenant_state, + large_chunks_enabled=self.index_to_large_chunks_enabled.get( + index_name, False + ), + httpx_client=self.httpx_client, + ) + total_chunks_deleted += vespa_document_index.delete( + document_id=doc_id, chunk_count=chunk_count + ) + + return total_chunks_deleted def id_based_retrieval( self, diff --git a/backend/onyx/document_index/vespa/shared_utils/utils.py b/backend/onyx/document_index/vespa/shared_utils/utils.py index 74e185cdef4..c52eefa9bd2 100644 --- a/backend/onyx/document_index/vespa/shared_utils/utils.py +++ b/backend/onyx/document_index/vespa/shared_utils/utils.py @@ -52,7 +52,9 @@ def replace_invalid_doc_id_characters(text: str) -> str: return text.replace("'", "_") -def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx.Client: +def get_vespa_http_client( + no_timeout: bool = False, http2: bool = True, timeout: int | None = None +) -> httpx.Client: """ Configures and returns an HTTP client for communicating with Vespa, including authentication if needed. @@ -64,7 +66,7 @@ def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx else None ), verify=False if not MANAGED_VESPA else True, - timeout=None if no_timeout else VESPA_REQUEST_TIMEOUT, + timeout=None if no_timeout else (timeout or VESPA_REQUEST_TIMEOUT), http2=http2, ) diff --git a/backend/onyx/document_index/vespa/shared_utils/vespa_request_builders.py b/backend/onyx/document_index/vespa/shared_utils/vespa_request_builders.py index 276f32b9aae..28653d6c29b 100644 --- a/backend/onyx/document_index/vespa/shared_utils/vespa_request_builders.py +++ b/backend/onyx/document_index/vespa/shared_utils/vespa_request_builders.py @@ -23,11 +23,8 @@ logger = setup_logger() -def build_tenant_id_filter(tenant_id: str, include_trailing_and: bool = False) -> str: - filter_str = f'({TENANT_ID} contains "{tenant_id}")' - if include_trailing_and: - filter_str += " and " - return filter_str +def build_tenant_id_filter(tenant_id: str) -> str: + return f'({TENANT_ID} contains "{tenant_id}")' def build_vespa_filters( @@ -37,30 +34,22 @@ def build_vespa_filters( remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query ) -> str: def _build_or_filters(key: str, vals: list[str] | None) -> str: - """For string-based 'contains' filters, e.g. WSET fields or array fields.""" + """For string-based 'contains' filters, e.g. WSET fields or array fields. + Returns a bare clause like '(key contains "v1" or key contains "v2")' or "".""" if not key or not vals: return "" eq_elems = [f'{key} contains "{val}"' for val in vals if val] if not eq_elems: return "" - or_clause = " or ".join(eq_elems) - return f"({or_clause}) and " + return f"({' or '.join(eq_elems)})" def _build_int_or_filters(key: str, vals: list[int] | None) -> str: - """ - For an integer field filter. - If vals is not None, we want *only* docs whose key matches one of vals. - """ - # If `vals` is None => skip the filter entirely + """For an integer field filter. + Returns a bare clause or "".""" if vals is None or not vals: return "" - - # Otherwise build the OR filter eq_elems = [f"{key} = {val}" for val in vals] - or_clause = " or ".join(eq_elems) - result = f"({or_clause}) and " - - return result + return f"({' or '.join(eq_elems)})" def _build_kg_filter( kg_entities: list[str] | None, @@ -73,16 +62,12 @@ def _build_kg_filter( combined_filter_parts = [] def _build_kge(entity: str) -> str: - # TYPE-SUBTYPE::ID -> "TYPE-SUBTYPE::ID" - # TYPE-SUBTYPE::* -> ({prefix: true}"TYPE-SUBTYPE") - # TYPE::* -> ({prefix: true}"TYPE") GENERAL = "::*" if entity.endswith(GENERAL): return f'({{prefix: true}}"{entity.split(GENERAL, 1)[0]}")' else: return f'"{entity}"' - # OR the entities (give new design) if kg_entities: filter_parts = [] for kg_entity in kg_entities: @@ -104,8 +89,7 @@ def _build_kge(entity: str) -> str: # TODO: remove kg terms entirely from prompts and codebase - # AND the combined filter parts - return f"({' and '.join(combined_filter_parts)}) and " + return f"({' and '.join(combined_filter_parts)})" def _build_kg_source_filters( kg_sources: list[str] | None, @@ -114,16 +98,14 @@ def _build_kg_source_filters( return "" source_phrases = [f'{DOCUMENT_ID} contains "{source}"' for source in kg_sources] - - return f"({' or '.join(source_phrases)}) and " + return f"({' or '.join(source_phrases)})" def _build_kg_chunk_id_zero_only_filter( kg_chunk_id_zero_only: bool, ) -> str: if not kg_chunk_id_zero_only: return "" - - return "(chunk_id = 0 ) and " + return "(chunk_id = 0)" def _build_time_filter( cutoff: datetime | None, @@ -135,8 +117,8 @@ def _build_time_filter( cutoff_secs = int(cutoff.timestamp()) if include_untimed: - return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and " - return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and " + return f"!({DOC_UPDATED_AT} < {cutoff_secs})" + return f"({DOC_UPDATED_AT} >= {cutoff_secs})" def _build_user_project_filter( project_id: int | None, @@ -147,8 +129,7 @@ def _build_user_project_filter( pid = int(project_id) except Exception: return "" - # Vespa YQL 'contains' expects a string literal; quote the integer - return f'({USER_PROJECT} contains "{pid}") and ' + return f'({USER_PROJECT} contains "{pid}")' def _build_persona_filter( persona_id: int | None, @@ -160,73 +141,94 @@ def _build_persona_filter( except Exception: logger.warning(f"Invalid persona ID: {persona_id}") return "" - return f'({PERSONAS} contains "{pid}") and ' + return f'({PERSONAS} contains "{pid}")' + + def _append(parts: list[str], clause: str) -> None: + if clause: + parts.append(clause) - # Start building the filter string - filter_str = f"!({HIDDEN}=true) and " if not include_hidden else "" + # Collect all top-level filter clauses, then join with " and " at the end. + filter_parts: list[str] = [] + + if not include_hidden: + filter_parts.append(f"!({HIDDEN}=true)") # TODO: add error condition if MULTI_TENANT and no tenant_id filter is set - # If running in multi-tenant mode if filters.tenant_id and MULTI_TENANT: - filter_str += build_tenant_id_filter( - filters.tenant_id, include_trailing_and=True - ) + filter_parts.append(build_tenant_id_filter(filters.tenant_id)) # ACL filters if filters.access_control_list is not None: - filter_str += _build_or_filters( - ACCESS_CONTROL_LIST, filters.access_control_list + _append( + filter_parts, + _build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list), ) # Source type filters source_strs = ( [s.value for s in filters.source_type] if filters.source_type else None ) - filter_str += _build_or_filters(SOURCE_TYPE, source_strs) + _append(filter_parts, _build_or_filters(SOURCE_TYPE, source_strs)) # Tag filters tag_attributes = None if filters.tags: - # build e.g. "tag_key|tag_value" tag_attributes = [ f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in filters.tags ] - filter_str += _build_or_filters(METADATA_LIST, tag_attributes) - - # Document sets - filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set) + _append(filter_parts, _build_or_filters(METADATA_LIST, tag_attributes)) + + # Knowledge scope: explicit knowledge attachments (document_sets, + # user_file_ids) restrict what an assistant can see. When none are + # set, the assistant can see everything. + # + # project_id / persona_id are additive: they make overflowing user + # files findable in Vespa but must NOT trigger the restriction on + # their own (an agent with no explicit knowledge should search + # everything). + knowledge_scope_parts: list[str] = [] + + _append( + knowledge_scope_parts, _build_or_filters(DOCUMENT_SETS, filters.document_set) + ) - # Convert UUIDs to strings for user_file_ids user_file_ids_str = ( [str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None ) - filter_str += _build_or_filters(DOCUMENT_ID, user_file_ids_str) + _append(knowledge_scope_parts, _build_or_filters(DOCUMENT_ID, user_file_ids_str)) - # User project filter (array attribute membership) - filter_str += _build_user_project_filter(filters.project_id) + # Only include project/persona scopes when an explicit knowledge + # restriction is already in effect — they widen the scope to also + # cover overflowing user files but never restrict on their own. + if knowledge_scope_parts: + _append(knowledge_scope_parts, _build_user_project_filter(filters.project_id)) + _append(knowledge_scope_parts, _build_persona_filter(filters.persona_id)) - # Persona filter (array attribute membership) - filter_str += _build_persona_filter(filters.persona_id) + if len(knowledge_scope_parts) > 1: + filter_parts.append("(" + " or ".join(knowledge_scope_parts) + ")") + elif len(knowledge_scope_parts) == 1: + filter_parts.append(knowledge_scope_parts[0]) # Time filter - filter_str += _build_time_filter(filters.time_cutoff) + _append(filter_parts, _build_time_filter(filters.time_cutoff)) # # Knowledge Graph Filters - # filter_str += _build_kg_filter( + # _append(filter_parts, _build_kg_filter( # kg_entities=filters.kg_entities, # kg_relationships=filters.kg_relationships, # kg_terms=filters.kg_terms, - # ) + # )) - # filter_str += _build_kg_source_filters(filters.kg_sources) + # _append(filter_parts, _build_kg_source_filters(filters.kg_sources)) - # filter_str += _build_kg_chunk_id_zero_only_filter( + # _append(filter_parts, _build_kg_chunk_id_zero_only_filter( # filters.kg_chunk_id_zero_only or False - # ) + # )) + + filter_str = " and ".join(filter_parts) - # Trim trailing " and " - if remove_trailing_and and filter_str.endswith(" and "): - filter_str = filter_str[:-5] + if filter_str and not remove_trailing_and: + filter_str += " and " return filter_str diff --git a/backend/onyx/error_handling/__init__.py b/backend/onyx/error_handling/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/onyx/error_handling/error_codes.py b/backend/onyx/error_handling/error_codes.py new file mode 100644 index 00000000000..42a4c5d60a8 --- /dev/null +++ b/backend/onyx/error_handling/error_codes.py @@ -0,0 +1,101 @@ +""" +Standardized error codes for the Onyx backend. + +Usage: + from onyx.error_handling.error_codes import OnyxErrorCode + from onyx.error_handling.exceptions import OnyxError + + raise OnyxError(OnyxErrorCode.UNAUTHENTICATED, "Token expired") +""" + +from enum import Enum + + +class OnyxErrorCode(Enum): + """ + Each member is a tuple of (error_code_string, http_status_code). + + The error_code_string is a stable, machine-readable identifier that + API consumers can match on. The http_status_code is the default HTTP + status to return. + """ + + # ------------------------------------------------------------------ + # Authentication (401) + # ------------------------------------------------------------------ + UNAUTHENTICATED = ("UNAUTHENTICATED", 401) + INVALID_TOKEN = ("INVALID_TOKEN", 401) + TOKEN_EXPIRED = ("TOKEN_EXPIRED", 401) + CSRF_FAILURE = ("CSRF_FAILURE", 403) + + # ------------------------------------------------------------------ + # Authorization (403) + # ------------------------------------------------------------------ + UNAUTHORIZED = ("UNAUTHORIZED", 403) + INSUFFICIENT_PERMISSIONS = ("INSUFFICIENT_PERMISSIONS", 403) + ADMIN_ONLY = ("ADMIN_ONLY", 403) + EE_REQUIRED = ("EE_REQUIRED", 403) + + # ------------------------------------------------------------------ + # Validation / Bad Request (400) + # ------------------------------------------------------------------ + VALIDATION_ERROR = ("VALIDATION_ERROR", 400) + INVALID_INPUT = ("INVALID_INPUT", 400) + MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400) + + # ------------------------------------------------------------------ + # Not Found (404) + # ------------------------------------------------------------------ + NOT_FOUND = ("NOT_FOUND", 404) + CONNECTOR_NOT_FOUND = ("CONNECTOR_NOT_FOUND", 404) + CREDENTIAL_NOT_FOUND = ("CREDENTIAL_NOT_FOUND", 404) + PERSONA_NOT_FOUND = ("PERSONA_NOT_FOUND", 404) + DOCUMENT_NOT_FOUND = ("DOCUMENT_NOT_FOUND", 404) + SESSION_NOT_FOUND = ("SESSION_NOT_FOUND", 404) + USER_NOT_FOUND = ("USER_NOT_FOUND", 404) + + # ------------------------------------------------------------------ + # Conflict (409) + # ------------------------------------------------------------------ + CONFLICT = ("CONFLICT", 409) + DUPLICATE_RESOURCE = ("DUPLICATE_RESOURCE", 409) + + # ------------------------------------------------------------------ + # Rate Limiting / Quotas (429 / 402) + # ------------------------------------------------------------------ + RATE_LIMITED = ("RATE_LIMITED", 429) + SEAT_LIMIT_EXCEEDED = ("SEAT_LIMIT_EXCEEDED", 402) + + # ------------------------------------------------------------------ + # Connector / Credential Errors (400-range) + # ------------------------------------------------------------------ + CONNECTOR_VALIDATION_FAILED = ("CONNECTOR_VALIDATION_FAILED", 400) + CREDENTIAL_INVALID = ("CREDENTIAL_INVALID", 400) + CREDENTIAL_EXPIRED = ("CREDENTIAL_EXPIRED", 401) + + # ------------------------------------------------------------------ + # Server Errors (5xx) + # ------------------------------------------------------------------ + INTERNAL_ERROR = ("INTERNAL_ERROR", 500) + NOT_IMPLEMENTED = ("NOT_IMPLEMENTED", 501) + SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503) + BAD_GATEWAY = ("BAD_GATEWAY", 502) + LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502) + GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504) + + def __init__(self, code: str, status_code: int) -> None: + self.code = code + self.status_code = status_code + + def detail(self, message: str | None = None) -> dict[str, str]: + """Build a structured error detail dict. + + Returns a dict like: + {"error_code": "UNAUTHENTICATED", "detail": "Token expired"} + + If no message is supplied, the error code itself is used as the detail. + """ + return { + "error_code": self.code, + "detail": message or self.code, + } diff --git a/backend/onyx/error_handling/exceptions.py b/backend/onyx/error_handling/exceptions.py new file mode 100644 index 00000000000..357c82c26bf --- /dev/null +++ b/backend/onyx/error_handling/exceptions.py @@ -0,0 +1,91 @@ +"""OnyxError — the single exception type for all Onyx business errors. + +Raise ``OnyxError`` instead of ``HTTPException`` in business code. A global +FastAPI exception handler (registered via ``register_onyx_exception_handlers``) +converts it into a JSON response with the standard +``{"error_code": "...", "detail": "..."}`` shape. + +Usage:: + + from onyx.error_handling.error_codes import OnyxErrorCode + from onyx.error_handling.exceptions import OnyxError + + raise OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found") + +For upstream errors with a dynamic HTTP status (e.g. billing service), +use ``status_code_override``:: + + raise OnyxError( + OnyxErrorCode.BAD_GATEWAY, + detail, + status_code_override=upstream_status, + ) +""" + +from fastapi import FastAPI +from fastapi import Request +from fastapi.responses import JSONResponse + +from onyx.error_handling.error_codes import OnyxErrorCode +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +class OnyxError(Exception): + """Structured error that maps to a specific ``OnyxErrorCode``. + + Attributes: + error_code: The ``OnyxErrorCode`` enum member. + detail: Human-readable detail (defaults to the error code string). + status_code: HTTP status — either overridden or from the error code. + """ + + def __init__( + self, + error_code: OnyxErrorCode, + detail: str | None = None, + *, + status_code_override: int | None = None, + ) -> None: + resolved_detail = detail or error_code.code + super().__init__(resolved_detail) + self.error_code = error_code + self.detail = resolved_detail + self._status_code_override = status_code_override + + @property + def status_code(self) -> int: + return self._status_code_override or self.error_code.status_code + + +def log_onyx_error(exc: OnyxError) -> None: + detail = exc.detail + status_code = exc.status_code + if status_code >= 500: + logger.error(f"OnyxError {exc.error_code.code}: {detail}") + elif status_code >= 400: + logger.warning(f"OnyxError {exc.error_code.code}: {detail}") + + +def onyx_error_to_json_response(exc: OnyxError) -> JSONResponse: + return JSONResponse( + status_code=exc.status_code, + content=exc.error_code.detail(exc.detail), + ) + + +def register_onyx_exception_handlers(app: FastAPI) -> None: + """Register a global handler that converts ``OnyxError`` to JSON responses. + + Must be called *after* the app is created but *before* it starts serving. + The handler logs at WARNING for 4xx and ERROR for 5xx. + """ + + @app.exception_handler(OnyxError) + async def _handle_onyx_error( + request: Request, # noqa: ARG001 + exc: OnyxError, + ) -> JSONResponse: + log_onyx_error(exc) + return onyx_error_to_json_response(exc) diff --git a/backend/onyx/federated_connectors/oauth_utils.py b/backend/onyx/federated_connectors/oauth_utils.py index 7876d3fda8b..0b17d97f723 100644 --- a/backend/onyx/federated_connectors/oauth_utils.py +++ b/backend/onyx/federated_connectors/oauth_utils.py @@ -4,39 +4,33 @@ import json import uuid from typing import Any -from typing import cast -from typing import Dict -from typing import Optional +from onyx.cache.factory import get_cache_backend from onyx.configs.app_configs import WEB_DOMAIN -from onyx.redis.redis_pool import get_redis_client from onyx.utils.logger import setup_logger logger = setup_logger() -# Redis key prefix for OAuth state OAUTH_STATE_PREFIX = "federated_oauth" -# Default TTL for OAuth state (5 minutes) -OAUTH_STATE_TTL = 300 +OAUTH_STATE_TTL = 300 # 5 minutes class OAuthSession: - """Represents an OAuth session stored in Redis.""" + """Represents an OAuth session stored in the cache backend.""" def __init__( self, federated_connector_id: int, user_id: str, - redirect_uri: Optional[str] = None, - additional_data: Optional[Dict[str, Any]] = None, + redirect_uri: str | None = None, + additional_data: dict[str, Any] | None = None, ): self.federated_connector_id = federated_connector_id self.user_id = user_id self.redirect_uri = redirect_uri self.additional_data = additional_data or {} - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary for Redis storage.""" + def to_dict(self) -> dict[str, Any]: return { "federated_connector_id": self.federated_connector_id, "user_id": self.user_id, @@ -45,8 +39,7 @@ def to_dict(self) -> Dict[str, Any]: } @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "OAuthSession": - """Create from dictionary retrieved from Redis.""" + def from_dict(cls, data: dict[str, Any]) -> "OAuthSession": return cls( federated_connector_id=data["federated_connector_id"], user_id=data["user_id"], @@ -58,31 +51,27 @@ def from_dict(cls, data: Dict[str, Any]) -> "OAuthSession": def generate_oauth_state( federated_connector_id: int, user_id: str, - redirect_uri: Optional[str] = None, - additional_data: Optional[Dict[str, Any]] = None, + redirect_uri: str | None = None, + additional_data: dict[str, Any] | None = None, ttl: int = OAUTH_STATE_TTL, ) -> str: """ - Generate a secure state parameter and store session data in Redis. + Generate a secure state parameter and store session data in the cache backend. Args: federated_connector_id: ID of the federated connector user_id: ID of the user initiating OAuth redirect_uri: Optional redirect URI after OAuth completion additional_data: Any additional data to store with the session - ttl: Time-to-live in seconds for the Redis key + ttl: Time-to-live in seconds for the cache key Returns: Base64-encoded state parameter """ # Generate a random UUID for the state state_uuid = uuid.uuid4() + state_b64 = base64.urlsafe_b64encode(state_uuid.bytes).decode("utf-8").rstrip("=") - # Convert UUID to base64 for URL-safe state parameter - state_bytes = state_uuid.bytes - state_b64 = base64.urlsafe_b64encode(state_bytes).decode("utf-8").rstrip("=") - - # Create session object session = OAuthSession( federated_connector_id=federated_connector_id, user_id=user_id, @@ -90,15 +79,9 @@ def generate_oauth_state( additional_data=additional_data, ) - # Store in Redis with TTL - redis_client = get_redis_client() - redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}" - - redis_client.set( - redis_key, - json.dumps(session.to_dict()), - ex=ttl, - ) + cache = get_cache_backend() + cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}" + cache.set(cache_key, json.dumps(session.to_dict()), ex=ttl) logger.info( f"Generated OAuth state for federated_connector_id={federated_connector_id}, " @@ -125,18 +108,15 @@ def verify_oauth_state(state: str) -> OAuthSession: state_bytes = base64.urlsafe_b64decode(padded_state) state_uuid = uuid.UUID(bytes=state_bytes) - # Look up in Redis - redis_client = get_redis_client() - redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}" + cache = get_cache_backend() + cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}" - session_data = cast(bytes, redis_client.get(redis_key)) + session_data = cache.get(cache_key) if not session_data: - raise ValueError(f"OAuth state not found in Redis: {state}") + raise ValueError(f"OAuth state not found: {state}") - # Delete the key after retrieval (one-time use) - redis_client.delete(redis_key) + cache.delete(cache_key) - # Parse and return session session_dict = json.loads(session_data) return OAuthSession.from_dict(session_dict) diff --git a/backend/onyx/file_processing/file_types.py b/backend/onyx/file_processing/file_types.py index 9f419fc222d..68f547b7c2b 100644 --- a/backend/onyx/file_processing/file_types.py +++ b/backend/onyx/file_processing/file_types.py @@ -19,12 +19,16 @@ class OnyxMimeTypes: PLAIN_TEXT_MIME_TYPE, "text/markdown", "text/x-markdown", + "text/x-log", "text/x-config", "text/tab-separated-values", "application/json", "application/xml", "text/xml", "application/x-yaml", + "application/yaml", + "text/yaml", + "text/x-yaml", } DOCUMENT_MIME_TYPES = { PDF_MIME_TYPE, diff --git a/backend/onyx/indexing/adapters/document_indexing_adapter.py b/backend/onyx/indexing/adapters/document_indexing_adapter.py index 012e438dc21..3bd91c067c7 100644 --- a/backend/onyx/indexing/adapters/document_indexing_adapter.py +++ b/backend/onyx/indexing/adapters/document_indexing_adapter.py @@ -123,15 +123,11 @@ def build_metadata_aware_chunks( } doc_id_to_new_chunk_cnt: dict[str, int] = { - document_id: len( - [ - chunk - for chunk in chunks_with_embeddings - if chunk.source_document.id == document_id - ] - ) - for document_id in updatable_ids + doc_id: 0 for doc_id in updatable_ids } + for chunk in chunks_with_embeddings: + if chunk.source_document.id in doc_id_to_new_chunk_cnt: + doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1 # Get ancestor hierarchy node IDs for each document doc_id_to_ancestor_ids = self._get_ancestor_ids_for_documents( diff --git a/backend/onyx/indexing/embedder.py b/backend/onyx/indexing/embedder.py index 40fc5c7c8bb..bf43857b12c 100644 --- a/backend/onyx/indexing/embedder.py +++ b/backend/onyx/indexing/embedder.py @@ -16,6 +16,7 @@ from onyx.indexing.models import IndexChunk from onyx.natural_language_processing.search_nlp_models import EmbeddingModel from onyx.utils.logger import setup_logger +from onyx.utils.pydantic_util import shallow_model_dump from onyx.utils.timing import log_function_time from shared_configs.configs import INDEXING_MODEL_SERVER_HOST from shared_configs.configs import INDEXING_MODEL_SERVER_PORT @@ -210,8 +211,8 @@ def embed_chunks( )[0] title_embed_dict[title] = title_embedding - new_embedded_chunk = IndexChunk( - **chunk.model_dump(), + new_embedded_chunk = IndexChunk.model_construct( + **shallow_model_dump(chunk), embeddings=ChunkEmbedding( full_embedding=chunk_embeddings[0], mini_chunk_embeddings=chunk_embeddings[1:], diff --git a/backend/onyx/indexing/indexing_pipeline.py b/backend/onyx/indexing/indexing_pipeline.py index 251f70194ee..da5ac417fcc 100644 --- a/backend/onyx/indexing/indexing_pipeline.py +++ b/backend/onyx/indexing/indexing_pipeline.py @@ -49,7 +49,6 @@ from onyx.indexing.models import DocAwareChunk from onyx.indexing.models import IndexingBatchAdapter from onyx.indexing.models import UpdatableChunkData -from onyx.indexing.postgres_sanitization import sanitize_documents_for_postgres from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff from onyx.llm.factory import get_default_llm_with_vision from onyx.llm.factory import get_llm_for_contextual_rag @@ -65,6 +64,7 @@ from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT2 from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_PROMPT from onyx.utils.logger import setup_logger +from onyx.utils.postgres_sanitization import sanitize_documents_for_postgres from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel from onyx.utils.timing import log_function_time diff --git a/backend/onyx/indexing/models.py b/backend/onyx/indexing/models.py index 36e51ceed2f..4fcd6850c61 100644 --- a/backend/onyx/indexing/models.py +++ b/backend/onyx/indexing/models.py @@ -12,6 +12,7 @@ from onyx.db.enums import EmbeddingPrecision from onyx.db.enums import SwitchoverType from onyx.utils.logger import setup_logger +from onyx.utils.pydantic_util import shallow_model_dump from shared_configs.enums import EmbeddingProvider from shared_configs.model_server_models import Embedding @@ -133,9 +134,8 @@ def from_index_chunk( tenant_id: str, ancestor_hierarchy_node_ids: list[int] | None = None, ) -> "DocMetadataAwareIndexChunk": - index_chunk_data = index_chunk.model_dump() - return cls( - **index_chunk_data, + return cls.model_construct( + **shallow_model_dump(index_chunk), access=access, document_sets=document_sets, user_project=user_project, diff --git a/backend/onyx/key_value_store/interface.py b/backend/onyx/key_value_store/interface.py index 9d6348076c0..7766513a85d 100644 --- a/backend/onyx/key_value_store/interface.py +++ b/backend/onyx/key_value_store/interface.py @@ -1,4 +1,5 @@ import abc +from typing import cast from onyx.utils.special_types import JSON_ro @@ -7,6 +8,19 @@ class KvKeyNotFoundError(Exception): pass +def unwrap_str(val: JSON_ro) -> str: + """Unwrap a string stored as {"value": str} in the encrypted KV store. + Also handles legacy plain-string values cached in Redis.""" + if isinstance(val, dict): + try: + return cast(str, val["value"]) + except KeyError: + raise ValueError( + f"Expected dict with 'value' key, got keys: {list(val.keys())}" + ) + return cast(str, val) + + class KeyValueStore: # In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in # It's read from the global thread level variable diff --git a/backend/onyx/key_value_store/store.py b/backend/onyx/key_value_store/store.py index 58dc2e8abe9..b44d6bc2712 100644 --- a/backend/onyx/key_value_store/store.py +++ b/backend/onyx/key_value_store/store.py @@ -1,13 +1,11 @@ import json from typing import cast -from redis.client import Redis - +from onyx.cache.interface import CacheBackend from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import KVStore from onyx.key_value_store.interface import KeyValueStore from onyx.key_value_store.interface import KvKeyNotFoundError -from onyx.redis.redis_pool import get_redis_client from onyx.utils.logger import setup_logger from onyx.utils.special_types import JSON_ro @@ -20,22 +18,27 @@ class PgRedisKVStore(KeyValueStore): - def __init__(self, redis_client: Redis | None = None) -> None: - # If no redis_client is provided, fall back to the context var - if redis_client is not None: - self.redis_client = redis_client - else: - self.redis_client = get_redis_client() + def __init__(self, cache: CacheBackend | None = None) -> None: + self._cache = cache + + def _get_cache(self) -> CacheBackend: + if self._cache is None: + from onyx.cache.factory import get_cache_backend + + self._cache = get_cache_backend() + return self._cache def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None: - # Not encrypted in Redis, but encrypted in Postgres + # Not encrypted in Cache backend (typically Redis), but encrypted in Postgres try: - self.redis_client.set( + self._get_cache().set( REDIS_KEY_PREFIX + key, json.dumps(val), ex=KV_REDIS_KEY_EXPIRATION ) except Exception as e: - # Fallback gracefully to Postgres if Redis fails - logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}") + # Fallback gracefully to Postgres if Cache backend fails + logger.error( + f"Failed to set value in Cache backend for key '{key}': {str(e)}" + ) encrypted_val = val if encrypt else None plain_val = val if not encrypt else None @@ -53,16 +56,12 @@ def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None: def load(self, key: str, refresh_cache: bool = False) -> JSON_ro: if not refresh_cache: try: - redis_value = self.redis_client.get(REDIS_KEY_PREFIX + key) - if redis_value: - if not isinstance(redis_value, bytes): - raise ValueError( - f"Redis value for key '{key}' is not a bytes object" - ) - return json.loads(redis_value.decode("utf-8")) + cached = self._get_cache().get(REDIS_KEY_PREFIX + key) + if cached is not None: + return json.loads(cached.decode("utf-8")) except Exception as e: logger.error( - f"Failed to get value from Redis for key '{key}': {str(e)}" + f"Failed to get value from cache for key '{key}': {str(e)}" ) with get_session_with_current_tenant() as db_session: @@ -79,21 +78,21 @@ def load(self, key: str, refresh_cache: bool = False) -> JSON_ro: value = None try: - self.redis_client.set( + self._get_cache().set( REDIS_KEY_PREFIX + key, json.dumps(value), ex=KV_REDIS_KEY_EXPIRATION, ) except Exception as e: - logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}") + logger.error(f"Failed to set value in cache for key '{key}': {str(e)}") return cast(JSON_ro, value) def delete(self, key: str) -> None: try: - self.redis_client.delete(REDIS_KEY_PREFIX + key) + self._get_cache().delete(REDIS_KEY_PREFIX + key) except Exception as e: - logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}") + logger.error(f"Failed to delete value from cache for key '{key}': {str(e)}") with get_session_with_current_tenant() as db_session: result = db_session.query(KVStore).filter_by(key=key).delete() diff --git a/backend/onyx/llm/constants.py b/backend/onyx/llm/constants.py index 5480e4622be..0327d9faa41 100644 --- a/backend/onyx/llm/constants.py +++ b/backend/onyx/llm/constants.py @@ -22,6 +22,7 @@ class LlmProviderNames(str, Enum): OPENROUTER = "openrouter" AZURE = "azure" OLLAMA_CHAT = "ollama_chat" + LM_STUDIO = "lm_studio" MISTRAL = "mistral" LITELLM_PROXY = "litellm_proxy" @@ -41,6 +42,8 @@ def __str__(self) -> str: LlmProviderNames.OPENROUTER, LlmProviderNames.AZURE, LlmProviderNames.OLLAMA_CHAT, + LlmProviderNames.LM_STUDIO, + LlmProviderNames.LITELLM_PROXY, ] @@ -56,6 +59,8 @@ def __str__(self) -> str: LlmProviderNames.AZURE: "Azure", "ollama": "Ollama", LlmProviderNames.OLLAMA_CHAT: "Ollama", + LlmProviderNames.LM_STUDIO: "LM Studio", + LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy", "groq": "Groq", "anyscale": "Anyscale", "deepseek": "DeepSeek", @@ -103,8 +108,10 @@ def __str__(self) -> str: LlmProviderNames.BEDROCK_CONVERSE, LlmProviderNames.OPENROUTER, LlmProviderNames.OLLAMA_CHAT, + LlmProviderNames.LM_STUDIO, LlmProviderNames.VERTEX_AI, LlmProviderNames.AZURE, + LlmProviderNames.LITELLM_PROXY, } # Model family name mappings for display name generation diff --git a/backend/onyx/llm/factory.py b/backend/onyx/llm/factory.py index 14ce113fa88..085a7c411f8 100644 --- a/backend/onyx/llm/factory.py +++ b/backend/onyx/llm/factory.py @@ -20,7 +20,9 @@ from onyx.llm.override_models import LLMOverride from onyx.llm.utils import get_max_input_tokens_from_llm_provider from onyx.llm.utils import model_supports_image_input -from onyx.llm.well_known_providers.constants import OLLAMA_API_KEY_CONFIG_KEY +from onyx.llm.well_known_providers.constants import ( + PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING, +) from onyx.natural_language_processing.utils import get_tokenizer from onyx.server.manage.llm.models import LLMProviderView from onyx.utils.headers import build_llm_extra_headers @@ -32,14 +34,18 @@ def _build_provider_extra_headers( provider: str, custom_config: dict[str, str] | None ) -> dict[str, str]: - if provider == LlmProviderNames.OLLAMA_CHAT and custom_config: - raw_api_key = custom_config.get(OLLAMA_API_KEY_CONFIG_KEY) - api_key = raw_api_key.strip() if raw_api_key else None + if provider in PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING and custom_config: + raw = custom_config.get(PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING[provider]) + api_key = raw.strip() if raw else None if not api_key: return {} - if not api_key.lower().startswith("bearer "): - api_key = f"Bearer {api_key}" - return {"Authorization": api_key} + return { + "Authorization": ( + api_key + if api_key.lower().startswith("bearer ") + else f"Bearer {api_key}" + ) + } # Passing these will put Onyx on the OpenRouter leaderboard elif provider == LlmProviderNames.OPENROUTER: diff --git a/backend/onyx/llm/litellm_singleton/monkey_patches.py b/backend/onyx/llm/litellm_singleton/monkey_patches.py index 78587bd3dff..8feb6329fd9 100644 --- a/backend/onyx/llm/litellm_singleton/monkey_patches.py +++ b/backend/onyx/llm/litellm_singleton/monkey_patches.py @@ -67,6 +67,18 @@ STATUS: STILL NEEDED - litellm_core_utils/litellm_logging.py lines 3185-3199 set usage as a dict with chat completion format instead of keeping it as ResponseAPIUsage. Our patch creates a deep copy before modification. + +7. Responses API metadata=None TypeError (_patch_responses_metadata_none): + - LiteLLM's @client decorator wrapper in utils.py uses kwargs.get("metadata", {}) + to check for router calls, but when metadata is explicitly None (key exists with + value None), the default {} is not used + - This causes "argument of type 'NoneType' is not iterable" TypeError which swallows + the real exception (e.g. AuthenticationError for wrong API key) + - Surfaces as: APIConnectionError: OpenAIException - argument of type 'NoneType' is + not iterable + STATUS: STILL NEEDED - litellm/utils.py wrapper function (line 1721) does not guard + against metadata being explicitly None. Triggered when Responses API bridge + passes **litellm_params containing metadata=None. """ import time @@ -725,6 +737,44 @@ def _patched_get_assembled_streaming_response( LiteLLMLoggingObj._get_assembled_streaming_response = _patched_get_assembled_streaming_response # type: ignore[method-assign] +def _patch_responses_metadata_none() -> None: + """ + Patches litellm.responses to normalize metadata=None to metadata={} in kwargs. + + LiteLLM's @client decorator wrapper in utils.py (line 1721) does: + _is_litellm_router_call = "model_group" in kwargs.get("metadata", {}) + When metadata is explicitly None in kwargs, kwargs.get("metadata", {}) returns + None (the key exists, so the default is not used), causing: + TypeError: argument of type 'NoneType' is not iterable + + This swallows the real exception (e.g. AuthenticationError) and surfaces as: + APIConnectionError: OpenAIException - argument of type 'NoneType' is not iterable + + This happens when the Responses API bridge calls litellm.responses() with + **litellm_params which may contain metadata=None. + + STATUS: STILL NEEDED - litellm/utils.py wrapper function uses kwargs.get("metadata", {}) + which does not guard against metadata being explicitly None. Same pattern exists + on line 1407 for async path. + """ + import litellm as _litellm + from functools import wraps + + original_responses = _litellm.responses + + if getattr(original_responses, "_metadata_patched", False): + return + + @wraps(original_responses) + def _patched_responses(*args: Any, **kwargs: Any) -> Any: + if kwargs.get("metadata") is None: + kwargs["metadata"] = {} + return original_responses(*args, **kwargs) + + _patched_responses._metadata_patched = True # type: ignore[attr-defined] + _litellm.responses = _patched_responses + + def apply_monkey_patches() -> None: """ Apply all necessary monkey patches to LiteLLM for compatibility. @@ -736,6 +786,7 @@ def apply_monkey_patches() -> None: - Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming - Patching ResponsesAPIResponse.model_construct to fix usage format in all code paths - Patching LiteLLMLoggingObj._get_assembled_streaming_response to avoid mutating original response + - Patching litellm.responses to fix metadata=None causing TypeError in error handling """ _patch_ollama_chunk_parser() _patch_openai_responses_parallel_tool_calls() @@ -743,3 +794,4 @@ def apply_monkey_patches() -> None: _patch_azure_responses_should_fake_stream() _patch_responses_api_usage_format() _patch_logging_assembled_streaming_response() + _patch_responses_metadata_none() diff --git a/backend/onyx/llm/model_metadata_enrichments.json b/backend/onyx/llm/model_metadata_enrichments.json index 4b6c2452dd4..23db46c9310 100644 --- a/backend/onyx/llm/model_metadata_enrichments.json +++ b/backend/onyx/llm/model_metadata_enrichments.json @@ -1512,6 +1512,10 @@ "display_name": "Claude Opus 4.5", "model_vendor": "anthropic" }, + "claude-opus-4-6": { + "display_name": "Claude Opus 4.6", + "model_vendor": "anthropic" + }, "claude-opus-4-5-20251101": { "display_name": "Claude Opus 4.5", "model_vendor": "anthropic", @@ -1526,6 +1530,10 @@ "display_name": "Claude Sonnet 4.5", "model_vendor": "anthropic" }, + "claude-sonnet-4-6": { + "display_name": "Claude Sonnet 4.6", + "model_vendor": "anthropic" + }, "claude-sonnet-4-5-20250929": { "display_name": "Claude Sonnet 4.5", "model_vendor": "anthropic", @@ -2516,6 +2524,10 @@ "model_vendor": "openai", "model_version": "2025-10-06" }, + "gpt-5.4": { + "display_name": "GPT-5.4", + "model_vendor": "openai" + }, "gpt-5.2-pro-2025-12-11": { "display_name": "GPT-5.2 Pro", "model_vendor": "openai", @@ -3770,16 +3782,6 @@ "display_name": "Claude Sonnet 3.5", "model_vendor": "anthropic" }, - "vertex_ai/claude-3-5-sonnet-v2": { - "display_name": "Claude Sonnet 3.5", - "model_vendor": "anthropic", - "model_version": "v2" - }, - "vertex_ai/claude-3-5-sonnet-v2@20241022": { - "display_name": "Claude Sonnet 3.5 v2", - "model_vendor": "anthropic", - "model_version": "20241022" - }, "vertex_ai/claude-3-5-sonnet@20240620": { "display_name": "Claude Sonnet 3.5", "model_vendor": "anthropic", diff --git a/backend/onyx/llm/multi_llm.py b/backend/onyx/llm/multi_llm.py index fc10b8a98f0..9b2087b9a75 100644 --- a/backend/onyx/llm/multi_llm.py +++ b/backend/onyx/llm/multi_llm.py @@ -42,6 +42,7 @@ from onyx.llm.well_known_providers.constants import ( AWS_SECRET_ACCESS_KEY_KWARG_ENV_VAR_FORMAT, ) +from onyx.llm.well_known_providers.constants import LM_STUDIO_API_KEY_CONFIG_KEY from onyx.llm.well_known_providers.constants import OLLAMA_API_KEY_CONFIG_KEY from onyx.llm.well_known_providers.constants import VERTEX_CREDENTIALS_FILE_KWARG from onyx.llm.well_known_providers.constants import ( @@ -92,6 +93,98 @@ def _prompt_to_dicts(prompt: LanguageModelInput) -> list[dict[str, Any]]: return [prompt.model_dump(exclude_none=True)] +def _normalize_content(raw: Any) -> str: + """Normalize a message content field to a plain string. + + Content can be a string, None, or a list of content-block dicts + (e.g. [{"type": "text", "text": "..."}]). + """ + if raw is None: + return "" + if isinstance(raw, str): + return raw + if isinstance(raw, list): + return "\n".join( + block.get("text", "") if isinstance(block, dict) else str(block) + for block in raw + ) + return str(raw) + + +def _strip_tool_content_from_messages( + messages: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Convert tool-related messages to plain text. + + Bedrock's Converse API requires toolConfig when messages contain + toolUse/toolResult content blocks. When no tools are provided for the + current request, we must convert any tool-related history into plain text + to avoid the "toolConfig field must be defined" error. + + This is the same approach used by _OllamaHistoryMessageFormatter. + """ + result: list[dict[str, Any]] = [] + for msg in messages: + role = msg.get("role") + tool_calls = msg.get("tool_calls") + + if role == "assistant" and tool_calls: + # Convert structured tool calls to text representation + tool_call_lines = [] + for tc in tool_calls: + func = tc.get("function", {}) + name = func.get("name", "unknown") + args = func.get("arguments", "{}") + tc_id = tc.get("id", "") + tool_call_lines.append( + f"[Tool Call] name={name} id={tc_id} args={args}" + ) + + existing_content = _normalize_content(msg.get("content")) + parts = ( + [existing_content] + tool_call_lines + if existing_content + else tool_call_lines + ) + new_msg = { + "role": "assistant", + "content": "\n".join(parts), + } + result.append(new_msg) + + elif role == "tool": + # Convert tool response to user message with text content + tool_call_id = msg.get("tool_call_id", "") + content = _normalize_content(msg.get("content")) + tool_result_text = f"[Tool Result] id={tool_call_id}\n{content}" + # Merge into previous user message if it is also a converted + # tool result to avoid consecutive user messages (Bedrock requires + # strict user/assistant alternation). + if ( + result + and result[-1]["role"] == "user" + and "[Tool Result]" in result[-1].get("content", "") + ): + result[-1]["content"] += "\n\n" + tool_result_text + else: + result.append({"role": "user", "content": tool_result_text}) + + else: + result.append(msg) + + return result + + +def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool: + """Check if any messages contain tool-related content blocks.""" + for msg in messages: + if msg.get("role") == "tool": + return True + if msg.get("role") == "assistant" and msg.get("tool_calls"): + return True + return False + + def _is_vertex_model_rejecting_output_config(model_name: str) -> bool: normalized_model_name = model_name.lower() return any( @@ -157,6 +250,9 @@ def __init__( elif model_provider == LlmProviderNames.OLLAMA_CHAT: if k == OLLAMA_API_KEY_CONFIG_KEY: model_kwargs["api_key"] = v + elif model_provider == LlmProviderNames.LM_STUDIO: + if k == LM_STUDIO_API_KEY_CONFIG_KEY: + model_kwargs["api_key"] = v elif model_provider == LlmProviderNames.BEDROCK: if k == AWS_REGION_NAME_KWARG: model_kwargs[k] = v @@ -173,6 +269,19 @@ def __init__( elif k == AWS_SECRET_ACCESS_KEY_KWARG_ENV_VAR_FORMAT: model_kwargs[AWS_SECRET_ACCESS_KEY_KWARG] = v + # LM Studio: LiteLLM defaults to "fake-api-key" when no key is provided, + # which LM Studio rejects. Ensure we always pass an explicit key (or empty + # string) to prevent LiteLLM from injecting its fake default. + if model_provider == LlmProviderNames.LM_STUDIO: + model_kwargs.setdefault("api_key", "") + + # Users provide the server root (e.g. http://localhost:1234) but LiteLLM + # needs /v1 for OpenAI-compatible calls. + if self._api_base is not None: + base = self._api_base.rstrip("/") + self._api_base = base if base.endswith("/v1") else f"{base}/v1" + model_kwargs["api_base"] = self._api_base + # Default vertex_location to "global" if not provided for Vertex AI # Latest gemini models are only available through the global region if ( @@ -404,13 +513,30 @@ def _completion( else nullcontext() ) with env_ctx: + messages = _prompt_to_dicts(prompt) + + # Bedrock's Converse API requires toolConfig when messages + # contain toolUse/toolResult content blocks. When no tools are + # provided for this request but the history contains tool + # content from previous turns, strip it to plain text. + is_bedrock = self._model_provider in { + LlmProviderNames.BEDROCK, + LlmProviderNames.BEDROCK_CONVERSE, + } + if ( + is_bedrock + and not tools + and _messages_contain_tool_content(messages) + ): + messages = _strip_tool_content_from_messages(messages) + response = litellm.completion( mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE, model=model, base_url=self._api_base or None, api_version=self._api_version or None, custom_llm_provider=self._custom_llm_provider or None, - messages=_prompt_to_dicts(prompt), + messages=messages, tools=tools, tool_choice=tool_choice, stream=stream, diff --git a/backend/onyx/llm/utils.py b/backend/onyx/llm/utils.py index 86eb83970e4..c19a21a3a58 100644 --- a/backend/onyx/llm/utils.py +++ b/backend/onyx/llm/utils.py @@ -322,7 +322,7 @@ def test_llm(llm: LLM) -> str | None: error_msg = None for _ in range(2): try: - llm.invoke(UserMessage(content="Do not respond")) + llm.invoke(UserMessage(content="Do not respond"), max_tokens=50) return None except Exception as e: error_msg = str(e) diff --git a/backend/onyx/llm/well_known_providers/auto_update_service.py b/backend/onyx/llm/well_known_providers/auto_update_service.py index c48cc7cfb6a..92a6bd83043 100644 --- a/backend/onyx/llm/well_known_providers/auto_update_service.py +++ b/backend/onyx/llm/well_known_providers/auto_update_service.py @@ -13,44 +13,38 @@ import httpx from sqlalchemy.orm import Session +from onyx.cache.factory import get_cache_backend from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL from onyx.db.llm import fetch_auto_mode_providers from onyx.db.llm import sync_auto_mode_models from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations -from onyx.redis.redis_pool import get_redis_client from onyx.utils.logger import setup_logger logger = setup_logger() -# Redis key for caching the last updated timestamp (per-tenant) -_REDIS_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at" +_CACHE_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at" +_CACHE_TTL_SECONDS = 60 * 60 * 24 # 24 hours def _get_cached_last_updated_at() -> datetime | None: - """Get the cached last_updated_at timestamp from Redis.""" try: - redis_client = get_redis_client() - value = redis_client.get(_REDIS_KEY_LAST_UPDATED_AT) - if value and isinstance(value, bytes): - # Value is bytes, decode to string then parse as ISO format + value = get_cache_backend().get(_CACHE_KEY_LAST_UPDATED_AT) + if value is not None: return datetime.fromisoformat(value.decode("utf-8")) except Exception as e: - logger.warning(f"Failed to get cached last_updated_at from Redis: {e}") + logger.warning(f"Failed to get cached last_updated_at: {e}") return None def _set_cached_last_updated_at(updated_at: datetime) -> None: - """Set the cached last_updated_at timestamp in Redis.""" try: - redis_client = get_redis_client() - # Store as ISO format string, with 24 hour expiration - redis_client.set( - _REDIS_KEY_LAST_UPDATED_AT, + get_cache_backend().set( + _CACHE_KEY_LAST_UPDATED_AT, updated_at.isoformat(), - ex=60 * 60 * 24, # 24 hours + ex=_CACHE_TTL_SECONDS, ) except Exception as e: - logger.warning(f"Failed to set cached last_updated_at in Redis: {e}") + logger.warning(f"Failed to set cached last_updated_at: {e}") def fetch_llm_recommendations_from_github( @@ -148,9 +142,8 @@ def sync_llm_models_from_github( def reset_cache() -> None: - """Reset the cache timestamp in Redis. Useful for testing.""" + """Reset the cache timestamp. Useful for testing.""" try: - redis_client = get_redis_client() - redis_client.delete(_REDIS_KEY_LAST_UPDATED_AT) + get_cache_backend().delete(_CACHE_KEY_LAST_UPDATED_AT) except Exception as e: - logger.warning(f"Failed to reset cache in Redis: {e}") + logger.warning(f"Failed to reset cache: {e}") diff --git a/backend/onyx/llm/well_known_providers/constants.py b/backend/onyx/llm/well_known_providers/constants.py index 2094ee4924c..fba2600b52a 100644 --- a/backend/onyx/llm/well_known_providers/constants.py +++ b/backend/onyx/llm/well_known_providers/constants.py @@ -1,52 +1,29 @@ +from onyx.llm.constants import LlmProviderNames + OPENAI_PROVIDER_NAME = "openai" -# Curated list of OpenAI models to show by default in the UI -OPENAI_VISIBLE_MODEL_NAMES = { - "gpt-5", - "gpt-5-mini", - "o1", - "o3-mini", - "gpt-4o", - "gpt-4o-mini", -} BEDROCK_PROVIDER_NAME = "bedrock" -BEDROCK_DEFAULT_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0" - - -def _fallback_bedrock_regions() -> list[str]: - # Fall back to a conservative set of well-known Bedrock regions if boto3 data isn't available. - return [ - "us-east-1", - "us-east-2", - "us-gov-east-1", - "us-gov-west-1", - "us-west-2", - "ap-northeast-1", - "ap-south-1", - "ap-southeast-1", - "ap-southeast-2", - "ap-east-1", - "ca-central-1", - "eu-central-1", - "eu-west-2", - ] OLLAMA_PROVIDER_NAME = "ollama_chat" OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY" +LM_STUDIO_PROVIDER_NAME = "lm_studio" +LM_STUDIO_API_KEY_CONFIG_KEY = "LM_STUDIO_API_KEY" + +LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy" + +# Providers that use optional Bearer auth from custom_config +PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = { + LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY, + LlmProviderNames.LM_STUDIO: LM_STUDIO_API_KEY_CONFIG_KEY, +} + # OpenRouter OPENROUTER_PROVIDER_NAME = "openrouter" ANTHROPIC_PROVIDER_NAME = "anthropic" -# Curated list of Anthropic models to show by default in the UI -ANTHROPIC_VISIBLE_MODEL_NAMES = { - "claude-opus-4-5", - "claude-sonnet-4-5", - "claude-haiku-4-5", -} - AZURE_PROVIDER_NAME = "azure" @@ -54,13 +31,6 @@ def _fallback_bedrock_regions() -> list[str]: VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials" VERTEX_CREDENTIALS_FILE_KWARG_ENV_VAR_FORMAT = "CREDENTIALS_FILE" VERTEX_LOCATION_KWARG = "vertex_location" -VERTEXAI_DEFAULT_MODEL = "gemini-2.5-flash" -# Curated list of Vertex AI models to show by default in the UI -VERTEXAI_VISIBLE_MODEL_NAMES = { - "gemini-2.5-flash", - "gemini-2.5-flash-lite", - "gemini-2.5-pro", -} AWS_REGION_NAME_KWARG = "aws_region_name" AWS_REGION_NAME_KWARG_ENV_VAR_FORMAT = "AWS_REGION_NAME" diff --git a/backend/onyx/llm/well_known_providers/llm_provider_options.py b/backend/onyx/llm/well_known_providers/llm_provider_options.py index b5b95784450..a456ff52690 100644 --- a/backend/onyx/llm/well_known_providers/llm_provider_options.py +++ b/backend/onyx/llm/well_known_providers/llm_provider_options.py @@ -15,6 +15,8 @@ from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME from onyx.llm.well_known_providers.constants import AZURE_PROVIDER_NAME from onyx.llm.well_known_providers.constants import BEDROCK_PROVIDER_NAME +from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME +from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME @@ -44,7 +46,9 @@ def _get_provider_to_models_map() -> dict[str, list[str]]: ANTHROPIC_PROVIDER_NAME: get_anthropic_model_names(), VERTEXAI_PROVIDER_NAME: get_vertexai_model_names(), OLLAMA_PROVIDER_NAME: [], # Dynamic - fetched from Ollama API + LM_STUDIO_PROVIDER_NAME: [], # Dynamic - fetched from LM Studio API OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API + LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API } @@ -323,11 +327,13 @@ def get_provider_display_name(provider_name: str) -> str: _ONYX_PROVIDER_DISPLAY_NAMES: dict[str, str] = { OPENAI_PROVIDER_NAME: "ChatGPT (OpenAI)", OLLAMA_PROVIDER_NAME: "Ollama", + LM_STUDIO_PROVIDER_NAME: "LM Studio", ANTHROPIC_PROVIDER_NAME: "Claude (Anthropic)", AZURE_PROVIDER_NAME: "Azure OpenAI", BEDROCK_PROVIDER_NAME: "Amazon Bedrock", VERTEXAI_PROVIDER_NAME: "Google Vertex AI", OPENROUTER_PROVIDER_NAME: "OpenRouter", + LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy", } if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES: diff --git a/backend/onyx/llm/well_known_providers/recommended-models.json b/backend/onyx/llm/well_known_providers/recommended-models.json index ffce42f3279..ec32ab598fa 100644 --- a/backend/onyx/llm/well_known_providers/recommended-models.json +++ b/backend/onyx/llm/well_known_providers/recommended-models.json @@ -1,12 +1,12 @@ { "version": "1.1", - "updated_at": "2026-02-05T00:00:00Z", + "updated_at": "2026-03-05T00:00:00Z", "providers": { "openai": { - "default_model": { "name": "gpt-5.2" }, + "default_model": { "name": "gpt-5.4" }, "additional_visible_models": [ - { "name": "gpt-5-mini" }, - { "name": "gpt-4.1" } + { "name": "gpt-5.4" }, + { "name": "gpt-5.2" } ] }, "anthropic": { @@ -16,6 +16,10 @@ "name": "claude-opus-4-6", "display_name": "Claude Opus 4.6" }, + { + "name": "claude-sonnet-4-6", + "display_name": "Claude Sonnet 4.6" + }, { "name": "claude-opus-4-5", "display_name": "Claude Opus 4.5" diff --git a/backend/onyx/main.py b/backend/onyx/main.py index 405581aa7b5..d7bd83cd7a9 100644 --- a/backend/onyx/main.py +++ b/backend/onyx/main.py @@ -32,16 +32,19 @@ from onyx.auth.users import auth_backend from onyx.auth.users import create_onyx_oauth_router from onyx.auth.users import fastapi_users +from onyx.cache.interface import CacheBackendType from onyx.configs.app_configs import APP_API_PREFIX from onyx.configs.app_configs import APP_HOST from onyx.configs.app_configs import APP_PORT from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED from onyx.configs.app_configs import AUTH_TYPE +from onyx.configs.app_configs import CACHE_BACKEND from onyx.configs.app_configs import DISABLE_VECTOR_DB from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY from onyx.configs.app_configs import OAUTH_CLIENT_ID from onyx.configs.app_configs import OAUTH_CLIENT_SECRET from onyx.configs.app_configs import OAUTH_ENABLED +from onyx.configs.app_configs import OIDC_PKCE_ENABLED from onyx.configs.app_configs import OIDC_SCOPE_OVERRIDE from onyx.configs.app_configs import OPENID_CONFIG_URL from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW @@ -57,6 +60,7 @@ from onyx.db.engine.connection_warmup import warm_up_connections from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.engine.sql_engine import SqlEngine +from onyx.error_handling.exceptions import register_onyx_exception_handlers from onyx.file_store.file_store import get_default_file_store from onyx.server.api_key.api import router as api_key_router from onyx.server.auth_check import check_router_auth @@ -255,6 +259,20 @@ def include_auth_router_with_prefix( ) +def validate_cache_backend_settings() -> None: + """Validate that CACHE_BACKEND=postgres is only used with DISABLE_VECTOR_DB. + + The Postgres cache backend eliminates the Redis dependency, but only works + when Celery is not running (which requires DISABLE_VECTOR_DB=true). + """ + if CACHE_BACKEND == CacheBackendType.POSTGRES and not DISABLE_VECTOR_DB: + raise RuntimeError( + "CACHE_BACKEND=postgres requires DISABLE_VECTOR_DB=true. " + "The Postgres cache backend is only supported in no-vector-DB " + "deployments where Celery is replaced by the in-process task runner." + ) + + def validate_no_vector_db_settings() -> None: """Validate that DISABLE_VECTOR_DB is not combined with incompatible settings. @@ -286,6 +304,7 @@ def validate_no_vector_db_settings() -> None: @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 validate_no_vector_db_settings() + validate_cache_backend_settings() # Set recursion limit if SYSTEM_RECURSION_LIMIT is not None: @@ -355,8 +374,20 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 if AUTH_RATE_LIMITING_ENABLED: await setup_auth_limiter() + if DISABLE_VECTOR_DB: + from onyx.background.periodic_poller import recover_stuck_user_files + from onyx.background.periodic_poller import start_periodic_poller + + recover_stuck_user_files(POSTGRES_DEFAULT_SCHEMA) + start_periodic_poller(POSTGRES_DEFAULT_SCHEMA) + yield + if DISABLE_VECTOR_DB: + from onyx.background.periodic_poller import stop_periodic_poller + + stop_periodic_poller() + SqlEngine.reset_engine() if AUTH_RATE_LIMITING_ENABLED: @@ -415,6 +446,8 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI: status.HTTP_500_INTERNAL_SERVER_ERROR, log_http_error ) + register_onyx_exception_handlers(application) + include_router_with_global_prefix_prepended(application, password_router) include_router_with_global_prefix_prepended(application, chat_router) include_router_with_global_prefix_prepended(application, query_router) @@ -565,6 +598,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI: associate_by_email=True, is_verified_by_default=True, redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback", + enable_pkce=OIDC_PKCE_ENABLED, ), prefix="/auth/oidc", ) diff --git a/backend/onyx/mcp_server/tools/search.py b/backend/onyx/mcp_server/tools/search.py index b3c0a427c79..c509b8d0aff 100644 --- a/backend/onyx/mcp_server/tools/search.py +++ b/backend/onyx/mcp_server/tools/search.py @@ -10,6 +10,7 @@ from onyx.mcp_server.utils import require_access_token from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import build_api_server_url_for_http_requests +from onyx.utils.variable_functionality import global_version logger = setup_logger() @@ -26,6 +27,14 @@ async def search_indexed_documents( Use this tool for information that is not public knowledge and specific to the user, their team, their work, or their organization/company. + Note: In CE mode, this tool uses the chat endpoint internally which invokes an LLM + on every call, consuming tokens and adding latency. + Additionally, CE callers receive a truncated snippet (blurb) instead of a full document chunk, + but this should still be sufficient for most use cases. CE mode functionality should be swapped + when a dedicated CE search endpoint is implemented. + + In EE mode, the dedicated search endpoint is used instead. + To find a list of available sources, use the `indexed_sources` resource. Returns chunks of text as search results with snippets, scores, and metadata. @@ -111,48 +120,73 @@ async def search_indexed_documents( if time_cutoff_dt: filters["time_cutoff"] = time_cutoff_dt.isoformat() - # Build the search request using the new SendSearchQueryRequest format - search_request = { - "search_query": query, - "filters": filters, - "num_docs_fed_to_llm_selection": limit, - "run_query_expansion": False, - "include_content": True, - "stream": False, - } + is_ee = global_version.is_ee_version() + base_url = build_api_server_url_for_http_requests(respect_env_override_if_set=True) + auth_headers = {"Authorization": f"Bearer {access_token.token}"} + + search_request: dict[str, Any] + if is_ee: + # EE: use the dedicated search endpoint (no LLM invocation) + search_request = { + "search_query": query, + "filters": filters, + "num_docs_fed_to_llm_selection": limit, + "run_query_expansion": False, + "include_content": True, + "stream": False, + } + endpoint = f"{base_url}/search/send-search-message" + error_key = "error" + docs_key = "search_docs" + content_field = "content" + else: + # CE: fall back to the chat endpoint (invokes LLM, consumes tokens) + search_request = { + "message": query, + "stream": False, + "chat_session_info": {}, + } + if filters: + search_request["internal_search_filters"] = filters + endpoint = f"{base_url}/chat/send-chat-message" + error_key = "error_msg" + docs_key = "top_documents" + content_field = "blurb" - # Call the API server using the new send-search-message route try: response = await get_http_client().post( - f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/search/send-search-message", + endpoint, json=search_request, - headers={"Authorization": f"Bearer {access_token.token}"}, + headers=auth_headers, ) response.raise_for_status() result = response.json() # Check for error in response - if result.get("error"): + if result.get(error_key): return { "documents": [], "total_results": 0, "query": query, - "error": result.get("error"), + "error": result.get(error_key), } - # Return simplified format for MCP clients - fields_to_return = [ - "semantic_identifier", - "content", - "source_type", - "link", - "score", - ] documents = [ - {key: doc.get(key) for key in fields_to_return} - for doc in result.get("search_docs", []) + { + "semantic_identifier": doc.get("semantic_identifier"), + "content": doc.get(content_field), + "source_type": doc.get("source_type"), + "link": doc.get("link"), + "score": doc.get("score"), + } + for doc in result.get(docs_key, []) ] + # NOTE: search depth is controlled by the backend persona defaults, not `limit`. + # `limit` only caps the returned list; fewer results may be returned if the + # backend retrieves fewer documents than requested. + documents = documents[:limit] + logger.info( f"Onyx MCP Server: Internal search returned {len(documents)} results" ) @@ -160,7 +194,6 @@ async def search_indexed_documents( "documents": documents, "total_results": len(documents), "query": query, - "executed_queries": result.get("all_executed_queries", [query]), } except Exception as e: logger.error(f"Onyx MCP Server: Document search error: {e}", exc_info=True) diff --git a/backend/onyx/onyxbot/slack/constants.py b/backend/onyx/onyxbot/slack/constants.py index 1f2d4ed68b7..fab3f5bf3e0 100644 --- a/backend/onyx/onyxbot/slack/constants.py +++ b/backend/onyx/onyxbot/slack/constants.py @@ -1,5 +1,9 @@ +import re from enum import Enum +# Matches Slack channel references like <#C097NBWMY8Y> or <#C097NBWMY8Y|channel-name> +SLACK_CHANNEL_REF_PATTERN = re.compile(r"<#([A-Z0-9]+)(?:\|([^>]+))?>") + LIKE_BLOCK_ACTION_ID = "feedback-like" DISLIKE_BLOCK_ACTION_ID = "feedback-dislike" SHOW_EVERYONE_ACTION_ID = "show-everyone" diff --git a/backend/onyx/onyxbot/slack/formatting.py b/backend/onyx/onyxbot/slack/formatting.py index 78a7ad40184..33551b4c155 100644 --- a/backend/onyx/onyxbot/slack/formatting.py +++ b/backend/onyx/onyxbot/slack/formatting.py @@ -130,7 +130,7 @@ def format_slack_message(message: str | None) -> str: message = _transform_outside_code_blocks(message, _sanitize_html) message = _convert_slack_links_to_markdown(message) normalized_message = _normalize_link_destinations(message) - md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"]) + md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough", "table"]) result = md(normalized_message) # With HTMLRenderer, result is always str (not AST list) assert isinstance(result, str) @@ -146,6 +146,11 @@ class SlackRenderer(HTMLRenderer): SPECIALS: dict[str, str] = {"&": "&", "<": "<", ">": ">"} + def __init__(self) -> None: + super().__init__() + self._table_headers: list[str] = [] + self._current_row_cells: list[str] = [] + def escape_special(self, text: str) -> str: for special, replacement in self.SPECIALS.items(): text = text.replace(special, replacement) @@ -218,5 +223,48 @@ def text(self, text: str) -> str: # as literal " text since Slack doesn't recognize that entity. return self.escape_special(text) + # -- Table rendering (converts markdown tables to vertical cards) -- + + def table_cell( + self, text: str, align: str | None = None, head: bool = False # noqa: ARG002 + ) -> str: + if head: + self._table_headers.append(text.strip()) + else: + self._current_row_cells.append(text.strip()) + return "" + + def table_head(self, text: str) -> str: # noqa: ARG002 + self._current_row_cells = [] + return "" + + def table_row(self, text: str) -> str: # noqa: ARG002 + cells = self._current_row_cells + self._current_row_cells = [] + # First column becomes the bold title, remaining columns are bulleted fields + lines: list[str] = [] + if cells: + title = cells[0] + if title: + # Avoid double-wrapping if cell already contains bold markup + if title.startswith("*") and title.endswith("*") and len(title) > 1: + lines.append(title) + else: + lines.append(f"*{title}*") + for i, cell in enumerate(cells[1:], start=1): + if i < len(self._table_headers): + lines.append(f" • {self._table_headers[i]}: {cell}") + else: + lines.append(f" • {cell}") + return "\n".join(lines) + "\n\n" + + def table_body(self, text: str) -> str: + return text + + def table(self, text: str) -> str: + self._table_headers = [] + self._current_row_cells = [] + return text + "\n" + def paragraph(self, text: str) -> str: return f"{text}\n\n" diff --git a/backend/onyx/onyxbot/slack/handlers/handle_message.py b/backend/onyx/onyxbot/slack/handlers/handle_message.py index 971c1dca029..f16b52c3520 100644 --- a/backend/onyx/onyxbot/slack/handlers/handle_message.py +++ b/backend/onyx/onyxbot/slack/handlers/handle_message.py @@ -3,10 +3,12 @@ from slack_sdk import WebClient from slack_sdk.errors import SlackApiError +from onyx.auth.schemas import UserRole from onyx.configs.onyxbot_configs import ONYX_BOT_FEEDBACK_REMINDER from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import SlackChannelConfig +from onyx.db.user_preferences import activate_user from onyx.db.users import add_slack_user_if_not_exists from onyx.db.users import get_user_by_email from onyx.onyxbot.slack.blocks import get_feedback_reminder_blocks @@ -243,6 +245,44 @@ def handle_message( ) return False + elif ( + not existing_user.is_active + and existing_user.role == UserRole.SLACK_USER + ): + check_seat_fn = fetch_ee_implementation_or_noop( + "onyx.db.license", + "check_seat_availability", + None, + ) + seat_result = check_seat_fn(db_session=db_session) + if seat_result is not None and not seat_result.available: + logger.info( + f"Blocked inactive Slack user {message_info.email}: " + f"{seat_result.error_message}" + ) + respond_in_thread_or_channel( + client=client, + channel=channel, + thread_ts=message_info.msg_to_respond, + text=( + "We weren't able to respond because your organization " + "has reached its user seat limit. Your account is " + "currently deactivated and cannot be reactivated " + "until more seats are available. Please contact " + "your Onyx administrator." + ), + ) + return False + + activate_user(existing_user, db_session) + invalidate_license_cache_fn = fetch_ee_implementation_or_noop( + "onyx.db.license", + "invalidate_license_cache", + None, + ) + invalidate_license_cache_fn() + logger.info(f"Reactivated inactive Slack user {message_info.email}") + add_slack_user_if_not_exists(db_session, message_info.email) # first check if we need to respond with a standard answer diff --git a/backend/onyx/onyxbot/slack/handlers/handle_regular_answer.py b/backend/onyx/onyxbot/slack/handlers/handle_regular_answer.py index 320ee9d9e3f..3b08063ac94 100644 --- a/backend/onyx/onyxbot/slack/handlers/handle_regular_answer.py +++ b/backend/onyx/onyxbot/slack/handlers/handle_regular_answer.py @@ -18,15 +18,18 @@ from onyx.configs.onyxbot_configs import ONYX_BOT_NUM_RETRIES from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI from onyx.context.search.models import BaseFilters +from onyx.context.search.models import Tag from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import SlackChannelConfig from onyx.db.models import User from onyx.db.persona import get_persona_by_id from onyx.db.users import get_user_by_email from onyx.onyxbot.slack.blocks import build_slack_response_blocks +from onyx.onyxbot.slack.constants import SLACK_CHANNEL_REF_PATTERN from onyx.onyxbot.slack.handlers.utils import send_team_member_message from onyx.onyxbot.slack.models import SlackMessageInfo from onyx.onyxbot.slack.models import ThreadMessage +from onyx.onyxbot.slack.utils import get_channel_from_id from onyx.onyxbot.slack.utils import get_channel_name_from_id from onyx.onyxbot.slack.utils import respond_in_thread_or_channel from onyx.onyxbot.slack.utils import SlackRateLimiter @@ -41,6 +44,51 @@ RT = TypeVar("RT") # return type +def resolve_channel_references( + message: str, + client: WebClient, + logger: OnyxLoggingAdapter, +) -> tuple[str, list[Tag]]: + """Parse Slack channel references from a message, resolve IDs to names, + replace the raw markup with readable #channel-name, and return channel tags + for search filtering.""" + tags: list[Tag] = [] + channel_matches = SLACK_CHANNEL_REF_PATTERN.findall(message) + seen_channel_ids: set[str] = set() + + for channel_id, channel_name_from_markup in channel_matches: + if channel_id in seen_channel_ids: + continue + seen_channel_ids.add(channel_id) + + channel_name = channel_name_from_markup or None + + if not channel_name: + try: + channel_info = get_channel_from_id(client=client, channel_id=channel_id) + channel_name = channel_info.get("name") or None + except Exception: + logger.warning(f"Failed to resolve channel name for ID: {channel_id}") + + if not channel_name: + continue + + # Replace raw Slack markup with readable channel name + if channel_name_from_markup: + message = message.replace( + f"<#{channel_id}|{channel_name_from_markup}>", + f"#{channel_name}", + ) + else: + message = message.replace( + f"<#{channel_id}>", + f"#{channel_name}", + ) + tags.append(Tag(tag_key="Channel", tag_value=channel_name)) + + return message, tags + + def rate_limits( client: WebClient, channel: str, thread_ts: Optional[str] ) -> Callable[[Callable[..., RT]], Callable[..., RT]]: @@ -157,6 +205,20 @@ def handle_regular_answer( user_message = messages[-1] history_messages = messages[:-1] + # Resolve any <#CHANNEL_ID> references in the user message to readable + # channel names and extract channel tags for search filtering + resolved_message, channel_tags = resolve_channel_references( + message=user_message.message, + client=client, + logger=logger, + ) + + user_message = ThreadMessage( + message=resolved_message, + sender=user_message.sender, + role=user_message.role, + ) + channel_name, _ = get_channel_name_from_id( client=client, channel_id=channel, @@ -207,6 +269,7 @@ def _get_slack_answer( source_type=None, document_set=document_set_names, time_cutoff=None, + tags=channel_tags if channel_tags else None, ) new_message_request = SendMessageRequest( @@ -231,6 +294,16 @@ def _get_slack_answer( slack_context_str=slack_context_str, ) + # If a channel filter was applied but no results were found, override + # the LLM response to avoid hallucinated answers about unindexed channels + if channel_tags and not answer.citation_info and not answer.top_documents: + channel_names = ", ".join(f"#{tag.tag_value}" for tag in channel_tags) + answer.answer = ( + f"No indexed data found for {channel_names}. " + "This channel may not be indexed, or there may be no messages " + "matching your query within it." + ) + except Exception as e: logger.exception( f"Unable to process message - did not successfully answer " @@ -285,6 +358,7 @@ def _get_slack_answer( only_respond_if_citations and not answer.citation_info and not message_info.bypass_filters + and not channel_tags ): logger.error( f"Unable to find citations to answer: '{answer.answer}' - not answering!" diff --git a/backend/onyx/redis/redis_hierarchy.py b/backend/onyx/redis/redis_hierarchy.py index d086c82abf6..f23ef1d02d4 100644 --- a/backend/onyx/redis/redis_hierarchy.py +++ b/backend/onyx/redis/redis_hierarchy.py @@ -16,6 +16,7 @@ using only the SOURCE-type node as the ancestor """ +from typing import cast from typing import TYPE_CHECKING from pydantic import BaseModel @@ -204,6 +205,30 @@ def cache_hierarchy_nodes_batch( redis_client.expire(raw_id_key, HIERARCHY_CACHE_TTL_SECONDS) +def evict_hierarchy_nodes_from_cache( + redis_client: Redis, + source: DocumentSource, + raw_node_ids: list[str], +) -> None: + """Remove specific hierarchy nodes from the Redis cache. + + Deletes entries from both the parent-chain hash and the raw_id→node_id hash. + """ + if not raw_node_ids: + return + + cache_key = _cache_key(source) + raw_id_key = _raw_id_cache_key(source) + + # Look up node_ids so we can remove them from the parent-chain hash + raw_values = cast(list[str | None], redis_client.hmget(raw_id_key, raw_node_ids)) + node_id_strs = [v for v in raw_values if v is not None] + + if node_id_strs: + redis_client.hdel(cache_key, *node_id_strs) + redis_client.hdel(raw_id_key, *raw_node_ids) + + def get_node_id_from_raw_id( redis_client: Redis, source: DocumentSource, diff --git a/backend/onyx/server/documents/connector.py b/backend/onyx/server/documents/connector.py index 20fb2a07a41..fc321285bbb 100644 --- a/backend/onyx/server/documents/connector.py +++ b/backend/onyx/server/documents/connector.py @@ -92,6 +92,7 @@ from onyx.db.connector_credential_pair import ( get_connector_credential_pairs_for_user_parallel, ) +from onyx.db.connector_credential_pair import verify_user_has_access_to_cc_pair from onyx.db.credentials import cleanup_gmail_credentials from onyx.db.credentials import cleanup_google_drive_credentials from onyx.db.credentials import create_credential @@ -572,6 +573,43 @@ def _normalize_file_names_for_backwards_compatibility( return file_names + file_locations[len(file_names) :] +def _fetch_and_check_file_connector_cc_pair_permissions( + connector_id: int, + user: User, + db_session: Session, + require_editable: bool, +) -> ConnectorCredentialPair: + cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id) + if cc_pair is None: + raise HTTPException( + status_code=404, + detail="No Connector-Credential Pair found for this connector", + ) + + has_requested_access = verify_user_has_access_to_cc_pair( + cc_pair_id=cc_pair.id, + db_session=db_session, + user=user, + get_editable=require_editable, + ) + if has_requested_access: + return cc_pair + + # Special case: global curators should be able to manage files + # for public file connectors even when they are not the creator. + if ( + require_editable + and user.role == UserRole.GLOBAL_CURATOR + and cc_pair.access_type == AccessType.PUBLIC + ): + return cc_pair + + raise HTTPException( + status_code=403, + detail="Access denied. User cannot manage files for this connector.", + ) + + @router.post("/admin/connector/file/upload", tags=PUBLIC_API_TAGS) def upload_files_api( files: list[UploadFile], @@ -583,7 +621,7 @@ def upload_files_api( @router.get("/admin/connector/{connector_id}/files", tags=PUBLIC_API_TAGS) def list_connector_files( connector_id: int, - user: User = Depends(current_curator_or_admin_user), # noqa: ARG001 + user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> ConnectorFilesResponse: """List all files in a file connector.""" @@ -596,6 +634,13 @@ def list_connector_files( status_code=400, detail="This endpoint only works with file connectors" ) + _ = _fetch_and_check_file_connector_cc_pair_permissions( + connector_id=connector_id, + user=user, + db_session=db_session, + require_editable=False, + ) + file_locations = connector.connector_specific_config.get("file_locations", []) file_names = connector.connector_specific_config.get("file_names", []) @@ -645,7 +690,7 @@ def update_connector_files( connector_id: int, files: list[UploadFile] | None = File(None), file_ids_to_remove: str = Form("[]"), - user: User = Depends(current_curator_or_admin_user), # noqa: ARG001 + user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> FileUploadResponse: """ @@ -663,12 +708,13 @@ def update_connector_files( ) # Get the connector-credential pair for indexing/pruning triggers - cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id) - if cc_pair is None: - raise HTTPException( - status_code=404, - detail="No Connector-Credential Pair found for this connector", - ) + # and validate user permissions for file management. + cc_pair = _fetch_and_check_file_connector_cc_pair_permissions( + connector_id=connector_id, + user=user, + db_session=db_session, + require_editable=True, + ) # Parse file IDs to remove try: @@ -1859,7 +1905,7 @@ def get_connector_by_id( @router.post("/connector-request") def submit_connector_request( request_data: ConnectorRequestSubmission, - user: User | None = Depends(current_user), + user: User = Depends(current_user), ) -> StatusResponse: """ Submit a connector request for Cloud deployments. @@ -1872,7 +1918,7 @@ def submit_connector_request( raise HTTPException(status_code=400, detail="Connector name cannot be empty") # Get user identifier for telemetry - user_email = user.email if user else None + user_email = user.email distinct_id = user_email or tenant_id # Track connector request via PostHog telemetry (Cloud only) diff --git a/backend/onyx/server/features/build/api/api.py b/backend/onyx/server/features/build/api/api.py index ad8b8cec919..038a4b82252 100644 --- a/backend/onyx/server/features/build/api/api.py +++ b/backend/onyx/server/features/build/api/api.py @@ -1,3 +1,4 @@ +import re from collections.abc import Iterator from pathlib import Path from uuid import UUID @@ -40,6 +41,9 @@ logger = setup_logger() +_TEMPLATES_DIR = Path(__file__).parent / "templates" +_WEBAPP_HMR_FIXER_TEMPLATE = (_TEMPLATES_DIR / "webapp_hmr_fixer.js").read_text() + def require_onyx_craft_enabled(user: User = Depends(current_user)) -> User: """ @@ -239,18 +243,62 @@ def _stream_response(response: httpx.Response) -> Iterator[bytes]: yield chunk +def _inject_hmr_fixer(content: bytes, session_id: str) -> bytes: + """Inject a script that stubs root-scoped Next HMR websocket connections.""" + base = f"/api/build/sessions/{session_id}/webapp" + script = f"" + text = content.decode("utf-8") + text = re.sub( + r"(]*>)", + lambda m: m.group(0) + script, + text, + count=1, + flags=re.IGNORECASE, + ) + return text.encode("utf-8") + + def _rewrite_asset_paths(content: bytes, session_id: str) -> bytes: """Rewrite Next.js asset paths to go through the proxy.""" - import re - - # Base path includes session_id for routing webapp_base_path = f"/api/build/sessions/{session_id}/webapp" + escaped_webapp_base_path = webapp_base_path.replace("/", r"\/") + hmr_paths = ("/_next/webpack-hmr", "/_next/hmr") text = content.decode("utf-8") - # Rewrite /_next/ paths to go through our proxy - text = text.replace("/_next/", f"{webapp_base_path}/_next/") - # Rewrite JSON data file fetch paths (e.g., /data.json, /data/tickets.json) - # Matches paths like "/filename.json" or "/path/to/file.json" + # Anchor on delimiter so already-prefixed URLs (from assetPrefix) aren't double-rewritten. + for delim in ('"', "'", "("): + text = text.replace(f"{delim}/_next/", f"{delim}{webapp_base_path}/_next/") + text = re.sub( + rf"{re.escape(delim)}https?://[^/\"')]+/_next/", + f"{delim}{webapp_base_path}/_next/", + text, + ) + text = re.sub( + rf"{re.escape(delim)}wss?://[^/\"')]+/_next/", + f"{delim}{webapp_base_path}/_next/", + text, + ) + text = text.replace(r"\/_next\/", rf"{escaped_webapp_base_path}\/_next\/") + text = re.sub( + r"https?:\\\/\\\/[^\"']+?\\\/_next\\\/", + rf"{escaped_webapp_base_path}\/_next\/", + text, + ) + text = re.sub( + r"wss?:\\\/\\\/[^\"']+?\\\/_next\\\/", + rf"{escaped_webapp_base_path}\/_next\/", + text, + ) + for hmr_path in hmr_paths: + escaped_hmr_path = hmr_path.replace("/", r"\/") + text = text.replace( + f"{webapp_base_path}{hmr_path}", + hmr_path, + ) + text = text.replace( + f"{escaped_webapp_base_path}{escaped_hmr_path}", + escaped_hmr_path, + ) text = re.sub( r'"(/(?:[a-zA-Z0-9_-]+/)*[a-zA-Z0-9_-]+\.json)"', f'"{webapp_base_path}\\1"', @@ -261,11 +309,29 @@ def _rewrite_asset_paths(content: bytes, session_id: str) -> bytes: f"'{webapp_base_path}\\1'", text, ) - # Rewrite favicon text = text.replace('"/favicon.ico', f'"{webapp_base_path}/favicon.ico') return text.encode("utf-8") +def _rewrite_proxy_response_headers( + headers: dict[str, str], session_id: str +) -> dict[str, str]: + """Rewrite response headers that can leak root-scoped asset URLs.""" + link = headers.get("link") + if link: + webapp_base_path = f"/api/build/sessions/{session_id}/webapp" + rewritten_link = re.sub( + r"]+/_next/", + f"<{webapp_base_path}/_next/", + link, + ) + rewritten_link = rewritten_link.replace( + " Response: @@ -399,6 +470,7 @@ def _offline_html_response() -> Response: Design mirrors the default Craft web template (outputs/web/app/page.tsx): terminal window aesthetic with Minecraft-themed typing animation. + """ html = _OFFLINE_HTML_PATH.read_text() return Response(content=html, status_code=503, media_type="text/html") diff --git a/backend/onyx/server/features/build/api/messages_api.py b/backend/onyx/server/features/build/api/messages_api.py index 50164d109ca..ac4b194bd17 100644 --- a/backend/onyx/server/features/build/api/messages_api.py +++ b/backend/onyx/server/features/build/api/messages_api.py @@ -57,9 +57,6 @@ def list_messages( db_session: Session = Depends(get_session), ) -> MessageListResponse: """Get all messages for a build session.""" - if user is None: - raise HTTPException(status_code=401, detail="Authentication required") - session_manager = SessionManager(db_session) messages = session_manager.list_messages(session_id, user.id) diff --git a/backend/onyx/server/features/build/api/sessions_api.py b/backend/onyx/server/features/build/api/sessions_api.py index ecdbf1682e3..2ec10da95f4 100644 --- a/backend/onyx/server/features/build/api/sessions_api.py +++ b/backend/onyx/server/features/build/api/sessions_api.py @@ -732,7 +732,7 @@ def get_webapp_info( return WebappInfo(**webapp_info) -@router.get("/{session_id}/webapp/download") +@router.get("/{session_id}/webapp-download") def download_webapp( session_id: UUID, user: User = Depends(current_user), diff --git a/backend/onyx/server/features/build/api/templates/webapp_hmr_fixer.js b/backend/onyx/server/features/build/api/templates/webapp_hmr_fixer.js new file mode 100644 index 00000000000..a45f67aedbc --- /dev/null +++ b/backend/onyx/server/features/build/api/templates/webapp_hmr_fixer.js @@ -0,0 +1,135 @@ +(function () { + var WEBAPP_BASE = "__WEBAPP_BASE__"; + var PROXIED_NEXT_PREFIX = WEBAPP_BASE + "/_next/"; + var PROXIED_HMR_PREFIX = WEBAPP_BASE + "/_next/webpack-hmr"; + var PROXIED_ALT_HMR_PREFIX = WEBAPP_BASE + "/_next/hmr"; + + function isHmrWebSocketUrl(url) { + if (!url) return false; + try { + var parsedUrl = new URL(String(url), window.location.href); + return ( + parsedUrl.pathname.indexOf("/_next/webpack-hmr") === 0 || + parsedUrl.pathname.indexOf("/_next/hmr") === 0 || + parsedUrl.pathname.indexOf(PROXIED_HMR_PREFIX) === 0 || + parsedUrl.pathname.indexOf(PROXIED_ALT_HMR_PREFIX) === 0 + ); + } catch (e) {} + if (typeof url === "string") { + return ( + url.indexOf("/_next/webpack-hmr") === 0 || + url.indexOf("/_next/hmr") === 0 || + url.indexOf(PROXIED_HMR_PREFIX) === 0 || + url.indexOf(PROXIED_ALT_HMR_PREFIX) === 0 + ); + } + return false; + } + + function rewriteNextAssetUrl(url) { + if (!url) return url; + try { + var parsedUrl = new URL(String(url), window.location.href); + if (parsedUrl.pathname.indexOf(PROXIED_NEXT_PREFIX) === 0) { + return parsedUrl.pathname + parsedUrl.search + parsedUrl.hash; + } + if (parsedUrl.pathname.indexOf("/_next/") === 0) { + return ( + WEBAPP_BASE + parsedUrl.pathname + parsedUrl.search + parsedUrl.hash + ); + } + } catch (e) {} + if (typeof url === "string") { + if (url.indexOf(PROXIED_NEXT_PREFIX) === 0) { + return url; + } + if (url.indexOf("/_next/") === 0) { + return WEBAPP_BASE + url; + } + } + return url; + } + + function createEvent(eventType) { + return typeof Event === "function" + ? new Event(eventType) + : { type: eventType }; + } + + function MockHmrWebSocket(url) { + this.url = String(url); + this.readyState = 1; + this.bufferedAmount = 0; + this.extensions = ""; + this.protocol = ""; + this.binaryType = "blob"; + this.onopen = null; + this.onmessage = null; + this.onerror = null; + this.onclose = null; + this._l = {}; + var socket = this; + setTimeout(function () { + socket._d("open", createEvent("open")); + }, 0); + } + + MockHmrWebSocket.CONNECTING = 0; + MockHmrWebSocket.OPEN = 1; + MockHmrWebSocket.CLOSING = 2; + MockHmrWebSocket.CLOSED = 3; + + MockHmrWebSocket.prototype.addEventListener = function (eventType, callback) { + (this._l[eventType] || (this._l[eventType] = [])).push(callback); + }; + + MockHmrWebSocket.prototype.removeEventListener = function ( + eventType, + callback, + ) { + var listeners = this._l[eventType] || []; + this._l[eventType] = listeners.filter(function (listener) { + return listener !== callback; + }); + }; + + MockHmrWebSocket.prototype._d = function (eventType, eventValue) { + var listeners = this._l[eventType] || []; + for (var i = 0; i < listeners.length; i++) { + listeners[i].call(this, eventValue); + } + var handler = this["on" + eventType]; + if (typeof handler === "function") { + handler.call(this, eventValue); + } + }; + + MockHmrWebSocket.prototype.send = function () {}; + + MockHmrWebSocket.prototype.close = function (code, reason) { + if (this.readyState >= 2) return; + this.readyState = 3; + var closeEvent = createEvent("close"); + closeEvent.code = code === undefined ? 1000 : code; + closeEvent.reason = reason || ""; + closeEvent.wasClean = true; + this._d("close", closeEvent); + }; + + if (window.WebSocket) { + var OriginalWebSocket = window.WebSocket; + window.WebSocket = function (url, protocols) { + if (isHmrWebSocketUrl(url)) { + return new MockHmrWebSocket(rewriteNextAssetUrl(url)); + } + return protocols === undefined + ? new OriginalWebSocket(url) + : new OriginalWebSocket(url, protocols); + }; + window.WebSocket.prototype = OriginalWebSocket.prototype; + Object.setPrototypeOf(window.WebSocket, OriginalWebSocket); + ["CONNECTING", "OPEN", "CLOSING", "CLOSED"].forEach(function (stateKey) { + window.WebSocket[stateKey] = OriginalWebSocket[stateKey]; + }); + } +})(); diff --git a/backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web/package-lock.json b/backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web/package-lock.json index 1cbc5a3a508..fbe6753d3d9 100644 --- a/backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web/package-lock.json +++ b/backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web/package-lock.json @@ -961,9 +961,9 @@ "license": "MIT" }, "node_modules/@hono/node-server": { - "version": "1.19.9", - "resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.9.tgz", - "integrity": "sha512-vHL6w3ecZsky+8P5MD+eFfaGTyCeOHUIFYMGpQGbrBTSmNNoxv0if69rEZ5giu36weC5saFuznL411gRX7bJDw==", + "version": "1.19.10", + "resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.10.tgz", + "integrity": "sha512-hZ7nOssGqRgyV3FVVQdfi+U4q02uB23bpnYpdvNXkYTRRyWx84b7yf1ans+dnJ/7h41sGL3CeQTfO+ZGxuO+Iw==", "license": "MIT", "engines": { "node": ">=18.14.1" @@ -1573,27 +1573,6 @@ } } }, - "node_modules/@isaacs/balanced-match": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/@isaacs/balanced-match/-/balanced-match-4.0.1.tgz", - "integrity": "sha512-yzMTt9lEb8Gv7zRioUilSglI0c0smZ9k5D65677DLWLtWJaXIS3CqcGyUFByYKlnUj6TkjLVs54fBl6+TiGQDQ==", - "license": "MIT", - "engines": { - "node": "20 || >=22" - } - }, - "node_modules/@isaacs/brace-expansion": { - "version": "5.0.1", - "resolved": "https://registry.npmjs.org/@isaacs/brace-expansion/-/brace-expansion-5.0.1.tgz", - "integrity": "sha512-WMz71T1JS624nWj2n2fnYAuPovhv7EUhk69R6i9dsVyzxt5eM3bjwvgk9L+APE1TRscGysAVMANkB0jh0LQZrQ==", - "license": "MIT", - "dependencies": { - "@isaacs/balanced-match": "^4.0.1" - }, - "engines": { - "node": "20 || >=22" - } - }, "node_modules/@jridgewell/gen-mapping": { "version": "0.3.13", "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.13.tgz", @@ -1680,9 +1659,9 @@ } }, "node_modules/@modelcontextprotocol/sdk/node_modules/ajv": { - "version": "8.17.1", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", - "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", + "version": "8.18.0", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz", + "integrity": "sha512-PlXPeEWMXMZ7sPYOHqmDyCJzcfNrUr3fGNKtezX14ykXOEIvyK81d+qydx89KY5O71FKMPaQ2vBfBFI5NHR63A==", "license": "MIT", "dependencies": { "fast-deep-equal": "^3.1.3", @@ -3855,6 +3834,27 @@ "path-browserify": "^1.0.1" } }, + "node_modules/@ts-morph/common/node_modules/balanced-match": { + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-4.0.4.tgz", + "integrity": "sha512-BLrgEcRTwX2o6gGxGOCNyMvGSp35YofuYzw9h1IMTRmKqttAZZVU67bdb9Pr2vUHA8+j3i2tJfjO6C6+4myGTA==", + "license": "MIT", + "engines": { + "node": "18 || 20 || >=22" + } + }, + "node_modules/@ts-morph/common/node_modules/brace-expansion": { + "version": "5.0.3", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.3.tgz", + "integrity": "sha512-fy6KJm2RawA5RcHkLa1z/ScpBeA762UF9KmZQxwIbDtRJrgLzM10depAiEQ+CXYcoiqW1/m96OAAoke2nE9EeA==", + "license": "MIT", + "dependencies": { + "balanced-match": "^4.0.2" + }, + "engines": { + "node": "18 || 20 || >=22" + } + }, "node_modules/@ts-morph/common/node_modules/fast-glob": { "version": "3.3.3", "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz", @@ -3884,15 +3884,15 @@ } }, "node_modules/@ts-morph/common/node_modules/minimatch": { - "version": "10.1.1", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-10.1.1.tgz", - "integrity": "sha512-enIvLvRAFZYXJzkCYG5RKmPfrFArdLv+R+lbQ53BmIMLIry74bjKzX6iHAm8WYamJkhSSEabrWN5D97XnKObjQ==", + "version": "10.2.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-10.2.4.tgz", + "integrity": "sha512-oRjTw/97aTBN0RHbYCdtF1MQfvusSIBQM0IZEgzl6426+8jSC0nF1a/GmnVLpfB9yyr6g6FTqWqiZVbxrtaCIg==", "license": "BlueOak-1.0.0", "dependencies": { - "@isaacs/brace-expansion": "^5.0.0" + "brace-expansion": "^5.0.2" }, "engines": { - "node": "20 || >=22" + "node": "18 || 20 || >=22" }, "funding": { "url": "https://github.com/sponsors/isaacs" @@ -4234,13 +4234,13 @@ } }, "node_modules/@typescript-eslint/typescript-estree/node_modules/minimatch": { - "version": "9.0.5", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", - "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "version": "9.0.9", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.9.tgz", + "integrity": "sha512-OBwBN9AL4dqmETlpS2zasx+vTeWclWzkblfZk7KTA5j3jeOONz/tRCnZomUyvNg83wL5Zv9Ss6HMJXAgL8R2Yg==", "dev": true, "license": "ISC", "dependencies": { - "brace-expansion": "^2.0.1" + "brace-expansion": "^2.0.2" }, "engines": { "node": ">=16 || 14 >=14.17" @@ -4619,9 +4619,9 @@ } }, "node_modules/ajv": { - "version": "6.12.6", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", - "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "version": "6.14.0", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.14.0.tgz", + "integrity": "sha512-IWrosm/yrn43eiKqkfkHis7QioDleaXQHdDVPKg0FSwwd/DuvyX79TZnFOnYpB7dcsFAMmtFztZuXPDvSePkFw==", "dev": true, "license": "MIT", "dependencies": { @@ -4653,9 +4653,9 @@ } }, "node_modules/ajv-formats/node_modules/ajv": { - "version": "8.17.1", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", - "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", + "version": "8.18.0", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz", + "integrity": "sha512-PlXPeEWMXMZ7sPYOHqmDyCJzcfNrUr3fGNKtezX14ykXOEIvyK81d+qydx89KY5O71FKMPaQ2vBfBFI5NHR63A==", "license": "MIT", "dependencies": { "fast-deep-equal": "^3.1.3", @@ -6758,12 +6758,12 @@ } }, "node_modules/express-rate-limit": { - "version": "8.2.1", - "resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.2.1.tgz", - "integrity": "sha512-PCZEIEIxqwhzw4KF0n7QF4QqruVTcF73O5kFKUnGOyjbCCgizBBiFaYpd/fnBLUMPw/BWw9OsiN7GgrNYr7j6g==", + "version": "8.3.0", + "resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.3.0.tgz", + "integrity": "sha512-KJzBawY6fB9FiZGdE/0aftepZ91YlaGIrV8vgblRM3J8X+dHx/aiowJWwkx6LIGyuqGiANsjSwwrbb8mifOJ4Q==", "license": "MIT", "dependencies": { - "ip-address": "10.0.1" + "ip-address": "10.1.0" }, "engines": { "node": ">= 16" @@ -7424,9 +7424,9 @@ } }, "node_modules/hono": { - "version": "4.11.7", - "resolved": "https://registry.npmjs.org/hono/-/hono-4.11.7.tgz", - "integrity": "sha512-l7qMiNee7t82bH3SeyUCt9UF15EVmaBvsppY2zQtrbIhl/yzBTny+YUxsVjSjQ6gaqaeVtZmGocom8TzBlA4Yw==", + "version": "4.12.7", + "resolved": "https://registry.npmjs.org/hono/-/hono-4.12.7.tgz", + "integrity": "sha512-jq9l1DM0zVIvsm3lv9Nw9nlJnMNPOcAtsbsgiUhWcFzPE99Gvo6yRTlszSLLYacMeQ6quHD6hMfId8crVHvexw==", "license": "MIT", "engines": { "node": ">=16.9.0" @@ -7556,9 +7556,9 @@ } }, "node_modules/ip-address": { - "version": "10.0.1", - "resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.0.1.tgz", - "integrity": "sha512-NWv9YLW4PoW2B7xtzaS3NCot75m6nK7Icdv0o3lfMceJVRfSoQwqD4wEH5rLwoKJwUiZ/rfpiVBhnaF0FK4HoA==", + "version": "10.1.0", + "resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.1.0.tgz", + "integrity": "sha512-XXADHxXmvT9+CRxhXg56LJovE+bmWnEWB78LB83VZTprKTmaC5QfruXocxzTZ2Kl0DNwKuBdlIhjL8LeY8Sf8Q==", "license": "MIT", "engines": { "node": ">= 12" @@ -8831,9 +8831,9 @@ } }, "node_modules/minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz", + "integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==", "dev": true, "license": "ISC", "dependencies": { @@ -9699,9 +9699,9 @@ } }, "node_modules/qs": { - "version": "6.14.1", - "resolved": "https://registry.npmjs.org/qs/-/qs-6.14.1.tgz", - "integrity": "sha512-4EK3+xJl8Ts67nLYNwqw/dsFVnCf+qR7RgXSK9jEEm9unao3njwMDdmsdvoKBKHzxd7tCYz5e5M+SnMjdtXGQQ==", + "version": "6.14.2", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.14.2.tgz", + "integrity": "sha512-V/yCWTTF7VJ9hIh18Ugr2zhJMP01MY7c5kh4J870L7imm6/DIzBsNLTXzMwUA3yZ5b/KBqLx8Kp3uRvd7xSe3Q==", "license": "BSD-3-Clause", "dependencies": { "side-channel": "^1.1.0" diff --git a/backend/onyx/server/features/build/sandbox/kubernetes/kubernetes_sandbox_manager.py b/backend/onyx/server/features/build/sandbox/kubernetes/kubernetes_sandbox_manager.py index c6a20290ff5..8fe3e9ef3bd 100644 --- a/backend/onyx/server/features/build/sandbox/kubernetes/kubernetes_sandbox_manager.py +++ b/backend/onyx/server/features/build/sandbox/kubernetes/kubernetes_sandbox_manager.py @@ -1133,7 +1133,8 @@ def _cleanup_kubernetes_resources( # Already deleted service_deleted = True else: - logger.warning(f"Error deleting Service {service_name}: {e}") + logger.error(f"Error deleting Service {service_name}: {e}") + raise pod_deleted = False try: @@ -1148,7 +1149,8 @@ def _cleanup_kubernetes_resources( # Already deleted pod_deleted = True else: - logger.warning(f"Error deleting Pod {pod_name}: {e}") + logger.error(f"Error deleting Pod {pod_name}: {e}") + raise # Wait for resources to be fully deleted to prevent 409 conflicts # on immediate re-provisioning diff --git a/backend/onyx/server/features/build/sandbox/tasks/tasks.py b/backend/onyx/server/features/build/sandbox/tasks/tasks.py index 1b335146486..b62fccacd9d 100644 --- a/backend/onyx/server/features/build/sandbox/tasks/tasks.py +++ b/backend/onyx/server/features/build/sandbox/tasks/tasks.py @@ -80,7 +80,7 @@ def cleanup_idle_sandboxes_task(self: Task, *, tenant_id: str) -> None: # noqa: # Prevent overlapping runs of this task if not lock.acquire(blocking=False): - task_logger.debug("cleanup_idle_sandboxes_task - lock not acquired, skipping") + task_logger.info("cleanup_idle_sandboxes_task - lock not acquired, skipping") return try: diff --git a/backend/onyx/server/features/document_set/api.py b/backend/onyx/server/features/document_set/api.py index 262e1ba649b..8e71cf689d9 100644 --- a/backend/onyx/server/features/document_set/api.py +++ b/backend/onyx/server/features/document_set/api.py @@ -11,6 +11,7 @@ from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryTask from onyx.db.document_set import check_document_sets_are_public +from onyx.db.document_set import delete_document_set as db_delete_document_set from onyx.db.document_set import fetch_all_document_sets_for_user from onyx.db.document_set import get_document_set_by_id from onyx.db.document_set import insert_document_set @@ -142,7 +143,10 @@ def delete_document_set( except Exception as e: raise HTTPException(status_code=400, detail=str(e)) - if not DISABLE_VECTOR_DB: + if DISABLE_VECTOR_DB: + db_session.refresh(document_set) + db_delete_document_set(document_set, db_session) + else: client_app.send_task( OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK, kwargs={"tenant_id": tenant_id}, diff --git a/backend/onyx/server/features/hierarchy/api.py b/backend/onyx/server/features/hierarchy/api.py index 4163810e732..12382b0070b 100644 --- a/backend/onyx/server/features/hierarchy/api.py +++ b/backend/onyx/server/features/hierarchy/api.py @@ -54,18 +54,14 @@ def _require_opensearch(db_session: Session) -> None: ) -def _get_user_access_info( - user: User | None, db_session: Session -) -> tuple[str | None, list[str]]: - if not user: - return None, [] +def _get_user_access_info(user: User, db_session: Session) -> tuple[str, list[str]]: return user.email, get_user_external_group_ids(db_session, user) @router.get(HIERARCHY_NODES_LIST_PATH) def list_accessible_hierarchy_nodes( source: DocumentSource, - user: User | None = Depends(current_user), + user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> HierarchyNodesResponse: _require_opensearch(db_session) @@ -92,7 +88,7 @@ def list_accessible_hierarchy_nodes( @router.post(HIERARCHY_NODE_DOCUMENTS_PATH) def list_accessible_hierarchy_node_documents( documents_request: HierarchyNodeDocumentsRequest, - user: User | None = Depends(current_user), + user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> HierarchyNodeDocumentsResponse: _require_opensearch(db_session) diff --git a/backend/onyx/server/features/mcp/api.py b/backend/onyx/server/features/mcp/api.py index cc92cfb1dde..7c6db29d05f 100644 --- a/backend/onyx/server/features/mcp/api.py +++ b/backend/onyx/server/features/mcp/api.py @@ -1013,7 +1013,7 @@ def get_mcp_servers_for_assistant( @router.get("/servers", response_model=MCPServersResponse) def get_mcp_servers_for_user( db: Session = Depends(get_session), - user: User | None = Depends(current_user), + user: User = Depends(current_user), ) -> MCPServersResponse: """List all MCP servers for use in agent configuration and chat UI. diff --git a/backend/onyx/server/features/projects/api.py b/backend/onyx/server/features/projects/api.py index 5ad7a7a6089..bcd55003433 100644 --- a/backend/onyx/server/features/projects/api.py +++ b/backend/onyx/server/features/projects/api.py @@ -2,6 +2,7 @@ from uuid import UUID from fastapi import APIRouter +from fastapi import BackgroundTasks from fastapi import Depends from fastapi import File from fastapi import Form @@ -12,13 +13,7 @@ from sqlalchemy.orm import Session from onyx.auth.users import current_user -from onyx.background.celery.tasks.user_file_processing.tasks import ( - enqueue_user_file_project_sync_task, -) -from onyx.background.celery.tasks.user_file_processing.tasks import ( - get_user_file_project_sync_queue_depth, -) -from onyx.background.celery.versioned_apps.client import app as client_app +from onyx.configs.app_configs import DISABLE_VECTOR_DB from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask @@ -34,7 +29,6 @@ from onyx.db.persona import get_personas_by_ids from onyx.db.projects import get_project_token_count from onyx.db.projects import upload_files_to_user_files_with_indexing -from onyx.redis.redis_pool import get_redis_client from onyx.server.features.projects.models import CategorizedFilesSnapshot from onyx.server.features.projects.models import ChatSessionRequest from onyx.server.features.projects.models import TokenCountResponse @@ -55,7 +49,27 @@ class UserFileDeleteResult(BaseModel): assistant_names: list[str] = [] -def _trigger_user_file_project_sync(user_file_id: UUID, tenant_id: str) -> None: +def _trigger_user_file_project_sync( + user_file_id: UUID, + tenant_id: str, + background_tasks: BackgroundTasks | None = None, +) -> None: + if DISABLE_VECTOR_DB and background_tasks is not None: + from onyx.background.task_utils import drain_project_sync_loop + + background_tasks.add_task(drain_project_sync_loop, tenant_id) + logger.info(f"Queued in-process project sync for user_file_id={user_file_id}") + return + + from onyx.background.celery.tasks.user_file_processing.tasks import ( + enqueue_user_file_project_sync_task, + ) + from onyx.background.celery.tasks.user_file_processing.tasks import ( + get_user_file_project_sync_queue_depth, + ) + from onyx.background.celery.versioned_apps.client import app as client_app + from onyx.redis.redis_pool import get_redis_client + queue_depth = get_user_file_project_sync_queue_depth(client_app) if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH: logger.warning( @@ -111,6 +125,7 @@ def create_project( @router.post("/file/upload", tags=PUBLIC_API_TAGS) def upload_user_files( + bg_tasks: BackgroundTasks, files: list[UploadFile] = File(...), project_id: int | None = Form(None), temp_id_map: str | None = Form(None), # JSON string mapping hashed key -> temp_id @@ -137,12 +152,12 @@ def upload_user_files( user=user, temp_id_map=parsed_temp_id_map, db_session=db_session, + background_tasks=bg_tasks if DISABLE_VECTOR_DB else None, ) return CategorizedFilesSnapshot.from_result(categorized_files_result) except Exception as e: - # Log error with type, message, and stack for easier debugging logger.exception(f"Error uploading files - {type(e).__name__}: {str(e)}") raise HTTPException( status_code=500, @@ -192,6 +207,7 @@ def get_files_in_project( def unlink_user_file_from_project( project_id: int, file_id: UUID, + bg_tasks: BackgroundTasks, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> Response: @@ -208,7 +224,6 @@ def unlink_user_file_from_project( if project is None: raise HTTPException(status_code=404, detail="Project not found") - user_id = user.id user_file = ( db_session.query(UserFile) .filter(UserFile.id == file_id, UserFile.user_id == user_id) @@ -224,7 +239,7 @@ def unlink_user_file_from_project( db_session.commit() tenant_id = get_current_tenant_id() - _trigger_user_file_project_sync(user_file.id, tenant_id) + _trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks) return Response(status_code=204) @@ -237,6 +252,7 @@ def unlink_user_file_from_project( def link_user_file_to_project( project_id: int, file_id: UUID, + bg_tasks: BackgroundTasks, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> UserFileSnapshot: @@ -268,7 +284,7 @@ def link_user_file_to_project( db_session.commit() tenant_id = get_current_tenant_id() - _trigger_user_file_project_sync(user_file.id, tenant_id) + _trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks) return UserFileSnapshot.from_model(user_file) @@ -424,6 +440,7 @@ def delete_project( @router.delete("/file/{file_id}", tags=PUBLIC_API_TAGS) def delete_user_file( file_id: UUID, + bg_tasks: BackgroundTasks, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> UserFileDeleteResult: @@ -456,15 +473,25 @@ def delete_user_file( db_session.commit() tenant_id = get_current_tenant_id() - task = client_app.send_task( - OnyxCeleryTask.DELETE_SINGLE_USER_FILE, - kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id}, - queue=OnyxCeleryQueues.USER_FILE_DELETE, - priority=OnyxCeleryPriority.HIGH, - ) - logger.info( - f"Triggered delete for user_file_id={user_file.id} with task_id={task.id}" - ) + if DISABLE_VECTOR_DB: + from onyx.background.task_utils import drain_delete_loop + + bg_tasks.add_task(drain_delete_loop, tenant_id) + logger.info(f"Queued in-process delete for user_file_id={user_file.id}") + else: + from onyx.background.celery.versioned_apps.client import app as client_app + + task = client_app.send_task( + OnyxCeleryTask.DELETE_SINGLE_USER_FILE, + kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id}, + queue=OnyxCeleryQueues.USER_FILE_DELETE, + priority=OnyxCeleryPriority.HIGH, + ) + logger.info( + f"Triggered delete for user_file_id={user_file.id} " + f"with task_id={task.id}" + ) + return UserFileDeleteResult( has_associations=False, project_names=[], assistant_names=[] ) diff --git a/backend/onyx/server/features/projects/projects_file_utils.py b/backend/onyx/server/features/projects/projects_file_utils.py index 9237348461d..1eaaa429f79 100644 --- a/backend/onyx/server/features/projects/projects_file_utils.py +++ b/backend/onyx/server/features/projects/projects_file_utils.py @@ -7,13 +7,16 @@ from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field +from sqlalchemy.orm import Session from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD +from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES +from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB +from onyx.db.llm import fetch_default_llm_model from onyx.file_processing.extract_file_text import extract_file_text from onyx.file_processing.extract_file_text import get_file_ext from onyx.file_processing.file_types import OnyxFileExtensions from onyx.file_processing.password_validation import is_file_password_protected -from onyx.llm.factory import get_default_llm from onyx.natural_language_processing.utils import get_tokenizer from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT @@ -34,6 +37,38 @@ def get_safe_filename(upload: UploadFile) -> str: return upload.filename +def get_upload_size_bytes(upload: UploadFile) -> int | None: + """Best-effort file size in bytes without consuming the stream.""" + if upload.size is not None: + return upload.size + + try: + current_pos = upload.file.tell() + upload.file.seek(0, 2) + size = upload.file.tell() + upload.file.seek(current_pos) + return size + except Exception as e: + logger.warning( + "Could not determine upload size via stream seek " + f"(filename='{get_safe_filename(upload)}', " + f"error_type={type(e).__name__}, error={e})" + ) + return None + + +def is_upload_too_large(upload: UploadFile, max_bytes: int) -> bool: + """Return True when upload size is known and exceeds max_bytes.""" + size_bytes = get_upload_size_bytes(upload) + if size_bytes is None: + logger.warning( + "Could not determine upload size; skipping size-limit check for " + f"'{get_safe_filename(upload)}'" + ) + return False + return size_bytes > max_bytes + + # Guard against extremely large images Image.MAX_IMAGE_PIXELS = 12000 * 12000 @@ -116,23 +151,28 @@ def estimate_image_tokens_for_upload( pass -def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles: +def categorize_uploaded_files( + files: list[UploadFile], db_session: Session +) -> CategorizedFiles: """ Categorize uploaded files based on text extractability and tokenized length. - - Extracts text using extract_file_text for supported plain/document extensions. + - Images are estimated for token cost via a patch-based heuristic. + - All other files are run through extract_file_text, which handles known + document formats (.pdf, .docx, …) and falls back to a text-detection + heuristic for unknown extensions (.py, .js, .rs, …). - Uses default tokenizer to compute token length. - - If token length > 100,000, reject file (unless threshold skip is enabled). - - If extension unsupported or text cannot be extracted, reject file. + - If token length > threshold, reject file (unless threshold skip is enabled). + - If text cannot be extracted, reject file. - Otherwise marked as acceptable. """ results = CategorizedFiles() - llm = get_default_llm() + default_model = fetch_default_llm_model(db_session) - tokenizer = get_tokenizer( - model_name=llm.config.model_name, provider_type=llm.config.model_provider - ) + model_name = default_model.name if default_model else None + provider_type = default_model.llm_provider.provider if default_model else None + tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type) # Check if threshold checks should be skipped skip_threshold = False @@ -156,6 +196,18 @@ def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles: for upload in files: try: filename = get_safe_filename(upload) + + # Size limit is a hard safety cap and is enforced even when token + # threshold checks are skipped via SKIP_USERFILE_THRESHOLD settings. + if is_upload_too_large(upload, USER_FILE_MAX_UPLOAD_SIZE_BYTES): + results.rejected.append( + RejectedFile( + filename=filename, + reason=f"Exceeds {USER_FILE_MAX_UPLOAD_SIZE_MB} MB file size limit", + ) + ) + continue + extension = get_file_ext(filename) # If image, estimate tokens via dedicated method first @@ -168,8 +220,7 @@ def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles: ) results.rejected.append( RejectedFile( - filename=filename, - reason=f"Unsupported file type: {extension}", + filename=filename, reason="Unsupported file contents" ) ) continue @@ -186,8 +237,10 @@ def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles: results.acceptable_file_to_token_count[filename] = token_count continue - # Otherwise, handle as text/document: extract text and count tokens - elif extension in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS: + # Handle as text/document: attempt text extraction and count tokens. + # This accepts any file that extract_file_text can handle, including + # code files (.py, .js, .rs, etc.) via its is_text_file() fallback. + else: if is_file_password_protected( file=upload.file, file_name=filename, @@ -210,7 +263,10 @@ def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles: if not text_content: logger.warning(f"No text content extracted from '{filename}'") results.rejected.append( - RejectedFile(filename=filename, reason="Could not read file") + RejectedFile( + filename=filename, + reason=f"Unsupported file type: {extension}", + ) ) continue @@ -233,17 +289,6 @@ def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles: logger.warning( f"Failed to reset file pointer for '{filename}': {str(e)}" ) - continue - - # If not recognized as supported types above, mark unsupported - logger.warning( - f"Unsupported file extension '{extension}' for file '{filename}'" - ) - results.rejected.append( - RejectedFile( - filename=filename, reason=f"Unsupported file type: {extension}" - ) - ) except Exception as e: logger.warning( f"Failed to process uploaded file '{get_safe_filename(upload)}' (error_type={type(e).__name__}, error={str(e)})" diff --git a/backend/onyx/server/features/release_notes/utils.py b/backend/onyx/server/features/release_notes/utils.py index a6092045aff..bb99a427c84 100644 --- a/backend/onyx/server/features/release_notes/utils.py +++ b/backend/onyx/server/features/release_notes/utils.py @@ -8,10 +8,10 @@ from sqlalchemy.orm import Session from onyx import __version__ +from onyx.cache.factory import get_shared_cache_backend from onyx.configs.app_configs import INSTANCE_TYPE from onyx.configs.constants import OnyxRedisLocks from onyx.db.release_notes import create_release_notifications_for_versions -from onyx.redis.redis_pool import get_shared_redis_client from onyx.server.features.release_notes.constants import AUTO_REFRESH_THRESHOLD_SECONDS from onyx.server.features.release_notes.constants import FETCH_TIMEOUT from onyx.server.features.release_notes.constants import GITHUB_CHANGELOG_RAW_URL @@ -113,60 +113,46 @@ def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry def get_cached_etag() -> str | None: - """Get the cached GitHub ETag from Redis.""" - redis_client = get_shared_redis_client() + cache = get_shared_cache_backend() try: - etag = redis_client.get(REDIS_KEY_ETAG) + etag = cache.get(REDIS_KEY_ETAG) if etag: - return etag.decode("utf-8") if isinstance(etag, bytes) else str(etag) + return etag.decode("utf-8") return None except Exception as e: - logger.error(f"Failed to get cached etag from Redis: {e}") + logger.error(f"Failed to get cached etag: {e}") return None def get_last_fetch_time() -> datetime | None: - """Get the last fetch timestamp from Redis.""" - redis_client = get_shared_redis_client() + cache = get_shared_cache_backend() try: - fetched_at_str = redis_client.get(REDIS_KEY_FETCHED_AT) - if not fetched_at_str: + raw = cache.get(REDIS_KEY_FETCHED_AT) + if not raw: return None - decoded = ( - fetched_at_str.decode("utf-8") - if isinstance(fetched_at_str, bytes) - else str(fetched_at_str) - ) - - last_fetch = datetime.fromisoformat(decoded) - - # Defensively ensure timezone awareness - # fromisoformat() returns naive datetime if input lacks timezone + last_fetch = datetime.fromisoformat(raw.decode("utf-8")) if last_fetch.tzinfo is None: - # Assume UTC for naive datetimes last_fetch = last_fetch.replace(tzinfo=timezone.utc) else: - # Convert to UTC if timezone-aware last_fetch = last_fetch.astimezone(timezone.utc) return last_fetch except Exception as e: - logger.error(f"Failed to get last fetch time from Redis: {e}") + logger.error(f"Failed to get last fetch time from cache: {e}") return None def save_fetch_metadata(etag: str | None) -> None: - """Save ETag and fetch timestamp to Redis.""" - redis_client = get_shared_redis_client() + cache = get_shared_cache_backend() now = datetime.now(timezone.utc) try: - redis_client.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL) + cache.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL) if etag: - redis_client.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL) + cache.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL) except Exception as e: - logger.error(f"Failed to save fetch metadata to Redis: {e}") + logger.error(f"Failed to save fetch metadata to cache: {e}") def is_cache_stale() -> bool: @@ -196,11 +182,10 @@ def ensure_release_notes_fresh_and_notify(db_session: Session) -> None: if not is_cache_stale(): return - # Acquire lock to prevent concurrent fetches - redis_client = get_shared_redis_client() - lock = redis_client.lock( + cache = get_shared_cache_backend() + lock = cache.lock( OnyxRedisLocks.RELEASE_NOTES_FETCH_LOCK, - timeout=90, # 90 second timeout for the lock + timeout=90, ) # Non-blocking acquire - if we can't get the lock, another request is handling it diff --git a/backend/onyx/server/manage/embedding/api.py b/backend/onyx/server/manage/embedding/api.py index ef659a4fac5..55688d204d0 100644 --- a/backend/onyx/server/manage/embedding/api.py +++ b/backend/onyx/server/manage/embedding/api.py @@ -1,6 +1,5 @@ from fastapi import APIRouter from fastapi import Depends -from fastapi import HTTPException from sqlalchemy.orm import Session from onyx.auth.users import current_admin_user @@ -11,6 +10,8 @@ from onyx.db.models import User from onyx.db.search_settings import get_all_search_settings from onyx.db.search_settings import get_current_db_embedding_provider +from onyx.error_handling.error_codes import OnyxErrorCode +from onyx.error_handling.exceptions import OnyxError from onyx.indexing.models import EmbeddingModelDetail from onyx.natural_language_processing.search_nlp_models import EmbeddingModel from onyx.server.manage.embedding.models import CloudEmbeddingProvider @@ -59,7 +60,7 @@ def test_embedding_configuration( except Exception as e: error_msg = "An error occurred while testing your embedding model. Please check your configuration." logger.error(f"{error_msg} Error message: {e}", exc_info=True) - raise HTTPException(status_code=400, detail=error_msg) + raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg) @admin_router.get("", response_model=list[EmbeddingModelDetail]) @@ -93,8 +94,9 @@ def delete_embedding_provider( embedding_provider is not None and provider_type == embedding_provider.provider_type ): - raise HTTPException( - status_code=400, detail="You can't delete a currently active model" + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "You can't delete a currently active model", ) remove_embedding_provider(db_session, provider_type=provider_type) diff --git a/backend/onyx/server/manage/llm/api.py b/backend/onyx/server/manage/llm/api.py index 0fe5bad013f..693826c844d 100644 --- a/backend/onyx/server/manage/llm/api.py +++ b/backend/onyx/server/manage/llm/api.py @@ -11,7 +11,6 @@ from botocore.exceptions import NoCredentialsError from fastapi import APIRouter from fastapi import Depends -from fastapi import HTTPException from fastapi import Query from pydantic import ValidationError from sqlalchemy.orm import Session @@ -38,6 +37,8 @@ from onyx.db.llm import validate_persona_ids_exist from onyx.db.models import User from onyx.db.persona import user_can_access_persona +from onyx.error_handling.error_codes import OnyxErrorCode +from onyx.error_handling.exceptions import OnyxError from onyx.llm.factory import get_default_llm from onyx.llm.factory import get_llm from onyx.llm.factory import get_max_input_tokens_from_llm_provider @@ -47,6 +48,7 @@ from onyx.llm.well_known_providers.auto_update_service import ( fetch_llm_recommendations_from_github, ) +from onyx.llm.well_known_providers.constants import LM_STUDIO_API_KEY_CONFIG_KEY from onyx.llm.well_known_providers.llm_provider_options import ( fetch_available_well_known_llms, ) @@ -56,22 +58,30 @@ from onyx.server.manage.llm.models import BedrockFinalModelResponse from onyx.server.manage.llm.models import BedrockModelsRequest from onyx.server.manage.llm.models import DefaultModel +from onyx.server.manage.llm.models import LitellmFinalModelResponse +from onyx.server.manage.llm.models import LitellmModelDetails +from onyx.server.manage.llm.models import LitellmModelsRequest from onyx.server.manage.llm.models import LLMCost from onyx.server.manage.llm.models import LLMProviderDescriptor from onyx.server.manage.llm.models import LLMProviderResponse from onyx.server.manage.llm.models import LLMProviderUpsertRequest from onyx.server.manage.llm.models import LLMProviderView +from onyx.server.manage.llm.models import LMStudioFinalModelResponse +from onyx.server.manage.llm.models import LMStudioModelsRequest +from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest from onyx.server.manage.llm.models import OllamaFinalModelResponse from onyx.server.manage.llm.models import OllamaModelDetails from onyx.server.manage.llm.models import OllamaModelsRequest from onyx.server.manage.llm.models import OpenRouterFinalModelResponse from onyx.server.manage.llm.models import OpenRouterModelDetails from onyx.server.manage.llm.models import OpenRouterModelsRequest +from onyx.server.manage.llm.models import SyncModelEntry from onyx.server.manage.llm.models import TestLLMRequest from onyx.server.manage.llm.models import VisionProviderResponse from onyx.server.manage.llm.utils import generate_bedrock_display_name from onyx.server.manage.llm.utils import generate_ollama_display_name from onyx.server.manage.llm.utils import infer_vision_support +from onyx.server.manage.llm.utils import is_reasoning_model from onyx.server.manage.llm.utils import is_valid_bedrock_model from onyx.server.manage.llm.utils import ModelMetadata from onyx.server.manage.llm.utils import strip_openrouter_vendor_prefix @@ -92,6 +102,34 @@ def _mask_string(value: str) -> str: return value[:4] + "****" + value[-4:] +def _sync_fetched_models( + db_session: Session, + provider_name: str, + models: list[SyncModelEntry], + source_label: str, +) -> None: + """Sync fetched models to DB for the given provider. + + Args: + db_session: Database session + provider_name: Name of the LLM provider + models: List of SyncModelEntry objects describing the fetched models + source_label: Human-readable label for log messages (e.g. "Bedrock", "LiteLLM") + """ + try: + new_count = sync_model_configurations( + db_session=db_session, + provider_name=provider_name, + models=models, + ) + if new_count > 0: + logger.info( + f"Added {new_count} new {source_label} models to provider '{provider_name}'" + ) + except ValueError as e: + logger.warning(f"Failed to sync {source_label} models to DB: {e}") + + # Keys in custom_config that contain sensitive credentials _SENSITIVE_CONFIG_KEYS = { "vertex_credentials", @@ -186,7 +224,7 @@ def _validate_llm_provider_change( Only enforced in MULTI_TENANT mode. Raises: - HTTPException: If api_base or custom_config changed without changing API key + OnyxError: If api_base or custom_config changed without changing API key """ if not MULTI_TENANT or api_key_changed: return @@ -200,9 +238,9 @@ def _validate_llm_provider_change( ) if api_base_changed or custom_config_changed: - raise HTTPException( - status_code=400, - detail="API base and/or custom config cannot be changed without changing the API key", + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "API base and/or custom config cannot be changed without changing the API key", ) @@ -222,7 +260,7 @@ def fetch_llm_provider_options( for well_known_llm in well_known_llms: if well_known_llm.name == provider_name: return well_known_llm - raise HTTPException(status_code=404, detail=f"Provider {provider_name} not found") + raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Provider {provider_name} not found") @admin_router.post("/test") @@ -281,7 +319,7 @@ def test_llm_configuration( error_msg = test_llm(llm) if error_msg: - raise HTTPException(status_code=400, detail=error_msg) + raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg) @admin_router.post("/test/default") @@ -292,11 +330,11 @@ def test_default_provider( llm = get_default_llm() except ValueError: logger.exception("Failed to fetch default LLM Provider") - raise HTTPException(status_code=400, detail="No LLM Provider setup") + raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No LLM Provider setup") error = test_llm(llm) if error: - raise HTTPException(status_code=400, detail=str(error)) + raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(error)) @admin_router.get("/provider") @@ -362,35 +400,31 @@ def put_llm_provider( # Check name constraints # TODO: Once port from name to id is complete, unique name will no longer be required if existing_provider and llm_provider_upsert_request.name != existing_provider.name: - raise HTTPException( - status_code=400, - detail="Renaming providers is not currently supported", + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "Renaming providers is not currently supported", ) found_provider = fetch_existing_llm_provider( name=llm_provider_upsert_request.name, db_session=db_session ) if found_provider is not None and found_provider is not existing_provider: - raise HTTPException( - status_code=400, - detail=f"Provider with name={llm_provider_upsert_request.name} already exists", + raise OnyxError( + OnyxErrorCode.DUPLICATE_RESOURCE, + f"Provider with name={llm_provider_upsert_request.name} already exists", ) if existing_provider and is_creation: - raise HTTPException( - status_code=400, - detail=( - f"LLM Provider with name {llm_provider_upsert_request.name} and " - f"id={llm_provider_upsert_request.id} already exists" - ), + raise OnyxError( + OnyxErrorCode.DUPLICATE_RESOURCE, + f"LLM Provider with name {llm_provider_upsert_request.name} and " + f"id={llm_provider_upsert_request.id} already exists", ) elif not existing_provider and not is_creation: - raise HTTPException( - status_code=400, - detail=( - f"LLM Provider with name {llm_provider_upsert_request.name} and " - f"id={llm_provider_upsert_request.id} does not exist" - ), + raise OnyxError( + OnyxErrorCode.NOT_FOUND, + f"LLM Provider with name {llm_provider_upsert_request.name} and " + f"id={llm_provider_upsert_request.id} does not exist", ) # SSRF Protection: Validate api_base and custom_config match stored values @@ -415,9 +449,9 @@ def put_llm_provider( db_session, persona_ids ) if missing_personas: - raise HTTPException( - status_code=400, - detail=f"Invalid persona IDs: {', '.join(map(str, missing_personas))}", + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + f"Invalid persona IDs: {', '.join(map(str, missing_personas))}", ) # Remove duplicates while preserving order seen: set[int] = set() @@ -444,6 +478,18 @@ def put_llm_provider( not existing_provider or not existing_provider.is_auto_mode ) + # When transitioning to auto mode, preserve existing model configurations + # so the upsert doesn't try to delete them (which would trip the default + # model protection guard). sync_auto_mode_models will handle the model + # lifecycle afterward — adding new models, hiding removed ones, and + # updating the default. This is safe even if sync fails: the provider + # keeps its old models and default rather than losing them. + if transitioning_to_auto_mode and existing_provider: + llm_provider_upsert_request.model_configurations = [ + ModelConfigurationUpsertRequest.from_model(mc) + for mc in existing_provider.model_configurations + ] + try: result = upsert_llm_provider( llm_provider_upsert_request=llm_provider_upsert_request, @@ -456,7 +502,6 @@ def put_llm_provider( config = fetch_llm_recommendations_from_github() if config and llm_provider_upsert_request.provider in config.providers: - # Refetch the provider to get the updated model updated_provider = fetch_existing_llm_provider_by_id( id=result.id, db_session=db_session ) @@ -473,19 +518,29 @@ def put_llm_provider( return result except ValueError as e: logger.exception("Failed to upsert LLM Provider") - raise HTTPException(status_code=400, detail=str(e)) + raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e)) @admin_router.delete("/provider/{provider_id}") def delete_llm_provider( provider_id: int, + force: bool = Query(False), _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> None: + if not force: + model = fetch_default_llm_model(db_session) + + if model and model.llm_provider_id == provider_id: + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "Cannot delete the default LLM provider", + ) + try: remove_llm_provider(db_session, provider_id) except ValueError as e: - raise HTTPException(status_code=404, detail=str(e)) + raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e)) @admin_router.post("/default") @@ -525,9 +580,9 @@ def get_auto_config( """ config = fetch_llm_recommendations_from_github() if not config: - raise HTTPException( - status_code=502, - detail="Failed to fetch configuration from GitHub", + raise OnyxError( + OnyxErrorCode.BAD_GATEWAY, + "Failed to fetch configuration from GitHub", ) return config.model_dump() @@ -603,9 +658,9 @@ def list_llm_provider_basics( for provider in all_providers: # Use centralized access control logic with persona=None since we're # listing providers without a specific persona context. This correctly: - # - Includes all public providers + # - Includes public providers WITHOUT persona restrictions # - Includes providers user can access via group membership - # - Excludes persona-only restricted providers (requires specific persona) + # - Excludes providers with persona restrictions (requires specific persona) # - Excludes non-public providers with no restrictions (admin-only) if can_user_access_llm_provider( provider, user_group_ids, persona=None, is_admin=is_admin @@ -638,7 +693,7 @@ def get_valid_model_names_for_persona( Returns a list of model names (e.g., ["gpt-4o", "claude-3-5-sonnet"]) that are available to the user when using this persona, respecting all RBAC restrictions. - Public providers are always included. + Public providers are included unless they have persona restrictions that exclude this persona. """ persona = fetch_persona_with_groups(db_session, persona_id) if not persona: @@ -652,7 +707,7 @@ def get_valid_model_names_for_persona( valid_models = [] for llm_provider_model in all_providers: - # Public providers always included, restricted checked via RBAC + # Check access with persona context — respects all RBAC restrictions if can_user_access_llm_provider( llm_provider_model, user_group_ids, persona, is_admin=is_admin ): @@ -673,7 +728,7 @@ def list_llm_providers_for_persona( """Get LLM providers for a specific persona. Returns providers that the user can access when using this persona: - - All public providers (is_public=True) - ALWAYS included + - Public providers (respecting persona restrictions if set) - Restricted providers user can access via group/persona restrictions This endpoint is used for background fetching of restricted providers @@ -684,13 +739,13 @@ def list_llm_providers_for_persona( persona = fetch_persona_with_groups(db_session, persona_id) if not persona: - raise HTTPException(status_code=404, detail="Persona not found") + raise OnyxError(OnyxErrorCode.PERSONA_NOT_FOUND, "Persona not found") # Verify user has access to this persona if not user_can_access_persona(db_session, persona_id, user, get_editable=False): - raise HTTPException( - status_code=403, - detail="You don't have access to this assistant", + raise OnyxError( + OnyxErrorCode.INSUFFICIENT_PERMISSIONS, + "You don't have access to this assistant", ) is_admin = user.role == UserRole.ADMIN @@ -702,7 +757,7 @@ def list_llm_providers_for_persona( llm_provider_list: list[LLMProviderDescriptor] = [] for llm_provider_model in all_providers: - # Use simplified access check - public providers always included + # Check access with persona context — respects persona restrictions if can_user_access_llm_provider( llm_provider_model, user_group_ids, persona, is_admin=is_admin ): @@ -844,9 +899,9 @@ def get_bedrock_available_models( try: bedrock = session.client("bedrock") except Exception as e: - raise HTTPException( - status_code=400, - detail=f"Failed to create Bedrock client: {e}. Check AWS credentials and region.", + raise OnyxError( + OnyxErrorCode.CREDENTIAL_INVALID, + f"Failed to create Bedrock client: {e}. Check AWS credentials and region.", ) # Build model info dict from foundation models (modelId -> metadata) @@ -940,39 +995,32 @@ def get_bedrock_available_models( # Sync new models to DB if provider_name is specified if request.provider_name: - try: - models_to_sync = [ - { - "name": r.name, - "display_name": r.display_name, - "max_input_tokens": r.max_input_tokens, - "supports_image_input": r.supports_image_input, - } - for r in results - ] - new_count = sync_model_configurations( - db_session=db_session, - provider_name=request.provider_name, - models=models_to_sync, - ) - if new_count > 0: - logger.info( - f"Added {new_count} new Bedrock models to provider '{request.provider_name}'" + _sync_fetched_models( + db_session=db_session, + provider_name=request.provider_name, + models=[ + SyncModelEntry( + name=r.name, + display_name=r.display_name, + max_input_tokens=r.max_input_tokens, + supports_image_input=r.supports_image_input, ) - except ValueError as e: - logger.warning(f"Failed to sync Bedrock models to DB: {e}") + for r in results + ], + source_label="Bedrock", + ) return results except (ClientError, NoCredentialsError, BotoCoreError) as e: - raise HTTPException( - status_code=400, - detail=f"Failed to connect to AWS Bedrock: {e}", + raise OnyxError( + OnyxErrorCode.CREDENTIAL_INVALID, + f"Failed to connect to AWS Bedrock: {e}", ) except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Unexpected error fetching Bedrock models: {e}", + raise OnyxError( + OnyxErrorCode.INTERNAL_ERROR, + f"Unexpected error fetching Bedrock models: {e}", ) @@ -984,9 +1032,9 @@ def _get_ollama_available_model_names(api_base: str) -> set[str]: response.raise_for_status() response_json = response.json() except Exception as e: - raise HTTPException( - status_code=400, - detail=f"Failed to fetch Ollama models: {e}", + raise OnyxError( + OnyxErrorCode.BAD_GATEWAY, + f"Failed to fetch Ollama models: {e}", ) models = response_json.get("models", []) @@ -1003,9 +1051,9 @@ def get_ollama_available_models( cleaned_api_base = request.api_base.strip().rstrip("/") if not cleaned_api_base: - raise HTTPException( - status_code=400, - detail="API base URL is required to fetch Ollama models.", + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "API base URL is required to fetch Ollama models.", ) # NOTE: most people run Ollama locally, so we don't disallow internal URLs @@ -1014,9 +1062,9 @@ def get_ollama_available_models( # with the same response format model_names = _get_ollama_available_model_names(cleaned_api_base) if not model_names: - raise HTTPException( - status_code=400, - detail="No models found from your Ollama server", + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "No models found from your Ollama server", ) all_models_with_context_size_and_vision: list[OllamaFinalModelResponse] = [] @@ -1078,27 +1126,20 @@ def get_ollama_available_models( # Sync new models to DB if provider_name is specified if request.provider_name: - try: - models_to_sync = [ - { - "name": r.name, - "display_name": r.display_name, - "max_input_tokens": r.max_input_tokens, - "supports_image_input": r.supports_image_input, - } - for r in sorted_results - ] - new_count = sync_model_configurations( - db_session=db_session, - provider_name=request.provider_name, - models=models_to_sync, - ) - if new_count > 0: - logger.info( - f"Added {new_count} new Ollama models to provider '{request.provider_name}'" + _sync_fetched_models( + db_session=db_session, + provider_name=request.provider_name, + models=[ + SyncModelEntry( + name=r.name, + display_name=r.display_name, + max_input_tokens=r.max_input_tokens, + supports_image_input=r.supports_image_input, ) - except ValueError as e: - logger.warning(f"Failed to sync Ollama models to DB: {e}") + for r in sorted_results + ], + source_label="Ollama", + ) return sorted_results @@ -1118,9 +1159,9 @@ def _get_openrouter_models_response(api_base: str, api_key: str) -> dict: response.raise_for_status() return response.json() except Exception as e: - raise HTTPException( - status_code=400, - detail=f"Failed to fetch OpenRouter models: {e}", + raise OnyxError( + OnyxErrorCode.BAD_GATEWAY, + f"Failed to fetch OpenRouter models: {e}", ) @@ -1141,9 +1182,9 @@ def get_openrouter_available_models( data = response_json.get("data", []) if not isinstance(data, list) or len(data) == 0: - raise HTTPException( - status_code=400, - detail="No models found from your OpenRouter endpoint", + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "No models found from your OpenRouter endpoint", ) results: list[OpenRouterFinalModelResponse] = [] @@ -1178,34 +1219,235 @@ def get_openrouter_available_models( ) if not results: - raise HTTPException( - status_code=400, detail="No compatible models found from OpenRouter" + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "No compatible models found from OpenRouter", ) sorted_results = sorted(results, key=lambda m: m.name.lower()) # Sync new models to DB if provider_name is specified if request.provider_name: - try: - models_to_sync = [ - { - "name": r.name, - "display_name": r.display_name, - "max_input_tokens": r.max_input_tokens, - "supports_image_input": r.supports_image_input, - } + _sync_fetched_models( + db_session=db_session, + provider_name=request.provider_name, + models=[ + SyncModelEntry( + name=r.name, + display_name=r.display_name, + max_input_tokens=r.max_input_tokens, + supports_image_input=r.supports_image_input, + ) for r in sorted_results - ] - new_count = sync_model_configurations( - db_session=db_session, - provider_name=request.provider_name, - models=models_to_sync, + ], + source_label="OpenRouter", + ) + + return sorted_results + + +@admin_router.post("/lm-studio/available-models") +def get_lm_studio_available_models( + request: LMStudioModelsRequest, + _: User = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[LMStudioFinalModelResponse]: + """Fetch available models from an LM Studio server. + + Uses the LM Studio-native /api/v1/models endpoint which exposes + rich metadata including capabilities (vision, reasoning), + display names, and context lengths. + """ + cleaned_api_base = request.api_base.strip().rstrip("/") + # Strip /v1 suffix that users may copy from OpenAI-compatible tool configs; + # the native metadata endpoint lives at /api/v1/models, not /v1/api/v1/models. + cleaned_api_base = cleaned_api_base.removesuffix("/v1") + if not cleaned_api_base: + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "API base URL is required to fetch LM Studio models.", + ) + + # If provider_name is given and the api_key hasn't been changed by the user, + # fall back to the stored API key from the database (the form value is masked). + api_key = request.api_key + if request.provider_name and not request.api_key_changed: + existing_provider = fetch_existing_llm_provider( + name=request.provider_name, db_session=db_session + ) + if existing_provider and existing_provider.custom_config: + api_key = existing_provider.custom_config.get(LM_STUDIO_API_KEY_CONFIG_KEY) + + url = f"{cleaned_api_base}/api/v1/models" + headers: dict[str, str] = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + try: + response = httpx.get(url, headers=headers, timeout=10.0) + response.raise_for_status() + response_json = response.json() + except Exception as e: + raise OnyxError( + OnyxErrorCode.BAD_GATEWAY, + f"Failed to fetch LM Studio models: {e}", + ) + + models = response_json.get("models", []) + if not isinstance(models, list) or len(models) == 0: + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "No models found from your LM Studio server.", + ) + + results: list[LMStudioFinalModelResponse] = [] + for item in models: + # Filter to LLM-type models only (skip embeddings, etc.) + if item.get("type") != "llm": + continue + + model_key = item.get("key") + if not model_key: + continue + + display_name = item.get("display_name") or model_key + max_context_length = item.get("max_context_length") + capabilities = item.get("capabilities") or {} + + results.append( + LMStudioFinalModelResponse( + name=model_key, + display_name=display_name, + max_input_tokens=max_context_length, + supports_image_input=capabilities.get("vision", False), + supports_reasoning=capabilities.get("reasoning", False) + or is_reasoning_model(model_key, display_name), ) - if new_count > 0: - logger.info( - f"Added {new_count} new OpenRouter models to provider '{request.provider_name}'" + ) + + if not results: + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "No compatible models found from LM Studio server.", + ) + + sorted_results = sorted(results, key=lambda m: m.name.lower()) + + # Sync new models to DB if provider_name is specified + if request.provider_name: + _sync_fetched_models( + db_session=db_session, + provider_name=request.provider_name, + models=[ + SyncModelEntry( + name=r.name, + display_name=r.display_name, + max_input_tokens=r.max_input_tokens, + supports_image_input=r.supports_image_input, ) - except ValueError as e: - logger.warning(f"Failed to sync OpenRouter models to DB: {e}") + for r in sorted_results + ], + source_label="LM Studio", + ) return sorted_results + + +@admin_router.post("/litellm/available-models") +def get_litellm_available_models( + request: LitellmModelsRequest, + _: User = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[LitellmFinalModelResponse]: + """Fetch available models from Litellm proxy /v1/models endpoint.""" + response_json = _get_litellm_models_response( + api_key=request.api_key, api_base=request.api_base + ) + + models = response_json.get("data", []) + if not isinstance(models, list) or len(models) == 0: + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "No models found from your Litellm endpoint", + ) + + results: list[LitellmFinalModelResponse] = [] + for model in models: + try: + model_details = LitellmModelDetails.model_validate(model) + + results.append( + LitellmFinalModelResponse( + provider_name=model_details.owned_by, + model_name=model_details.id, + ) + ) + except Exception as e: + logger.warning( + "Failed to parse Litellm model entry", + extra={"error": str(e), "item": str(model)[:1000]}, + ) + + if not results: + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "No compatible models found from Litellm", + ) + + sorted_results = sorted(results, key=lambda m: m.model_name.lower()) + + # Sync new models to DB if provider_name is specified + if request.provider_name: + _sync_fetched_models( + db_session=db_session, + provider_name=request.provider_name, + models=[ + SyncModelEntry( + name=r.model_name, + display_name=r.model_name, + ) + for r in sorted_results + ], + source_label="LiteLLM", + ) + + return sorted_results + + +def _get_litellm_models_response(api_key: str, api_base: str) -> dict: + """Perform GET to Litellm proxy /api/v1/models and return parsed JSON.""" + cleaned_api_base = api_base.strip().rstrip("/") + url = f"{cleaned_api_base}/v1/models" + + headers = { + "Authorization": f"Bearer {api_key}", + "HTTP-Referer": "https://onyx.app", + "X-Title": "Onyx", + } + + try: + response = httpx.get(url, headers=headers, timeout=10.0) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + if e.response.status_code == 401: + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + "Authentication failed: invalid or missing API key for LiteLLM proxy.", + ) + elif e.response.status_code == 404: + raise OnyxError( + OnyxErrorCode.VALIDATION_ERROR, + f"LiteLLM models endpoint not found at {url}. " + "Please verify the API base URL.", + ) + else: + raise OnyxError( + OnyxErrorCode.BAD_GATEWAY, + f"Failed to fetch LiteLLM models: {e}", + ) + except Exception as e: + raise OnyxError( + OnyxErrorCode.BAD_GATEWAY, + f"Failed to fetch LiteLLM models: {e}", + ) diff --git a/backend/onyx/server/manage/llm/models.py b/backend/onyx/server/manage/llm/models.py index 623e0654b31..cf92a72ee77 100644 --- a/backend/onyx/server/manage/llm/models.py +++ b/backend/onyx/server/manage/llm/models.py @@ -371,6 +371,22 @@ class OpenRouterFinalModelResponse(BaseModel): supports_image_input: bool +# LM Studio dynamic models fetch +class LMStudioModelsRequest(BaseModel): + api_base: str + api_key: str | None = None + api_key_changed: bool = False + provider_name: str | None = None # Optional: to save models to existing provider + + +class LMStudioFinalModelResponse(BaseModel): + name: str # Model ID from LM Studio (e.g., "lmstudio-community/Meta-Llama-3-8B") + display_name: str # Human-readable name + max_input_tokens: int | None # From LM Studio API or None if unavailable + supports_image_input: bool + supports_reasoning: bool + + class DefaultModel(BaseModel): provider_id: int model_name: str @@ -404,3 +420,32 @@ def from_models( default_text=default_text, default_vision=default_vision, ) + + +class SyncModelEntry(BaseModel): + """Typed model for syncing fetched models to the DB.""" + + name: str + display_name: str + max_input_tokens: int | None = None + supports_image_input: bool = False + + +class LitellmModelsRequest(BaseModel): + api_key: str + api_base: str + provider_name: str | None = None # Optional: to save models to existing provider + + +class LitellmModelDetails(BaseModel): + """Response model for Litellm proxy /api/v1/models endpoint""" + + id: str # Model ID (e.g. "gpt-4o") + object: str # "model" + created: int # Unix timestamp in seconds + owned_by: str # Provider name (e.g. "openai") + + +class LitellmFinalModelResponse(BaseModel): + provider_name: str # Provider name (e.g. "openai") + model_name: str # Model ID (e.g. "gpt-4o") diff --git a/backend/onyx/server/manage/llm/utils.py b/backend/onyx/server/manage/llm/utils.py index 1237da5b9f1..bdef898cabc 100644 --- a/backend/onyx/server/manage/llm/utils.py +++ b/backend/onyx/server/manage/llm/utils.py @@ -12,6 +12,7 @@ from onyx.llm.constants import BEDROCK_MODEL_NAME_MAPPINGS from onyx.llm.constants import LlmProviderNames +from onyx.llm.constants import MODEL_PREFIX_TO_VENDOR from onyx.llm.constants import OLLAMA_MODEL_NAME_MAPPINGS from onyx.llm.constants import OLLAMA_MODEL_TO_VENDOR from onyx.llm.constants import PROVIDER_DISPLAY_NAMES @@ -23,6 +24,7 @@ LlmProviderNames.OPENROUTER, LlmProviderNames.BEDROCK, LlmProviderNames.OLLAMA_CHAT, + LlmProviderNames.LM_STUDIO, } ) @@ -348,4 +350,19 @@ def extract_vendor_from_model_name(model_name: str, provider: str) -> str | None # Fallback: capitalize the base name as vendor return base_name.split("-")[0].title() + elif provider == LlmProviderNames.LM_STUDIO: + # LM Studio model IDs can be paths like "publisher/model-name" + # or simple names. Use MODEL_PREFIX_TO_VENDOR for matching. + + model_lower = model_name.lower() + # Check for slash-separated vendor prefix first + if "/" in model_lower: + vendor_key = model_lower.split("/")[0] + return PROVIDER_DISPLAY_NAMES.get(vendor_key, vendor_key.title()) + # Fallback to model prefix matching + for prefix, vendor in MODEL_PREFIX_TO_VENDOR.items(): + if model_lower.startswith(prefix): + return PROVIDER_DISPLAY_NAMES.get(vendor, vendor.title()) + return None + return None diff --git a/backend/onyx/server/manage/search_settings.py b/backend/onyx/server/manage/search_settings.py index 7f3181f5021..a7cf129195d 100644 --- a/backend/onyx/server/manage/search_settings.py +++ b/backend/onyx/server/manage/search_settings.py @@ -6,8 +6,11 @@ from onyx.auth.users import current_admin_user from onyx.auth.users import current_user +from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP from onyx.context.search.models import SavedSearchSettings from onyx.context.search.models import SearchSettingsCreationRequest +from onyx.db.connector_credential_pair import get_connector_credential_pairs +from onyx.db.connector_credential_pair import resync_cc_pair from onyx.db.engine.sql_engine import get_session from onyx.db.index_attempt import expire_index_attempts from onyx.db.llm import fetch_existing_llm_provider @@ -15,20 +18,25 @@ from onyx.db.llm import update_no_default_contextual_rag_provider from onyx.db.models import IndexModelStatus from onyx.db.models import User +from onyx.db.search_settings import create_search_settings from onyx.db.search_settings import delete_search_settings from onyx.db.search_settings import get_current_search_settings +from onyx.db.search_settings import get_embedding_provider_from_provider_type from onyx.db.search_settings import get_secondary_search_settings from onyx.db.search_settings import update_current_search_settings from onyx.db.search_settings import update_search_settings_status +from onyx.document_index.factory import get_all_document_indices from onyx.document_index.factory import get_default_document_index from onyx.file_processing.unstructured import delete_unstructured_api_key from onyx.file_processing.unstructured import get_unstructured_api_key from onyx.file_processing.unstructured import update_unstructured_api_key +from onyx.natural_language_processing.search_nlp_models import clean_model_name from onyx.server.manage.embedding.models import SearchSettingsDeleteRequest from onyx.server.manage.models import FullModelVersionResponse from onyx.server.models import IdReturn from onyx.server.utils_vector_db import require_vector_db from onyx.utils.logger import setup_logger +from shared_configs.configs import ALT_INDEX_SUFFIX from shared_configs.configs import MULTI_TENANT router = APIRouter(prefix="/search-settings") @@ -41,110 +49,99 @@ def set_new_search_settings( _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), # noqa: ARG001 ) -> IdReturn: - """Creates a new EmbeddingModel row and cancels the previous secondary indexing if any - Gives an error if the same model name is used as the current or secondary index """ - # TODO(andrei): Re-enable. - # NOTE Enable integration external dependency tests in test_search_settings.py - # when this is reenabled. They are currently skipped - logger.error("Setting new search settings is temporarily disabled.") - raise HTTPException( - status_code=status.HTTP_501_NOT_IMPLEMENTED, - detail="Setting new search settings is temporarily disabled.", + Creates a new SearchSettings row and cancels the previous secondary indexing + if any exists. + """ + if search_settings_new.index_name: + logger.warning("Index name was specified by request, this is not suggested") + + # Disallow contextual RAG for cloud deployments. + if MULTI_TENANT and search_settings_new.enable_contextual_rag: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Contextual RAG disabled in Onyx Cloud", + ) + + # Validate cloud provider exists or create new LiteLLM provider. + if search_settings_new.provider_type is not None: + cloud_provider = get_embedding_provider_from_provider_type( + db_session, provider_type=search_settings_new.provider_type + ) + + if cloud_provider is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}", + ) + + validate_contextual_rag_model( + provider_name=search_settings_new.contextual_rag_llm_provider, + model_name=search_settings_new.contextual_rag_llm_name, + db_session=db_session, ) - # if search_settings_new.index_name: - # logger.warning("Index name was specified by request, this is not suggested") - - # # Disallow contextual RAG for cloud deployments - # if MULTI_TENANT and search_settings_new.enable_contextual_rag: - # raise HTTPException( - # status_code=status.HTTP_400_BAD_REQUEST, - # detail="Contextual RAG disabled in Onyx Cloud", - # ) - - # # Validate cloud provider exists or create new LiteLLM provider - # if search_settings_new.provider_type is not None: - # cloud_provider = get_embedding_provider_from_provider_type( - # db_session, provider_type=search_settings_new.provider_type - # ) - - # if cloud_provider is None: - # raise HTTPException( - # status_code=status.HTTP_400_BAD_REQUEST, - # detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}", - # ) - - # validate_contextual_rag_model( - # provider_name=search_settings_new.contextual_rag_llm_provider, - # model_name=search_settings_new.contextual_rag_llm_name, - # db_session=db_session, - # ) - - # search_settings = get_current_search_settings(db_session) - - # if search_settings_new.index_name is None: - # # We define index name here - # index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}" - # if ( - # search_settings_new.model_name == search_settings.model_name - # and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX) - # ): - # index_name += ALT_INDEX_SUFFIX - # search_values = search_settings_new.model_dump() - # search_values["index_name"] = index_name - # new_search_settings_request = SavedSearchSettings(**search_values) - # else: - # new_search_settings_request = SavedSearchSettings( - # **search_settings_new.model_dump() - # ) - - # secondary_search_settings = get_secondary_search_settings(db_session) - - # if secondary_search_settings: - # # Cancel any background indexing jobs - # expire_index_attempts( - # search_settings_id=secondary_search_settings.id, db_session=db_session - # ) - - # # Mark previous model as a past model directly - # update_search_settings_status( - # search_settings=secondary_search_settings, - # new_status=IndexModelStatus.PAST, - # db_session=db_session, - # ) - - # new_search_settings = create_search_settings( - # search_settings=new_search_settings_request, db_session=db_session - # ) - - # # Ensure Vespa has the new index immediately - # get_multipass_config(search_settings) - # get_multipass_config(new_search_settings) - # document_index = get_default_document_index( - # search_settings, new_search_settings, db_session - # ) - - # document_index.ensure_indices_exist( - # primary_embedding_dim=search_settings.final_embedding_dim, - # primary_embedding_precision=search_settings.embedding_precision, - # secondary_index_embedding_dim=new_search_settings.final_embedding_dim, - # secondary_index_embedding_precision=new_search_settings.embedding_precision, - # ) - - # # Pause index attempts for the currently in use index to preserve resources - # if DISABLE_INDEX_UPDATE_ON_SWAP: - # expire_index_attempts( - # search_settings_id=search_settings.id, db_session=db_session - # ) - # for cc_pair in get_connector_credential_pairs(db_session): - # resync_cc_pair( - # cc_pair=cc_pair, - # search_settings_id=new_search_settings.id, - # db_session=db_session, - # ) - - # db_session.commit() - # return IdReturn(id=new_search_settings.id) + + search_settings = get_current_search_settings(db_session) + + if search_settings_new.index_name is None: + # We define index name here. + index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}" + if ( + search_settings_new.model_name == search_settings.model_name + and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX) + ): + index_name += ALT_INDEX_SUFFIX + search_values = search_settings_new.model_dump() + search_values["index_name"] = index_name + new_search_settings_request = SavedSearchSettings(**search_values) + else: + new_search_settings_request = SavedSearchSettings( + **search_settings_new.model_dump() + ) + + secondary_search_settings = get_secondary_search_settings(db_session) + + if secondary_search_settings: + # Cancel any background indexing jobs. + expire_index_attempts( + search_settings_id=secondary_search_settings.id, db_session=db_session + ) + + # Mark previous model as a past model directly. + update_search_settings_status( + search_settings=secondary_search_settings, + new_status=IndexModelStatus.PAST, + db_session=db_session, + ) + + new_search_settings = create_search_settings( + search_settings=new_search_settings_request, db_session=db_session + ) + + # Ensure the document indices have the new index immediately. + document_indices = get_all_document_indices(search_settings, new_search_settings) + for document_index in document_indices: + document_index.ensure_indices_exist( + primary_embedding_dim=search_settings.final_embedding_dim, + primary_embedding_precision=search_settings.embedding_precision, + secondary_index_embedding_dim=new_search_settings.final_embedding_dim, + secondary_index_embedding_precision=new_search_settings.embedding_precision, + ) + + # Pause index attempts for the currently in-use index to preserve resources. + if DISABLE_INDEX_UPDATE_ON_SWAP: + expire_index_attempts( + search_settings_id=search_settings.id, db_session=db_session + ) + for cc_pair in get_connector_credential_pairs(db_session): + resync_cc_pair( + cc_pair=cc_pair, + search_settings_id=new_search_settings.id, + db_session=db_session, + ) + + db_session.commit() + return IdReturn(id=new_search_settings.id) @router.post("/cancel-new-embedding", dependencies=[Depends(require_vector_db)]) diff --git a/backend/onyx/server/manage/users.py b/backend/onyx/server/manage/users.py index 9278ddc990a..fb9d941c61b 100644 --- a/backend/onyx/server/manage/users.py +++ b/backend/onyx/server/manage/users.py @@ -5,6 +5,7 @@ from datetime import timedelta from datetime import timezone from typing import cast +from uuid import UUID import jwt from email_validator import EmailNotValidError @@ -18,6 +19,7 @@ from fastapi import Request from fastapi.responses import StreamingResponse from pydantic import BaseModel +from sqlalchemy import select from sqlalchemy.orm import Session from onyx.auth.anonymous_user import fetch_anonymous_user_info @@ -67,11 +69,14 @@ from onyx.db.user_preferences import update_user_shortcut_enabled from onyx.db.user_preferences import update_user_temperature_override_enabled from onyx.db.user_preferences import update_user_theme_preference +from onyx.db.users import batch_get_user_groups from onyx.db.users import delete_user_from_db +from onyx.db.users import get_all_accepted_users from onyx.db.users import get_all_users from onyx.db.users import get_page_of_filtered_users from onyx.db.users import get_total_filtered_users_count from onyx.db.users import get_user_by_email +from onyx.db.users import get_user_counts_by_role_and_status from onyx.db.users import validate_user_role_update from onyx.key_value_store.factory import get_kv_store from onyx.redis.redis_pool import get_raw_redis_client @@ -98,6 +103,7 @@ from onyx.server.models import FullUserSnapshot from onyx.server.models import InvitedUserSnapshot from onyx.server.models import MinimalUserSnapshot +from onyx.server.models import UserGroupInfo from onyx.server.usage_limits import is_tenant_on_trial_fn from onyx.server.utils import BasicAuthenticationError from onyx.utils.logger import setup_logger @@ -203,14 +209,91 @@ def list_accepted_users( total_items=0, ) + user_ids = [user.id for user in filtered_accepted_users] + groups_by_user = batch_get_user_groups(db_session, user_ids) + + # Batch-fetch SCIM mappings to mark synced users + scim_synced_ids: set[UUID] = set() + try: + from onyx.db.models import ScimUserMapping + + scim_mappings = db_session.scalars( + select(ScimUserMapping.user_id).where(ScimUserMapping.user_id.in_(user_ids)) + ).all() + scim_synced_ids = set(scim_mappings) + except Exception: + logger.warning( + "Failed to fetch SCIM mappings; marking all users as non-synced", + exc_info=True, + ) + return PaginatedReturn( items=[ - FullUserSnapshot.from_user_model(user) for user in filtered_accepted_users + FullUserSnapshot.from_user_model( + user, + groups=[ + UserGroupInfo(id=gid, name=gname) + for gid, gname in groups_by_user.get(user.id, []) + ], + is_scim_synced=user.id in scim_synced_ids, + ) + for user in filtered_accepted_users ], total_items=total_accepted_users_count, ) +@router.get("/manage/users/accepted/all", tags=PUBLIC_API_TAGS) +def list_all_accepted_users( + _: User = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[FullUserSnapshot]: + """Returns all accepted users without pagination. + Used by the admin Users page for client-side filtering/sorting.""" + users = get_all_accepted_users(db_session=db_session) + + if not users: + return [] + + user_ids = [user.id for user in users] + groups_by_user = batch_get_user_groups(db_session, user_ids) + + # Batch-fetch SCIM mappings to mark synced users + scim_synced_ids: set[UUID] = set() + try: + from onyx.db.models import ScimUserMapping + + scim_mappings = db_session.scalars( + select(ScimUserMapping.user_id).where(ScimUserMapping.user_id.in_(user_ids)) + ).all() + scim_synced_ids = set(scim_mappings) + except Exception: + logger.warning( + "Failed to fetch SCIM mappings; marking all users as non-synced", + exc_info=True, + ) + + return [ + FullUserSnapshot.from_user_model( + user, + groups=[ + UserGroupInfo(id=gid, name=gname) + for gid, gname in groups_by_user.get(user.id, []) + ], + is_scim_synced=user.id in scim_synced_ids, + ) + for user in users + ] + + +@router.get("/manage/users/counts") +def get_user_counts( + _: User = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> dict[str, dict[str, int]]: + return get_user_counts_by_role_and_status(db_session) + + @router.get("/manage/users/invited", tags=PUBLIC_API_TAGS) def list_invited_users( _: User = Depends(current_admin_user), @@ -269,24 +352,10 @@ def list_all_users( if accepted_page is None or invited_page is None or slack_users_page is None: return AllUsersResponse( accepted=[ - FullUserSnapshot( - id=user.id, - email=user.email, - role=user.role, - is_active=user.is_active, - password_configured=user.password_configured, - ) - for user in accepted_users + FullUserSnapshot.from_user_model(user) for user in accepted_users ], slack_users=[ - FullUserSnapshot( - id=user.id, - email=user.email, - role=user.role, - is_active=user.is_active, - password_configured=user.password_configured, - ) - for user in slack_users + FullUserSnapshot.from_user_model(user) for user in slack_users ], invited=[InvitedUserSnapshot(email=email) for email in invited_emails], accepted_pages=1, @@ -296,26 +365,10 @@ def list_all_users( # Otherwise, return paginated results return AllUsersResponse( - accepted=[ - FullUserSnapshot( - id=user.id, - email=user.email, - role=user.role, - is_active=user.is_active, - password_configured=user.password_configured, - ) - for user in accepted_users - ][accepted_page * USERS_PAGE_SIZE : (accepted_page + 1) * USERS_PAGE_SIZE], - slack_users=[ - FullUserSnapshot( - id=user.id, - email=user.email, - role=user.role, - is_active=user.is_active, - password_configured=user.password_configured, - ) - for user in slack_users - ][ + accepted=[FullUserSnapshot.from_user_model(user) for user in accepted_users][ + accepted_page * USERS_PAGE_SIZE : (accepted_page + 1) * USERS_PAGE_SIZE + ], + slack_users=[FullUserSnapshot.from_user_model(user) for user in slack_users][ slack_users_page * USERS_PAGE_SIZE : (slack_users_page + 1) * USERS_PAGE_SIZE diff --git a/backend/onyx/server/models.py b/backend/onyx/server/models.py index 3e1b4a3ec07..88fdf927e6d 100644 --- a/backend/onyx/server/models.py +++ b/backend/onyx/server/models.py @@ -1,3 +1,4 @@ +import datetime from typing import Generic from typing import Optional from typing import TypeVar @@ -31,21 +32,41 @@ class MinimalUserSnapshot(BaseModel): email: str +class UserGroupInfo(BaseModel): + id: int + name: str + + class FullUserSnapshot(BaseModel): id: UUID email: str role: UserRole is_active: bool password_configured: bool + personal_name: str | None + created_at: datetime.datetime + updated_at: datetime.datetime + groups: list[UserGroupInfo] + is_scim_synced: bool @classmethod - def from_user_model(cls, user: User) -> "FullUserSnapshot": + def from_user_model( + cls, + user: User, + groups: list[UserGroupInfo] | None = None, + is_scim_synced: bool = False, + ) -> "FullUserSnapshot": return cls( id=user.id, email=user.email, role=user.role, is_active=user.is_active, password_configured=user.password_configured, + personal_name=user.personal_name, + created_at=user.created_at, + updated_at=user.updated_at, + groups=groups or [], + is_scim_synced=is_scim_synced, ) diff --git a/backend/onyx/server/query_and_chat/chat_backend.py b/backend/onyx/server/query_and_chat/chat_backend.py index d1729f111d9..83162a00352 100644 --- a/backend/onyx/server/query_and_chat/chat_backend.py +++ b/backend/onyx/server/query_and_chat/chat_backend.py @@ -1,6 +1,5 @@ import datetime import json -import os from collections.abc import Generator from datetime import timedelta from uuid import UUID @@ -13,13 +12,13 @@ from fastapi import Response from fastapi.responses import StreamingResponse from pydantic import BaseModel -from redis.client import Redis from sqlalchemy.orm import Session from onyx.auth.api_key import get_hashed_api_key_from_request from onyx.auth.pat import get_hashed_pat_from_request from onyx.auth.users import current_chat_accessible_user from onyx.auth.users import current_user +from onyx.cache.factory import get_cache_backend from onyx.chat.chat_processing_checker import is_chat_session_processing from onyx.chat.chat_state import ChatStateContainer from onyx.chat.chat_utils import convert_chat_history_basic @@ -61,13 +60,11 @@ from onyx.db.usage import increment_usage from onyx.db.usage import UsageType from onyx.db.user_file import get_file_id_by_user_file_id -from onyx.file_processing.extract_file_text import docx_to_txt_filename from onyx.file_store.file_store import get_default_file_store from onyx.llm.constants import LlmProviderNames from onyx.llm.factory import get_default_llm from onyx.llm.factory import get_llm_for_persona from onyx.llm.factory import get_llm_token_counter -from onyx.redis.redis_pool import get_redis_client from onyx.secondary_llm_flows.chat_session_naming import generate_chat_session_name from onyx.server.api_key_usage import check_api_key_usage from onyx.server.query_and_chat.models import ChatFeedbackRequest @@ -152,10 +149,20 @@ def get_user_chat_sessions( project_id: int | None = None, only_non_project_chats: bool = True, include_failed_chats: bool = False, + page_size: int = Query(default=50, ge=1, le=100), + before: str | None = Query(default=None), ) -> ChatSessionsResponse: user_id = user.id try: + before_dt = ( + datetime.datetime.fromisoformat(before) if before is not None else None + ) + except ValueError: + raise HTTPException(status_code=422, detail="Invalid 'before' timestamp format") + + try: + # Fetch one extra to determine if there are more results chat_sessions = get_chat_sessions_by_user( user_id=user_id, deleted=False, @@ -163,11 +170,16 @@ def get_user_chat_sessions( project_id=project_id, only_non_project_chats=only_non_project_chats, include_failed_chats=include_failed_chats, + limit=page_size + 1, + before=before_dt, ) except ValueError: raise ValueError("Chat session does not exist or has been deleted") + has_more = len(chat_sessions) > page_size + chat_sessions = chat_sessions[:page_size] + return ChatSessionsResponse( sessions=[ ChatSessionDetails( @@ -181,7 +193,8 @@ def get_user_chat_sessions( current_temperature_override=chat.temperature_override, ) for chat in chat_sessions - ] + ], + has_more=has_more, ) @@ -314,7 +327,7 @@ def get_chat_session( ] try: - is_processing = is_chat_session_processing(session_id, get_redis_client()) + is_processing = is_chat_session_processing(session_id, get_cache_backend()) # Edit the last message to indicate loading (Overriding default message value) if is_processing and chat_message_details: last_msg = chat_message_details[-1] @@ -797,18 +810,6 @@ def fetch_chat_file( if not file_record: raise HTTPException(status_code=404, detail="File not found") - original_file_name = file_record.display_name - if file_record.file_type.startswith( - "application/vnd.openxmlformats-officedocument.wordprocessingml.document" - ): - # Check if a converted text file exists for .docx files - txt_file_name = docx_to_txt_filename(original_file_name) - txt_file_id = os.path.join(os.path.dirname(file_id), txt_file_name) - txt_file_record = file_store.read_file_record(txt_file_id) - if txt_file_record: - file_record = txt_file_record - file_id = txt_file_id - media_type = file_record.file_type file_io = file_store.read_file(file_id, mode="b") @@ -911,11 +912,10 @@ async def search_chats( def stop_chat_session( chat_session_id: UUID, user: User = Depends(current_user), # noqa: ARG001 - redis_client: Redis = Depends(get_redis_client), ) -> dict[str, str]: """ - Stop a chat session by setting a stop signal in Redis. + Stop a chat session by setting a stop signal. This endpoint is called by the frontend when the user clicks the stop button. """ - set_fence(chat_session_id, redis_client, True) + set_fence(chat_session_id, get_cache_backend(), True) return {"message": "Chat session stopped"} diff --git a/backend/onyx/server/query_and_chat/models.py b/backend/onyx/server/query_and_chat/models.py index 776dbbaa79a..76a4afd4ae0 100644 --- a/backend/onyx/server/query_and_chat/models.py +++ b/backend/onyx/server/query_and_chat/models.py @@ -192,6 +192,7 @@ def from_model(cls, model: ChatSession) -> "ChatSessionDetails": class ChatSessionsResponse(BaseModel): sessions: list[ChatSessionDetails] + has_more: bool = False class ChatMessageDetail(BaseModel): diff --git a/backend/onyx/server/query_and_chat/session_loading.py b/backend/onyx/server/query_and_chat/session_loading.py index 12fe95cc1fb..63538c11c40 100644 --- a/backend/onyx/server/query_and_chat/session_loading.py +++ b/backend/onyx/server/query_and_chat/session_loading.py @@ -1,9 +1,11 @@ from __future__ import annotations import json +from typing import Any from typing import cast from typing import Literal +from pydantic import ValidationError from sqlalchemy.orm import Session from onyx.chat.citation_utils import extract_citation_order_from_text @@ -20,7 +22,9 @@ from onyx.server.query_and_chat.streaming_models import AgentResponseDelta from onyx.server.query_and_chat.streaming_models import AgentResponseStart from onyx.server.query_and_chat.streaming_models import CitationInfo +from onyx.server.query_and_chat.streaming_models import CustomToolArgs from onyx.server.query_and_chat.streaming_models import CustomToolDelta +from onyx.server.query_and_chat.streaming_models import CustomToolErrorInfo from onyx.server.query_and_chat.streaming_models import CustomToolStart from onyx.server.query_and_chat.streaming_models import FileReaderResult from onyx.server.query_and_chat.streaming_models import FileReaderStart @@ -180,24 +184,37 @@ def create_custom_tool_packets( tab_index: int = 0, data: dict | list | str | int | float | bool | None = None, file_ids: list[str] | None = None, + error: CustomToolErrorInfo | None = None, + tool_args: dict[str, Any] | None = None, + tool_id: int | None = None, ) -> list[Packet]: packets: list[Packet] = [] packets.append( Packet( placement=Placement(turn_index=turn_index, tab_index=tab_index), - obj=CustomToolStart(tool_name=tool_name), + obj=CustomToolStart(tool_name=tool_name, tool_id=tool_id), ) ) + if tool_args: + packets.append( + Packet( + placement=Placement(turn_index=turn_index, tab_index=tab_index), + obj=CustomToolArgs(tool_name=tool_name, tool_args=tool_args), + ) + ) + packets.append( Packet( placement=Placement(turn_index=turn_index, tab_index=tab_index), obj=CustomToolDelta( tool_name=tool_name, + tool_id=tool_id, response_type=response_type, data=data, file_ids=file_ids, + error=error, ), ), ) @@ -657,13 +674,55 @@ def translate_assistant_message_to_packets( else: # Custom tool or unknown tool + # Try to parse as structured CustomToolCallSummary JSON + custom_data: dict | list | str | int | float | bool | None = ( + tool_call.tool_call_response + ) + custom_error: CustomToolErrorInfo | None = None + custom_response_type = "text" + + try: + parsed = json.loads(tool_call.tool_call_response) + if isinstance(parsed, dict) and "tool_name" in parsed: + custom_data = parsed.get("tool_result") + custom_response_type = parsed.get( + "response_type", "text" + ) + if parsed.get("error"): + custom_error = CustomToolErrorInfo( + **parsed["error"] + ) + except ( + json.JSONDecodeError, + KeyError, + TypeError, + ValidationError, + ): + pass + + custom_file_ids: list[str] | None = None + if custom_response_type in ("image", "csv") and isinstance( + custom_data, dict + ): + custom_file_ids = custom_data.get("file_ids") + custom_data = None + + custom_args = { + k: v + for k, v in (tool_call.tool_call_arguments or {}).items() + if k != "requestBody" + } turn_tool_packets.extend( create_custom_tool_packets( tool_name=tool.display_name or tool.name, - response_type="text", + response_type=custom_response_type, turn_index=turn_num, tab_index=tool_call.tab_index, - data=tool_call.tool_call_response, + data=custom_data, + file_ids=custom_file_ids, + error=custom_error, + tool_args=custom_args if custom_args else None, + tool_id=tool_call.tool_id, ) ) diff --git a/backend/onyx/server/query_and_chat/streaming_models.py b/backend/onyx/server/query_and_chat/streaming_models.py index 4ab5f2eda5a..6015efa0fa8 100644 --- a/backend/onyx/server/query_and_chat/streaming_models.py +++ b/backend/onyx/server/query_and_chat/streaming_models.py @@ -33,6 +33,7 @@ class StreamingType(Enum): PYTHON_TOOL_START = "python_tool_start" PYTHON_TOOL_DELTA = "python_tool_delta" CUSTOM_TOOL_START = "custom_tool_start" + CUSTOM_TOOL_ARGS = "custom_tool_args" CUSTOM_TOOL_DELTA = "custom_tool_delta" FILE_READER_START = "file_reader_start" FILE_READER_RESULT = "file_reader_result" @@ -41,6 +42,7 @@ class StreamingType(Enum): REASONING_DONE = "reasoning_done" CITATION_INFO = "citation_info" TOOL_CALL_DEBUG = "tool_call_debug" + TOOL_CALL_ARGUMENT_DELTA = "tool_call_argument_delta" MEMORY_TOOL_START = "memory_tool_start" MEMORY_TOOL_DELTA = "memory_tool_delta" @@ -245,6 +247,20 @@ class CustomToolStart(BaseObj): type: Literal["custom_tool_start"] = StreamingType.CUSTOM_TOOL_START.value tool_name: str + tool_id: int | None = None + + +class CustomToolArgs(BaseObj): + type: Literal["custom_tool_args"] = StreamingType.CUSTOM_TOOL_ARGS.value + + tool_name: str + tool_args: dict[str, Any] + + +class CustomToolErrorInfo(BaseModel): + is_auth_error: bool = False + status_code: int + message: str # The allowed streamed packets for a custom tool @@ -252,11 +268,22 @@ class CustomToolDelta(BaseObj): type: Literal["custom_tool_delta"] = StreamingType.CUSTOM_TOOL_DELTA.value tool_name: str + tool_id: int | None = None response_type: str # For non-file responses data: dict | list | str | int | float | bool | None = None # For file-based responses like image/csv file_ids: list[str] | None = None + error: CustomToolErrorInfo | None = None + + +class ToolCallArgumentDelta(BaseObj): + type: Literal["tool_call_argument_delta"] = ( + StreamingType.TOOL_CALL_ARGUMENT_DELTA.value + ) + + tool_type: str + argument_deltas: dict[str, Any] ################################################ @@ -366,6 +393,7 @@ class IntermediateReportCitedDocs(BaseObj): PythonToolStart, PythonToolDelta, CustomToolStart, + CustomToolArgs, CustomToolDelta, FileReaderStart, FileReaderResult, @@ -379,6 +407,7 @@ class IntermediateReportCitedDocs(BaseObj): # Citation Packets CitationInfo, ToolCallDebug, + ToolCallArgumentDelta, # Deep Research Packets DeepResearchPlanStart, DeepResearchPlanDelta, diff --git a/backend/onyx/server/query_and_chat/streaming_utils.py b/backend/onyx/server/query_and_chat/streaming_utils.py index 2b93a84cddc..c5f8e72412a 100644 --- a/backend/onyx/server/query_and_chat/streaming_utils.py +++ b/backend/onyx/server/query_and_chat/streaming_utils.py @@ -8,8 +8,6 @@ from onyx.server.query_and_chat.streaming_models import AgentResponseDelta from onyx.server.query_and_chat.streaming_models import AgentResponseStart from onyx.server.query_and_chat.streaming_models import CitationInfo -from onyx.server.query_and_chat.streaming_models import CustomToolDelta -from onyx.server.query_and_chat.streaming_models import CustomToolStart from onyx.server.query_and_chat.streaming_models import GeneratedImage from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart @@ -165,39 +163,6 @@ def create_image_generation_packets( return packets -def create_custom_tool_packets( - tool_name: str, - response_type: str, - turn_index: int, - data: dict | list | str | int | float | bool | None = None, - file_ids: list[str] | None = None, -) -> list[Packet]: - packets: list[Packet] = [] - - packets.append( - Packet( - placement=Placement(turn_index=turn_index), - obj=CustomToolStart(tool_name=tool_name), - ) - ) - - packets.append( - Packet( - placement=Placement(turn_index=turn_index), - obj=CustomToolDelta( - tool_name=tool_name, - response_type=response_type, - data=data, - file_ids=file_ids, - ), - ), - ) - - packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd())) - - return packets - - def create_fetch_packets( fetch_docs: list[SavedSearchDoc], urls: list[str], diff --git a/backend/onyx/server/settings/models.py b/backend/onyx/server/settings/models.py index 3eabefb2254..32870c030b4 100644 --- a/backend/onyx/server/settings/models.py +++ b/backend/onyx/server/settings/models.py @@ -60,9 +60,11 @@ class Settings(BaseModel): deep_research_enabled: bool | None = None search_ui_enabled: bool | None = None - # Enterprise features flag - set by license enforcement at runtime - # When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status - # When LICENSE_ENFORCEMENT_ENABLED=false, defaults to False + # Whether EE features are unlocked for use. + # Depends on license status: True when the user has a valid license + # (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER), False when there's no license + # or the license is expired (GATED_ACCESS). + # This controls UI visibility of EE features (user groups, analytics, RBAC, etc.). ee_features_enabled: bool = False temperature_override_enabled: bool | None = False @@ -76,6 +78,7 @@ class Settings(BaseModel): # User Knowledge settings user_knowledge_enabled: bool | None = True + user_file_max_upload_size_mb: int | None = None # Connector settings show_extra_connectors: bool | None = True diff --git a/backend/onyx/server/settings/store.py b/backend/onyx/server/settings/store.py index 5c6f4a4aaba..5b058b32172 100644 --- a/backend/onyx/server/settings/store.py +++ b/backend/onyx/server/settings/store.py @@ -1,16 +1,15 @@ +from onyx.cache.factory import get_cache_backend from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS +from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB from onyx.configs.constants import KV_SETTINGS_KEY from onyx.configs.constants import OnyxRedisLocks from onyx.key_value_store.factory import get_kv_store from onyx.key_value_store.interface import KvKeyNotFoundError -from onyx.redis.redis_pool import get_redis_client from onyx.server.settings.models import Settings from onyx.utils.logger import setup_logger -from shared_configs.configs import MULTI_TENANT -from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() @@ -33,44 +32,36 @@ def load_settings() -> Settings: logger.error(f"Error loading settings from KV store: {str(e)}") settings = Settings() - tenant_id = get_current_tenant_id() if MULTI_TENANT else None - redis_client = get_redis_client(tenant_id=tenant_id) + cache = get_cache_backend() try: - value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED) + value = cache.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED) if value is not None: - assert isinstance(value, bytes) anonymous_user_enabled = int(value.decode("utf-8")) == 1 else: - # Default to False anonymous_user_enabled = False - # Optionally store the default back to Redis - redis_client.set( - OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL - ) + cache.set(OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL) except Exception as e: - # Log the error and reset to default - logger.error(f"Error loading anonymous user setting from Redis: {str(e)}") + logger.error(f"Error loading anonymous user setting from cache: {str(e)}") anonymous_user_enabled = False settings.anonymous_user_enabled = anonymous_user_enabled settings.query_history_type = ONYX_QUERY_HISTORY_TYPE - # Override user knowledge setting if disabled via environment variable if DISABLE_USER_KNOWLEDGE: settings.user_knowledge_enabled = False + settings.user_file_max_upload_size_mb = USER_FILE_MAX_UPLOAD_SIZE_MB settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS settings.opensearch_indexing_enabled = ENABLE_OPENSEARCH_INDEXING_FOR_ONYX return settings def store_settings(settings: Settings) -> None: - tenant_id = get_current_tenant_id() if MULTI_TENANT else None - redis_client = get_redis_client(tenant_id=tenant_id) + cache = get_cache_backend() if settings.anonymous_user_enabled is not None: - redis_client.set( + cache.set( OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "1" if settings.anonymous_user_enabled else "0", ex=SETTINGS_TTL, diff --git a/backend/onyx/setup.py b/backend/onyx/setup.py index ee3af79e0f3..085ae6fb000 100644 --- a/backend/onyx/setup.py +++ b/backend/onyx/setup.py @@ -275,9 +275,13 @@ def setup_postgres(db_session: Session) -> None: ], api_key_changed=True, ) - new_llm_provider = upsert_llm_provider( - llm_provider_upsert_request=model_req, db_session=db_session - ) + try: + new_llm_provider = upsert_llm_provider( + llm_provider_upsert_request=model_req, db_session=db_session + ) + except ValueError as e: + logger.warning("Failed to upsert LLM provider during setup: %s", e) + return update_default_provider( provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session ) diff --git a/backend/onyx/tools/built_in_tools.py b/backend/onyx/tools/built_in_tools.py index ca7bcbfa669..bf1ba9725a1 100644 --- a/backend/onyx/tools/built_in_tools.py +++ b/backend/onyx/tools/built_in_tools.py @@ -56,3 +56,23 @@ def get_built_in_tool_ids() -> list[str]: def get_built_in_tool_by_id(in_code_tool_id: str) -> Type[BUILT_IN_TOOL_TYPES]: return BUILT_IN_TOOL_MAP[in_code_tool_id] + + +def _build_tool_name_to_class() -> dict[str, Type[BUILT_IN_TOOL_TYPES]]: + """Build a mapping from LLM-facing tool name to tool class.""" + result: dict[str, Type[BUILT_IN_TOOL_TYPES]] = {} + for cls in BUILT_IN_TOOL_MAP.values(): + name_attr = cls.__dict__.get("name") + if isinstance(name_attr, property) and name_attr.fget is not None: + tool_name = name_attr.fget(cls) + elif isinstance(name_attr, str): + tool_name = name_attr + else: + raise ValueError( + f"Built-in tool {cls.__name__} must define a valid LLM-facing tool name" + ) + result[tool_name] = cls + return result + + +TOOL_NAME_TO_CLASS: dict[str, Type[BUILT_IN_TOOL_TYPES]] = _build_tool_name_to_class() diff --git a/backend/onyx/tools/interface.py b/backend/onyx/tools/interface.py index 924833e65ce..05bcb005e90 100644 --- a/backend/onyx/tools/interface.py +++ b/backend/onyx/tools/interface.py @@ -92,3 +92,7 @@ def run( **llm_kwargs: Any, ) -> ToolResponse: raise NotImplementedError + + @classmethod + def should_emit_argument_deltas(cls) -> bool: + return False diff --git a/backend/onyx/tools/models.py b/backend/onyx/tools/models.py index 686b9bf2a69..632e8eb05a2 100644 --- a/backend/onyx/tools/models.py +++ b/backend/onyx/tools/models.py @@ -18,6 +18,7 @@ from onyx.context.search.models import SearchDocsResponse from onyx.db.memory import UserMemoryContext from onyx.server.query_and_chat.placement import Placement +from onyx.server.query_and_chat.streaming_models import CustomToolErrorInfo from onyx.server.query_and_chat.streaming_models import GeneratedImage from onyx.tools.tool_implementations.images.models import FinalImageGenerationResponse from onyx.tools.tool_implementations.memory.models import MemoryToolResponse @@ -61,6 +62,7 @@ class CustomToolCallSummary(BaseModel): tool_name: str response_type: str # e.g., 'json', 'image', 'csv', 'graph' tool_result: Any # The response data + error: CustomToolErrorInfo | None = None class ToolCallKickoff(BaseModel): @@ -93,6 +95,8 @@ class ToolResponse(BaseModel): # | WebContentResponse # This comes from custom tools, tool result needs to be saved | CustomToolCallSummary + # This comes from code interpreter, carries generated files + | PythonToolRichResponse # If the rich response is a string, this is what's saved to the tool call in the DB | str | None # If nothing needs to be persisted outside of the string value passed to the LLM @@ -193,6 +197,12 @@ class ChatFile(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) +class PythonToolRichResponse(BaseModel): + """Rich response from the Python tool carrying generated files.""" + + generated_files: list[PythonExecutionFile] = [] + + class PythonToolOverrideKwargs(BaseModel): """Override kwargs for the Python/Code Interpreter tool.""" @@ -245,6 +255,7 @@ class ToolCallInfo(BaseModel): tool_call_response: str search_docs: list[SearchDoc] | None = None generated_images: list[GeneratedImage] | None = None + generated_files: list[PythonExecutionFile] | None = None CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID" diff --git a/backend/onyx/tools/tool_implementations/custom/custom_tool.py b/backend/onyx/tools/tool_implementations/custom/custom_tool.py index 30e49b8becd..c1e95287e0d 100644 --- a/backend/onyx/tools/tool_implementations/custom/custom_tool.py +++ b/backend/onyx/tools/tool_implementations/custom/custom_tool.py @@ -15,7 +15,9 @@ from onyx.configs.constants import FileOrigin from onyx.file_store.file_store import get_default_file_store from onyx.server.query_and_chat.placement import Placement +from onyx.server.query_and_chat.streaming_models import CustomToolArgs from onyx.server.query_and_chat.streaming_models import CustomToolDelta +from onyx.server.query_and_chat.streaming_models import CustomToolErrorInfo from onyx.server.query_and_chat.streaming_models import CustomToolStart from onyx.server.query_and_chat.streaming_models import Packet from onyx.tools.interface import Tool @@ -139,7 +141,7 @@ def emit_start(self, placement: Placement) -> None: self.emitter.emit( Packet( placement=placement, - obj=CustomToolStart(tool_name=self._name), + obj=CustomToolStart(tool_name=self._name, tool_id=self._id), ) ) @@ -149,10 +151,8 @@ def run( override_kwargs: None = None, # noqa: ARG002 **llm_kwargs: Any, ) -> ToolResponse: - request_body = llm_kwargs.get(REQUEST_BODY) - + # Build path params path_params = {} - for path_param_schema in self._method_spec.get_path_param_schemas(): param_name = path_param_schema["name"] if param_name not in llm_kwargs: @@ -165,6 +165,7 @@ def run( ) path_params[param_name] = llm_kwargs[param_name] + # Build query params query_params = {} for query_param_schema in self._method_spec.get_query_param_schemas(): if query_param_schema["name"] in llm_kwargs: @@ -172,6 +173,20 @@ def run( query_param_schema["name"] ] + # Emit args packet (path + query params only, no request body) + tool_args = {**path_params, **query_params} + if tool_args: + self.emitter.emit( + Packet( + placement=placement, + obj=CustomToolArgs( + tool_name=self._name, + tool_args=tool_args, + ), + ) + ) + + request_body = llm_kwargs.get(REQUEST_BODY) url = self._method_spec.build_url(self._base_url, path_params, query_params) method = self._method_spec.method @@ -180,6 +195,18 @@ def run( ) content_type = response.headers.get("Content-Type", "") + # Detect HTTP errors — only 401/403 are flagged as auth errors + error_info: CustomToolErrorInfo | None = None + if response.status_code in (401, 403): + error_info = CustomToolErrorInfo( + is_auth_error=True, + status_code=response.status_code, + message=f"{self._name} action failed because of authentication error", + ) + logger.warning( + f"Auth error from custom tool '{self._name}': HTTP {response.status_code}" + ) + tool_result: Any response_type: str file_ids: List[str] | None = None @@ -222,9 +249,11 @@ def run( placement=placement, obj=CustomToolDelta( tool_name=self._name, + tool_id=self._id, response_type=response_type, data=data, file_ids=file_ids, + error=error_info, ), ) ) @@ -236,6 +265,7 @@ def run( tool_name=self._name, response_type=response_type, tool_result=tool_result, + error=error_info, ), llm_facing_response=llm_facing_response, ) diff --git a/backend/onyx/tools/tool_implementations/open_url/snippet_matcher.py b/backend/onyx/tools/tool_implementations/open_url/snippet_matcher.py index b9d1fe0c00f..7c8c078ae03 100644 --- a/backend/onyx/tools/tool_implementations/open_url/snippet_matcher.py +++ b/backend/onyx/tools/tool_implementations/open_url/snippet_matcher.py @@ -111,19 +111,26 @@ def _normalize_text_with_mapping(text: str) -> tuple[str, list[int]]: # Step 1: NFC normalization with position mapping nfc_text = unicodedata.normalize("NFC", text) - # Build mapping from NFC positions to original start positions + # Map NFD positions → original positions. + # NFD only decomposes, so each original char produces 1+ NFD chars. + nfd_to_orig: list[int] = [] + for orig_idx, orig_char in enumerate(original_text): + nfd_of_char = unicodedata.normalize("NFD", orig_char) + for _ in nfd_of_char: + nfd_to_orig.append(orig_idx) + + # Map NFC positions → NFD positions. + # Each NFC char, when decomposed, tells us exactly how many NFD + # chars it was composed from. nfc_to_orig: list[int] = [] - orig_idx = 0 + nfd_idx = 0 for nfc_char in nfc_text: - nfc_to_orig.append(orig_idx) - # Find how many original chars contributed to this NFC char - for length in range(1, len(original_text) - orig_idx + 1): - substr = original_text[orig_idx : orig_idx + length] - if unicodedata.normalize("NFC", substr) == nfc_char: - orig_idx += length - break + if nfd_idx < len(nfd_to_orig): + nfc_to_orig.append(nfd_to_orig[nfd_idx]) else: - orig_idx += 1 # Fallback + nfc_to_orig.append(len(original_text) - 1) + nfd_of_nfc = unicodedata.normalize("NFD", nfc_char) + nfd_idx += len(nfd_of_nfc) # Work with NFC text from here text = nfc_text diff --git a/backend/onyx/tools/tool_implementations/python/code_interpreter_client.py b/backend/onyx/tools/tool_implementations/python/code_interpreter_client.py index e2de2e4b4d3..442b76a9f20 100644 --- a/backend/onyx/tools/tool_implementations/python/code_interpreter_client.py +++ b/backend/onyx/tools/tool_implementations/python/code_interpreter_client.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import json +import time from collections.abc import Generator from typing import Literal from typing import TypedDict @@ -12,6 +15,9 @@ logger = setup_logger() +_HEALTH_CACHE_TTL_SECONDS = 30 +_health_cache: dict[str, tuple[float, bool]] = {} + class FileInput(TypedDict): """Input file to be staged in execution workspace""" @@ -80,6 +86,19 @@ def __init__(self, base_url: str | None = CODE_INTERPRETER_BASE_URL): raise ValueError("CODE_INTERPRETER_BASE_URL not configured") self.base_url = base_url.rstrip("/") self.session = requests.Session() + self._closed = False + + def __enter__(self) -> CodeInterpreterClient: + return self + + def __exit__(self, *args: object) -> None: + self.close() + + def close(self) -> None: + if self._closed: + return + self.session.close() + self._closed = True def _build_payload( self, @@ -98,16 +117,32 @@ def _build_payload( payload["files"] = files return payload - def health(self) -> bool: - """Check if the Code Interpreter service is healthy""" + def health(self, use_cache: bool = False) -> bool: + """Check if the Code Interpreter service is healthy + + Args: + use_cache: When True, return a cached result if available and + within the TTL window. The cache is always populated + after a live request regardless of this flag. + """ + if use_cache: + cached = _health_cache.get(self.base_url) + if cached is not None: + cached_at, cached_result = cached + if time.monotonic() - cached_at < _HEALTH_CACHE_TTL_SECONDS: + return cached_result + url = f"{self.base_url}/health" try: response = self.session.get(url, timeout=5) response.raise_for_status() - return response.json().get("status") == "ok" + result = response.json().get("status") == "ok" except Exception as e: logger.warning(f"Exception caught when checking health, e={e}") - return False + result = False + + _health_cache[self.base_url] = (time.monotonic(), result) + return result def execute( self, @@ -157,8 +192,11 @@ def execute_streaming( yield from self._batch_as_stream(code, stdin, timeout_ms, files) return - response.raise_for_status() - yield from self._parse_sse(response) + try: + response.raise_for_status() + yield from self._parse_sse(response) + finally: + response.close() def _parse_sse( self, response: requests.Response diff --git a/backend/onyx/tools/tool_implementations/python/python_tool.py b/backend/onyx/tools/tool_implementations/python/python_tool.py index 5f0fafb5624..8be2fc1817f 100644 --- a/backend/onyx/tools/tool_implementations/python/python_tool.py +++ b/backend/onyx/tools/tool_implementations/python/python_tool.py @@ -23,6 +23,7 @@ from onyx.tools.models import LlmPythonExecutionResult from onyx.tools.models import PythonExecutionFile from onyx.tools.models import PythonToolOverrideKwargs +from onyx.tools.models import PythonToolRichResponse from onyx.tools.models import ToolCallException from onyx.tools.models import ToolResponse from onyx.tools.tool_implementations.python.code_interpreter_client import ( @@ -107,7 +108,11 @@ def is_available(cls, db_session: Session) -> bool: if not CODE_INTERPRETER_BASE_URL: return False server = fetch_code_interpreter_server(db_session) - return server.server_enabled + if not server.server_enabled: + return False + + with CodeInterpreterClient() as client: + return client.health(use_cache=True) def tool_definition(self) -> dict: return { @@ -171,194 +176,208 @@ def run( ) ) - # Create Code Interpreter client - client = CodeInterpreterClient() + # Create Code Interpreter client — context manager ensures + # session.close() is called on every exit path. + with CodeInterpreterClient() as client: + # Stage chat files for execution + files_to_stage: list[FileInput] = [] + for ind, chat_file in enumerate(chat_files): + file_name = chat_file.filename or f"file_{ind}" + try: + # Upload to Code Interpreter + ci_file_id = client.upload_file(chat_file.content, file_name) - # Stage chat files for execution - files_to_stage: list[FileInput] = [] - for ind, chat_file in enumerate(chat_files): - file_name = chat_file.filename or f"file_{ind}" - try: - # Upload to Code Interpreter - ci_file_id = client.upload_file(chat_file.content, file_name) + # Stage for execution + files_to_stage.append({"path": file_name, "file_id": ci_file_id}) - # Stage for execution - files_to_stage.append({"path": file_name, "file_id": ci_file_id}) + logger.info(f"Staged file for Python execution: {file_name}") - logger.info(f"Staged file for Python execution: {file_name}") + except Exception as e: + logger.warning(f"Failed to stage file {file_name}: {e}") - except Exception as e: - logger.warning(f"Failed to stage file {file_name}: {e}") - - try: - logger.debug(f"Executing code: {code}") - - # Execute code with streaming (falls back to batch if unavailable) - stdout_parts: list[str] = [] - stderr_parts: list[str] = [] - result_event: StreamResultEvent | None = None - - for event in client.execute_streaming( - code=code, - timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS, - files=files_to_stage or None, - ): - if isinstance(event, StreamOutputEvent): - if event.stream == "stdout": - stdout_parts.append(event.data) - else: - stderr_parts.append(event.data) - # Emit incremental delta to frontend - self.emitter.emit( - Packet( - placement=placement, - obj=PythonToolDelta( - stdout=event.data if event.stream == "stdout" else "", - stderr=event.data if event.stream == "stderr" else "", - ), + try: + logger.debug(f"Executing code: {code}") + + # Execute code with streaming (falls back to batch if unavailable) + stdout_parts: list[str] = [] + stderr_parts: list[str] = [] + result_event: StreamResultEvent | None = None + + for event in client.execute_streaming( + code=code, + timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS, + files=files_to_stage or None, + ): + if isinstance(event, StreamOutputEvent): + if event.stream == "stdout": + stdout_parts.append(event.data) + else: + stderr_parts.append(event.data) + # Emit incremental delta to frontend + self.emitter.emit( + Packet( + placement=placement, + obj=PythonToolDelta( + stdout=( + event.data if event.stream == "stdout" else "" + ), + stderr=( + event.data if event.stream == "stderr" else "" + ), + ), + ) ) + elif isinstance(event, StreamResultEvent): + result_event = event + elif isinstance(event, StreamErrorEvent): + raise RuntimeError(f"Code interpreter error: {event.message}") + + if result_event is None: + raise RuntimeError( + "Code interpreter stream ended without a result event" ) - elif isinstance(event, StreamResultEvent): - result_event = event - elif isinstance(event, StreamErrorEvent): - raise RuntimeError(f"Code interpreter error: {event.message}") - - if result_event is None: - raise RuntimeError( - "Code interpreter stream ended without a result event" + + full_stdout = "".join(stdout_parts) + full_stderr = "".join(stderr_parts) + + # Truncate output for LLM consumption + truncated_stdout = _truncate_output( + full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout" + ) + truncated_stderr = _truncate_output( + full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr" ) - full_stdout = "".join(stdout_parts) - full_stderr = "".join(stderr_parts) + # Handle generated files + generated_files: list[PythonExecutionFile] = [] + generated_file_ids: list[str] = [] + file_ids_to_cleanup: list[str] = [] + file_store = get_default_file_store() + + for workspace_file in result_event.files: + if workspace_file.kind != "file" or not workspace_file.file_id: + continue + + try: + # Download file from Code Interpreter + file_content = client.download_file(workspace_file.file_id) + + # Determine MIME type from file extension + filename = workspace_file.path.split("/")[-1] + mime_type, _ = mimetypes.guess_type(filename) + # Default to binary if we can't determine the type + mime_type = mime_type or "application/octet-stream" + + # Save to Onyx file store + onyx_file_id = file_store.save_file( + content=BytesIO(file_content), + display_name=filename, + file_origin=FileOrigin.CHAT_UPLOAD, + file_type=mime_type, + ) - # Truncate output for LLM consumption - truncated_stdout = _truncate_output( - full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout" - ) - truncated_stderr = _truncate_output( - full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr" - ) + generated_files.append( + PythonExecutionFile( + filename=filename, + file_link=build_full_frontend_file_url(onyx_file_id), + ) + ) + generated_file_ids.append(onyx_file_id) - # Handle generated files - generated_files: list[PythonExecutionFile] = [] - generated_file_ids: list[str] = [] - file_ids_to_cleanup: list[str] = [] - file_store = get_default_file_store() + # Mark for cleanup + file_ids_to_cleanup.append(workspace_file.file_id) - for workspace_file in result_event.files: - if workspace_file.kind != "file" or not workspace_file.file_id: - continue + except Exception as e: + logger.error( + f"Failed to handle generated file " + f"{workspace_file.path}: {e}" + ) - try: - # Download file from Code Interpreter - file_content = client.download_file(workspace_file.file_id) - - # Determine MIME type from file extension - filename = workspace_file.path.split("/")[-1] - mime_type, _ = mimetypes.guess_type(filename) - # Default to binary if we can't determine the type - mime_type = mime_type or "application/octet-stream" - - # Save to Onyx file store - onyx_file_id = file_store.save_file( - content=BytesIO(file_content), - display_name=filename, - file_origin=FileOrigin.CHAT_UPLOAD, - file_type=mime_type, - ) + # Cleanup Code Interpreter files (generated files) + for ci_file_id in file_ids_to_cleanup: + try: + client.delete_file(ci_file_id) + except Exception as e: + logger.error( + f"Failed to delete Code Interpreter generated " + f"file {ci_file_id}: {e}" + ) + + # Cleanup staged input files + for file_mapping in files_to_stage: + try: + client.delete_file(file_mapping["file_id"]) + except Exception as e: + logger.error( + f"Failed to delete Code Interpreter staged " + f"file {file_mapping['file_id']}: {e}" + ) - generated_files.append( - PythonExecutionFile( - filename=filename, - file_link=build_full_frontend_file_url(onyx_file_id), + # Emit file_ids once files are processed + if generated_file_ids: + self.emitter.emit( + Packet( + placement=placement, + obj=PythonToolDelta(file_ids=generated_file_ids), ) ) - generated_file_ids.append(onyx_file_id) - # Mark for cleanup - file_ids_to_cleanup.append(workspace_file.file_id) + # Build result + result = LlmPythonExecutionResult( + stdout=truncated_stdout, + stderr=truncated_stderr, + exit_code=result_event.exit_code, + timed_out=result_event.timed_out, + generated_files=generated_files, + error=(None if result_event.exit_code == 0 else truncated_stderr), + ) - except Exception as e: - logger.error( - f"Failed to handle generated file {workspace_file.path}: {e}" - ) + # Serialize result for LLM + adapter = TypeAdapter(LlmPythonExecutionResult) + llm_response = adapter.dump_json(result).decode() - # Cleanup Code Interpreter files (generated files) - for ci_file_id in file_ids_to_cleanup: - try: - client.delete_file(ci_file_id) - except Exception as e: - logger.error( - f"Failed to delete Code Interpreter generated file {ci_file_id}: {e}" - ) + return ToolResponse( + rich_response=PythonToolRichResponse( + generated_files=generated_files, + ), + llm_facing_response=llm_response, + ) - # Cleanup staged input files - for file_mapping in files_to_stage: - try: - client.delete_file(file_mapping["file_id"]) - except Exception as e: - logger.error( - f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}" - ) + except Exception as e: + logger.error(f"Python execution failed: {e}") + error_msg = str(e) - # Emit file_ids once files are processed - if generated_file_ids: + # Emit error delta self.emitter.emit( Packet( placement=placement, - obj=PythonToolDelta(file_ids=generated_file_ids), + obj=PythonToolDelta( + stdout="", + stderr=error_msg, + file_ids=[], + ), ) ) - # Build result - result = LlmPythonExecutionResult( - stdout=truncated_stdout, - stderr=truncated_stderr, - exit_code=result_event.exit_code, - timed_out=result_event.timed_out, - generated_files=generated_files, - error=None if result_event.exit_code == 0 else truncated_stderr, - ) - - # Serialize result for LLM - adapter = TypeAdapter(LlmPythonExecutionResult) - llm_response = adapter.dump_json(result).decode() - - return ToolResponse( - rich_response=None, # No rich response needed for Python tool - llm_facing_response=llm_response, - ) - - except Exception as e: - logger.error(f"Python execution failed: {e}") - error_msg = str(e) - - # Emit error delta - self.emitter.emit( - Packet( - placement=placement, - obj=PythonToolDelta( - stdout="", - stderr=error_msg, - file_ids=[], - ), + # Return error result + result = LlmPythonExecutionResult( + stdout="", + stderr=error_msg, + exit_code=-1, + timed_out=False, + generated_files=[], + error=error_msg, ) - ) - # Return error result - result = LlmPythonExecutionResult( - stdout="", - stderr=error_msg, - exit_code=-1, - timed_out=False, - generated_files=[], - error=error_msg, - ) + adapter = TypeAdapter(LlmPythonExecutionResult) + llm_response = adapter.dump_json(result).decode() - adapter = TypeAdapter(LlmPythonExecutionResult) - llm_response = adapter.dump_json(result).decode() + return ToolResponse( + rich_response=None, + llm_facing_response=llm_response, + ) - return ToolResponse( - rich_response=None, - llm_facing_response=llm_response, - ) + @classmethod + @override + def should_emit_argument_deltas(cls) -> bool: + return True diff --git a/backend/onyx/tools/tool_implementations/web_search/web_search_tool.py b/backend/onyx/tools/tool_implementations/web_search/web_search_tool.py index 6c1adfc1923..cb06368d253 100644 --- a/backend/onyx/tools/tool_implementations/web_search/web_search_tool.py +++ b/backend/onyx/tools/tool_implementations/web_search/web_search_tool.py @@ -1,6 +1,5 @@ import json from typing import Any -from typing import cast from sqlalchemy.orm import Session from typing_extensions import override @@ -57,6 +56,30 @@ def _sanitize_query(query: str) -> str: return " ".join(sanitized.split()) +def _normalize_queries_input(raw: Any) -> list[str]: + """Coerce LLM output to a list of sanitized query strings. + + Accepts a bare string or a list (possibly with non-string elements). + Sanitizes each query (strip control chars, normalize whitespace) and + drops empty or whitespace-only entries. + """ + if isinstance(raw, str): + raw = raw.strip() + if not raw: + return [] + raw = [raw] + elif not isinstance(raw, list): + return [] + result: list[str] = [] + for q in raw: + if q is None: + continue + sanitized = _sanitize_query(str(q)) + if sanitized: + result.append(sanitized) + return result + + class WebSearchTool(Tool[WebSearchToolOverrideKwargs]): NAME = "web_search" DESCRIPTION = "Search the web for information." @@ -189,13 +212,7 @@ def run( f'like: {{"queries": ["your search query here"]}}' ), ) - raw_queries = cast(list[str], llm_kwargs[QUERIES_FIELD]) - - # Normalize queries: - # - remove control characters (null bytes, etc.) that LLMs sometimes produce - # - collapse whitespace and strip - # - drop empty/whitespace-only queries - queries = [sanitized for q in raw_queries if (sanitized := _sanitize_query(q))] + queries = _normalize_queries_input(llm_kwargs[QUERIES_FIELD]) if not queries: raise ToolCallException( message=( diff --git a/backend/onyx/utils/encryption.py b/backend/onyx/utils/encryption.py index 299b2326b82..d7385528207 100644 --- a/backend/onyx/utils/encryption.py +++ b/backend/onyx/utils/encryption.py @@ -11,16 +11,20 @@ # IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation -def _encrypt_string(input_str: str) -> bytes: +def _encrypt_string(input_str: str, key: str | None = None) -> bytes: # noqa: ARG001 if ENCRYPTION_KEY_SECRET: logger.warning("MIT version of Onyx does not support encryption of secrets.") + elif key is not None: + logger.debug("MIT encrypt called with explicit key — key ignored.") return input_str.encode() # IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation -def _decrypt_bytes(input_bytes: bytes) -> str: - # No need to double warn. If you wish to learn more about encryption features - # refer to the Onyx EE code +def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str: # noqa: ARG001 + if ENCRYPTION_KEY_SECRET: + logger.warning("MIT version of Onyx does not support decryption of secrets.") + elif key is not None: + logger.debug("MIT decrypt called with explicit key — key ignored.") return input_bytes.decode() @@ -86,15 +90,15 @@ def _mask_list(items: list[Any]) -> list[Any]: return masked -def encrypt_string_to_bytes(intput_str: str) -> bytes: +def encrypt_string_to_bytes(intput_str: str, key: str | None = None) -> bytes: versioned_encryption_fn = fetch_versioned_implementation( "onyx.utils.encryption", "_encrypt_string" ) - return versioned_encryption_fn(intput_str) + return versioned_encryption_fn(intput_str, key=key) -def decrypt_bytes_to_string(intput_bytes: bytes) -> str: +def decrypt_bytes_to_string(intput_bytes: bytes, key: str | None = None) -> str: versioned_decryption_fn = fetch_versioned_implementation( "onyx.utils.encryption", "_decrypt_bytes" ) - return versioned_decryption_fn(intput_bytes) + return versioned_decryption_fn(intput_bytes, key=key) diff --git a/backend/onyx/utils/jsonriver/__init__.py b/backend/onyx/utils/jsonriver/__init__.py new file mode 100644 index 00000000000..02acb52816f --- /dev/null +++ b/backend/onyx/utils/jsonriver/__init__.py @@ -0,0 +1,17 @@ +""" +jsonriver - A streaming JSON parser for Python + +Parse JSON incrementally as it streams in, e.g. from a network request or a language model. +Gives you a sequence of increasingly complete values. + +Copyright (c) 2023 Google LLC (original TypeScript implementation) +Copyright (c) 2024 jsonriver-python contributors (Python port) +SPDX-License-Identifier: BSD-3-Clause +""" + +from .parse import _Parser as Parser +from .parse import JsonObject +from .parse import JsonValue + +__all__ = ["Parser", "JsonValue", "JsonObject"] +__version__ = "0.0.1" diff --git a/backend/onyx/utils/jsonriver/parse.py b/backend/onyx/utils/jsonriver/parse.py new file mode 100644 index 00000000000..ee65e79c69a --- /dev/null +++ b/backend/onyx/utils/jsonriver/parse.py @@ -0,0 +1,427 @@ +""" +JSON parser for streaming incremental parsing + +Copyright (c) 2023 Google LLC (original TypeScript implementation) +Copyright (c) 2024 jsonriver-python contributors (Python port) +SPDX-License-Identifier: BSD-3-Clause +""" + +from __future__ import annotations + +import copy +from enum import IntEnum +from typing import cast +from typing import Union + +from .tokenize import _Input +from .tokenize import json_token_type_to_string +from .tokenize import JsonTokenType +from .tokenize import Tokenizer + + +# Type definitions for JSON values +JsonValue = Union[None, bool, float, str, list["JsonValue"], dict[str, "JsonValue"]] +JsonObject = dict[str, JsonValue] + + +class _StateEnum(IntEnum): + """Parser state machine states""" + + Initial = 0 + InString = 1 + InArray = 2 + InObjectExpectingKey = 3 + InObjectExpectingValue = 4 + + +class _State: + """Base class for parser states""" + + type: _StateEnum + value: JsonValue | tuple[str, JsonObject] | None + + +class _InitialState(_State): + """Initial state before any parsing""" + + def __init__(self) -> None: + self.type = _StateEnum.Initial + self.value = None + + +class _InStringState(_State): + """State while parsing a string""" + + def __init__(self) -> None: + self.type = _StateEnum.InString + self.value = "" + + +class _InArrayState(_State): + """State while parsing an array""" + + def __init__(self) -> None: + self.type = _StateEnum.InArray + self.value: list[JsonValue] = [] + + +class _InObjectExpectingKeyState(_State): + """State while parsing an object, expecting a key""" + + def __init__(self) -> None: + self.type = _StateEnum.InObjectExpectingKey + self.value: JsonObject = {} + + +class _InObjectExpectingValueState(_State): + """State while parsing an object, expecting a value""" + + def __init__(self, key: str, obj: JsonObject) -> None: + self.type = _StateEnum.InObjectExpectingValue + self.value = (key, obj) + + +# Sentinel value to distinguish "not set" from "set to None/null" +class _Unset: + pass + + +_UNSET = _Unset() + + +class _Parser: + """ + Incremental JSON parser + + Feed chunks of JSON text via feed() and get back progressively + more complete JSON values. + """ + + def __init__(self) -> None: + self._state_stack: list[_State] = [_InitialState()] + self._toplevel_value: JsonValue | _Unset = _UNSET + self._input = _Input() + self.tokenizer = Tokenizer(self._input, self) + self._finished = False + self._progressed = False + self._prev_snapshot: JsonValue | _Unset = _UNSET + + def feed(self, chunk: str) -> list[JsonValue]: + """ + Feed a chunk of JSON text and return deltas from the previous state. + + Each element in the returned list represents what changed since the + last yielded value. For dicts, only changed/new keys are included, + with string values containing only the newly appended characters. + """ + if self._finished: + return [] + + self._input.feed(chunk) + return self._collect_deltas() + + @staticmethod + def _compute_delta(prev: JsonValue | None, current: JsonValue) -> JsonValue | None: + if prev is None: + return current + + if isinstance(current, dict) and isinstance(prev, dict): + result: JsonObject = {} + for key in current: + cur_val = current[key] + prev_val = prev.get(key) + if key not in prev: + result[key] = cur_val + elif isinstance(cur_val, str) and isinstance(prev_val, str): + if cur_val != prev_val: + result[key] = cur_val[len(prev_val) :] + elif isinstance(cur_val, list) and isinstance(prev_val, list): + if cur_val != prev_val: + new_items = cur_val[len(prev_val) :] + # check if the last existing element was updated + if ( + prev_val + and len(cur_val) >= len(prev_val) + and cur_val[len(prev_val) - 1] != prev_val[-1] + ): + result[key] = [cur_val[len(prev_val) - 1]] + new_items + elif new_items: + result[key] = new_items + elif cur_val != prev_val: + result[key] = cur_val + return result if result else None + + if isinstance(current, str) and isinstance(prev, str): + delta = current[len(prev) :] + return delta if delta else None + + if isinstance(current, list) and isinstance(prev, list): + if current != prev: + new_items = current[len(prev) :] + if ( + prev + and len(current) >= len(prev) + and current[len(prev) - 1] != prev[-1] + ): + return [current[len(prev) - 1]] + new_items + return new_items if new_items else None + return None + + if current != prev: + return current + return None + + def finish(self) -> list[JsonValue]: + """Signal that no more chunks will be fed. Validates trailing content. + + Returns any final deltas produced by flushing pending tokens (e.g. + numbers, which have no terminator and wait for more input). + """ + self._input.mark_complete() + # Pump once more so the tokenizer can emit tokens that were waiting + # for more input (e.g. numbers need buffer_complete to finalize). + results = self._collect_deltas() + self._input.expect_end_of_content() + return results + + def _collect_deltas(self) -> list[JsonValue]: + """Run one pump cycle and return any deltas produced.""" + results: list[JsonValue] = [] + while True: + self._progressed = False + self.tokenizer.pump() + + if self._progressed: + if self._toplevel_value is _UNSET: + raise RuntimeError( + "Internal error: toplevel_value should not be unset " + "after progressing" + ) + current = copy.deepcopy(cast(JsonValue, self._toplevel_value)) + if isinstance(self._prev_snapshot, _Unset): + results.append(current) + else: + delta = self._compute_delta(self._prev_snapshot, current) + if delta is not None: + results.append(delta) + self._prev_snapshot = current + else: + if not self._state_stack: + self._finished = True + break + return results + + # TokenHandler protocol implementation + + def handle_null(self) -> None: + """Handle null token""" + self._handle_value_token(JsonTokenType.Null, None) + + def handle_boolean(self, value: bool) -> None: + """Handle boolean token""" + self._handle_value_token(JsonTokenType.Boolean, value) + + def handle_number(self, value: float) -> None: + """Handle number token""" + self._handle_value_token(JsonTokenType.Number, value) + + def handle_string_start(self) -> None: + """Handle string start token""" + state = self._current_state() + if not self._progressed and state.type != _StateEnum.InObjectExpectingKey: + self._progressed = True + + if state.type == _StateEnum.Initial: + self._state_stack.pop() + self._toplevel_value = self._progress_value(JsonTokenType.StringStart, None) + + elif state.type == _StateEnum.InArray: + v = self._progress_value(JsonTokenType.StringStart, None) + arr = cast(list[JsonValue], state.value) + arr.append(v) + + elif state.type == _StateEnum.InObjectExpectingKey: + self._state_stack.append(_InStringState()) + + elif state.type == _StateEnum.InObjectExpectingValue: + key, obj = cast(tuple[str, JsonObject], state.value) + sv = self._progress_value(JsonTokenType.StringStart, None) + obj[key] = sv + + elif state.type == _StateEnum.InString: + raise ValueError( + f"Unexpected {json_token_type_to_string(JsonTokenType.StringStart)} " + f"token in the middle of string" + ) + + def handle_string_middle(self, value: str) -> None: + """Handle string middle token""" + state = self._current_state() + + if not self._progressed: + if len(self._state_stack) >= 2: + prev = self._state_stack[-2] + if prev.type != _StateEnum.InObjectExpectingKey: + self._progressed = True + else: + self._progressed = True + + if state.type != _StateEnum.InString: + raise ValueError( + f"Unexpected {json_token_type_to_string(JsonTokenType.StringMiddle)} " + f"token when not in string" + ) + + assert isinstance(state.value, str) + state.value += value + + parent_state = self._state_stack[-2] if len(self._state_stack) >= 2 else None + self._update_string_parent(state.value, parent_state) + + def handle_string_end(self) -> None: + """Handle string end token""" + state = self._current_state() + + if state.type != _StateEnum.InString: + raise ValueError( + f"Unexpected {json_token_type_to_string(JsonTokenType.StringEnd)} " + f"token when not in string" + ) + + self._state_stack.pop() + parent_state = self._state_stack[-1] if self._state_stack else None + assert isinstance(state.value, str) + self._update_string_parent(state.value, parent_state) + + def handle_array_start(self) -> None: + """Handle array start token""" + self._handle_value_token(JsonTokenType.ArrayStart, None) + + def handle_array_end(self) -> None: + """Handle array end token""" + state = self._current_state() + if state.type != _StateEnum.InArray: + raise ValueError( + f"Unexpected {json_token_type_to_string(JsonTokenType.ArrayEnd)} token" + ) + self._state_stack.pop() + + def handle_object_start(self) -> None: + """Handle object start token""" + self._handle_value_token(JsonTokenType.ObjectStart, None) + + def handle_object_end(self) -> None: + """Handle object end token""" + state = self._current_state() + + if state.type in ( + _StateEnum.InObjectExpectingKey, + _StateEnum.InObjectExpectingValue, + ): + self._state_stack.pop() + else: + raise ValueError( + f"Unexpected {json_token_type_to_string(JsonTokenType.ObjectEnd)} token" + ) + + # Private helper methods + + def _current_state(self) -> _State: + """Get current parser state""" + if not self._state_stack: + raise ValueError("Unexpected trailing input") + return self._state_stack[-1] + + def _handle_value_token(self, token_type: JsonTokenType, value: JsonValue) -> None: + """Handle a complete value token""" + state = self._current_state() + + if not self._progressed: + self._progressed = True + + if state.type == _StateEnum.Initial: + self._state_stack.pop() + self._toplevel_value = self._progress_value(token_type, value) + + elif state.type == _StateEnum.InArray: + v = self._progress_value(token_type, value) + arr = cast(list[JsonValue], state.value) + arr.append(v) + + elif state.type == _StateEnum.InObjectExpectingValue: + key, obj = cast(tuple[str, JsonObject], state.value) + if token_type != JsonTokenType.StringStart: + self._state_stack.pop() + new_state = _InObjectExpectingKeyState() + new_state.value = obj + self._state_stack.append(new_state) + + v = self._progress_value(token_type, value) + obj[key] = v + + elif state.type == _StateEnum.InString: + raise ValueError( + f"Unexpected {json_token_type_to_string(token_type)} " + f"token in the middle of string" + ) + + elif state.type == _StateEnum.InObjectExpectingKey: + raise ValueError( + f"Unexpected {json_token_type_to_string(token_type)} " + f"token in the middle of object expecting key" + ) + + def _update_string_parent(self, updated: str, parent_state: _State | None) -> None: + """Update parent container with updated string value""" + if parent_state is None: + self._toplevel_value = updated + + elif parent_state.type == _StateEnum.InArray: + arr = cast(list[JsonValue], parent_state.value) + arr[-1] = updated + + elif parent_state.type == _StateEnum.InObjectExpectingValue: + key, obj = cast(tuple[str, JsonObject], parent_state.value) + obj[key] = updated + if self._state_stack and self._state_stack[-1] == parent_state: + self._state_stack.pop() + new_state = _InObjectExpectingKeyState() + new_state.value = obj + self._state_stack.append(new_state) + + elif parent_state.type == _StateEnum.InObjectExpectingKey: + if self._state_stack and self._state_stack[-1] == parent_state: + self._state_stack.pop() + obj = cast(JsonObject, parent_state.value) + self._state_stack.append(_InObjectExpectingValueState(updated, obj)) + + def _progress_value(self, token_type: JsonTokenType, value: JsonValue) -> JsonValue: + """Create initial value for a token and push appropriate state""" + if token_type == JsonTokenType.Null: + return None + + elif token_type == JsonTokenType.Boolean: + return value + + elif token_type == JsonTokenType.Number: + return value + + elif token_type == JsonTokenType.StringStart: + string_state = _InStringState() + self._state_stack.append(string_state) + return "" + + elif token_type == JsonTokenType.ArrayStart: + array_state = _InArrayState() + self._state_stack.append(array_state) + return array_state.value + + elif token_type == JsonTokenType.ObjectStart: + object_state = _InObjectExpectingKeyState() + self._state_stack.append(object_state) + return object_state.value + + else: + raise ValueError( + f"Unexpected token type: {json_token_type_to_string(token_type)}" + ) diff --git a/backend/onyx/utils/jsonriver/tokenize.py b/backend/onyx/utils/jsonriver/tokenize.py new file mode 100644 index 00000000000..fec2fe7ddde --- /dev/null +++ b/backend/onyx/utils/jsonriver/tokenize.py @@ -0,0 +1,514 @@ +""" +JSON tokenizer for streaming incremental parsing + +Copyright (c) 2023 Google LLC (original TypeScript implementation) +Copyright (c) 2024 jsonriver-python contributors (Python port) +SPDX-License-Identifier: BSD-3-Clause +""" + +from __future__ import annotations + +import re +from enum import IntEnum +from typing import Protocol + + +class TokenHandler(Protocol): + """Protocol for handling JSON tokens""" + + def handle_null(self) -> None: ... + def handle_boolean(self, value: bool) -> None: ... + def handle_number(self, value: float) -> None: ... + def handle_string_start(self) -> None: ... + def handle_string_middle(self, value: str) -> None: ... + def handle_string_end(self) -> None: ... + def handle_array_start(self) -> None: ... + def handle_array_end(self) -> None: ... + def handle_object_start(self) -> None: ... + def handle_object_end(self) -> None: ... + + +class JsonTokenType(IntEnum): + """Types of JSON tokens""" + + Null = 0 + Boolean = 1 + Number = 2 + StringStart = 3 + StringMiddle = 4 + StringEnd = 5 + ArrayStart = 6 + ArrayEnd = 7 + ObjectStart = 8 + ObjectEnd = 9 + + +def json_token_type_to_string(token_type: JsonTokenType) -> str: + """Convert token type to readable string""" + names = { + JsonTokenType.Null: "null", + JsonTokenType.Boolean: "boolean", + JsonTokenType.Number: "number", + JsonTokenType.StringStart: "string start", + JsonTokenType.StringMiddle: "string middle", + JsonTokenType.StringEnd: "string end", + JsonTokenType.ArrayStart: "array start", + JsonTokenType.ArrayEnd: "array end", + JsonTokenType.ObjectStart: "object start", + JsonTokenType.ObjectEnd: "object end", + } + return names[token_type] + + +class _State(IntEnum): + """Internal tokenizer states""" + + ExpectingValue = 0 + InString = 1 + StartArray = 2 + AfterArrayValue = 3 + StartObject = 4 + AfterObjectKey = 5 + AfterObjectValue = 6 + BeforeObjectKey = 7 + + +# Regex for validating JSON numbers +_JSON_NUMBER_PATTERN = re.compile(r"^-?(0|[1-9]\d*)(\.\d+)?([eE][+-]?\d+)?$") + + +def _parse_json_number(s: str) -> float: + """Parse a JSON number string, validating format""" + if not _JSON_NUMBER_PATTERN.match(s): + raise ValueError("Invalid number") + return float(s) + + +class _Input: + """ + Input buffer for chunk-based JSON parsing + + Manages buffering of input chunks and provides methods for + consuming and inspecting the buffer. + """ + + def __init__(self) -> None: + self._buffer = "" + self._start_index = 0 + self.buffer_complete = False + + def feed(self, chunk: str) -> None: + """Add a chunk of data to the buffer""" + self._buffer += chunk + + def mark_complete(self) -> None: + """Signal that no more chunks will be fed""" + self.buffer_complete = True + + @property + def length(self) -> int: + """Number of characters remaining in buffer""" + return len(self._buffer) - self._start_index + + def advance(self, length: int) -> None: + """Advance the start position by length characters""" + self._start_index += length + + def peek(self, offset: int) -> str | None: + """Peek at character at offset, or None if not available""" + idx = self._start_index + offset + if idx < len(self._buffer): + return self._buffer[idx] + return None + + def peek_char_code(self, offset: int) -> int: + """Get character code at offset""" + return ord(self._buffer[self._start_index + offset]) + + def slice(self, start: int, end: int) -> str: + """Slice buffer from start to end (relative to current position)""" + return self._buffer[self._start_index + start : self._start_index + end] + + def commit(self) -> None: + """Commit consumed content, removing it from buffer""" + if self._start_index > 0: + self._buffer = self._buffer[self._start_index :] + self._start_index = 0 + + def remaining(self) -> str: + """Get all remaining content in buffer""" + return self._buffer[self._start_index :] + + def expect_end_of_content(self) -> None: + """Verify no non-whitespace content remains""" + self.commit() + self.skip_past_whitespace() + if self.length != 0: + raise ValueError(f"Unexpected trailing content {self.remaining()!r}") + + def skip_past_whitespace(self) -> None: + """Skip whitespace characters""" + i = self._start_index + while i < len(self._buffer): + c = ord(self._buffer[i]) + if c in (32, 9, 10, 13): # space, tab, \n, \r + i += 1 + else: + break + self._start_index = i + + def try_to_take_prefix(self, prefix: str) -> bool: + """Try to consume prefix from buffer, return True if successful""" + if self._buffer.startswith(prefix, self._start_index): + self._start_index += len(prefix) + return True + return False + + def try_to_take(self, length: int) -> str | None: + """Try to take length characters, or None if not enough available""" + if self.length < length: + return None + result = self._buffer[self._start_index : self._start_index + length] + self._start_index += length + return result + + def try_to_take_char_code(self) -> int | None: + """Try to take a single character as char code, or None if buffer empty""" + if self.length == 0: + return None + code = ord(self._buffer[self._start_index]) + self._start_index += 1 + return code + + def take_until_quote_or_backslash(self) -> tuple[str, bool]: + """ + Consume input up to first quote or backslash + + Returns tuple of (consumed_content, pattern_found) + """ + buf = self._buffer + i = self._start_index + while i < len(buf): + c = ord(buf[i]) + if c <= 0x1F: + raise ValueError("Unescaped control character in string") + if c == 34 or c == 92: # " or \ + result = buf[self._start_index : i] + self._start_index = i + return (result, True) + i += 1 + + result = buf[self._start_index :] + self._start_index = len(buf) + return (result, False) + + +class Tokenizer: + """ + Tokenizer for chunk-based JSON parsing + + Processes chunks fed into its input buffer and calls handler methods + as JSON tokens are recognized. + """ + + def __init__(self, input: _Input, handler: TokenHandler) -> None: + self.input = input + self._handler = handler + self._stack: list[_State] = [_State.ExpectingValue] + self._emitted_tokens = 0 + + def is_done(self) -> bool: + """Check if tokenization is complete""" + return len(self._stack) == 0 and self.input.length == 0 + + def pump(self) -> None: + """Process all available tokens in the buffer""" + while True: + before = self._emitted_tokens + self._tokenize_more() + if self._emitted_tokens == before: + self.input.commit() + return + + def _tokenize_more(self) -> None: + """Process one step of tokenization based on current state""" + if not self._stack: + return + + state = self._stack[-1] + + if state == _State.ExpectingValue: + self._tokenize_value() + elif state == _State.InString: + self._tokenize_string() + elif state == _State.StartArray: + self._tokenize_array_start() + elif state == _State.AfterArrayValue: + self._tokenize_after_array_value() + elif state == _State.StartObject: + self._tokenize_object_start() + elif state == _State.AfterObjectKey: + self._tokenize_after_object_key() + elif state == _State.AfterObjectValue: + self._tokenize_after_object_value() + elif state == _State.BeforeObjectKey: + self._tokenize_before_object_key() + + def _tokenize_value(self) -> None: + """Tokenize a JSON value""" + self.input.skip_past_whitespace() + + if self.input.try_to_take_prefix("null"): + self._handler.handle_null() + self._emitted_tokens += 1 + self._stack.pop() + return + + if self.input.try_to_take_prefix("true"): + self._handler.handle_boolean(True) + self._emitted_tokens += 1 + self._stack.pop() + return + + if self.input.try_to_take_prefix("false"): + self._handler.handle_boolean(False) + self._emitted_tokens += 1 + self._stack.pop() + return + + if self.input.length > 0: + ch = self.input.peek_char_code(0) + if (48 <= ch <= 57) or ch == 45: # 0-9 or - + # Scan for end of number + i = 0 + while i < self.input.length: + c = self.input.peek_char_code(i) + if (48 <= c <= 57) or c in (45, 43, 46, 101, 69): # 0-9 - + . e E + i += 1 + else: + break + + if i == self.input.length and not self.input.buffer_complete: + # Need more input (numbers have no terminator) + return + + number_chars = self.input.slice(0, i) + self.input.advance(i) + number = _parse_json_number(number_chars) + self._handler.handle_number(number) + self._emitted_tokens += 1 + self._stack.pop() + return + + if self.input.try_to_take_prefix('"'): + self._stack.pop() + self._stack.append(_State.InString) + self._handler.handle_string_start() + self._emitted_tokens += 1 + self._tokenize_string() + return + + if self.input.try_to_take_prefix("["): + self._stack.pop() + self._stack.append(_State.StartArray) + self._handler.handle_array_start() + self._emitted_tokens += 1 + self._tokenize_array_start() + return + + if self.input.try_to_take_prefix("{"): + self._stack.pop() + self._stack.append(_State.StartObject) + self._handler.handle_object_start() + self._emitted_tokens += 1 + self._tokenize_object_start() + return + + def _tokenize_string(self) -> None: + """Tokenize string content""" + while True: + chunk, interrupted = self.input.take_until_quote_or_backslash() + if chunk: + self._handler.handle_string_middle(chunk) + self._emitted_tokens += 1 + elif not interrupted: + return + + if interrupted: + if self.input.length == 0: + return + + next_char = self.input.peek(0) + if next_char == '"': + self.input.advance(1) + self._handler.handle_string_end() + self._emitted_tokens += 1 + self._stack.pop() + return + + # Handle escape sequences + next_char2 = self.input.peek(1) + if next_char2 is None: + return + + value: str + if next_char2 == "u": + # Unicode escape: need 4 hex digits + if self.input.length < 6: + return + + code = 0 + for j in range(2, 6): + c = self.input.peek_char_code(j) + if 48 <= c <= 57: # 0-9 + digit = c - 48 + elif 65 <= c <= 70: # A-F + digit = c - 55 + elif 97 <= c <= 102: # a-f + digit = c - 87 + else: + raise ValueError("Bad Unicode escape in JSON") + code = (code << 4) | digit + + self.input.advance(6) + self._handler.handle_string_middle(chr(code)) + self._emitted_tokens += 1 + continue + + elif next_char2 == "n": + value = "\n" + elif next_char2 == "r": + value = "\r" + elif next_char2 == "t": + value = "\t" + elif next_char2 == "b": + value = "\b" + elif next_char2 == "f": + value = "\f" + elif next_char2 == "\\": + value = "\\" + elif next_char2 == "/": + value = "/" + elif next_char2 == '"': + value = '"' + else: + raise ValueError("Bad escape in string") + + self.input.advance(2) + self._handler.handle_string_middle(value) + self._emitted_tokens += 1 + + def _tokenize_array_start(self) -> None: + """Tokenize start of array (check for empty or first element)""" + self.input.skip_past_whitespace() + if self.input.length == 0: + return + + if self.input.try_to_take_prefix("]"): + self._handler.handle_array_end() + self._emitted_tokens += 1 + self._stack.pop() + return + + self._stack.pop() + self._stack.append(_State.AfterArrayValue) + self._stack.append(_State.ExpectingValue) + self._tokenize_value() + + def _tokenize_after_array_value(self) -> None: + """Tokenize after an array value (expect , or ])""" + self.input.skip_past_whitespace() + next_char = self.input.try_to_take_char_code() + + if next_char is None: + return + elif next_char == 0x5D: # ] + self._handler.handle_array_end() + self._emitted_tokens += 1 + self._stack.pop() + return + elif next_char == 0x2C: # , + self._stack.append(_State.ExpectingValue) + self._tokenize_value() + return + else: + raise ValueError(f"Expected , or ], got {chr(next_char)!r}") + + def _tokenize_object_start(self) -> None: + """Tokenize start of object (check for empty or first key)""" + self.input.skip_past_whitespace() + next_char = self.input.try_to_take_char_code() + + if next_char is None: + return + elif next_char == 0x7D: # } + self._handler.handle_object_end() + self._emitted_tokens += 1 + self._stack.pop() + return + elif next_char == 0x22: # " + self._stack.pop() + self._stack.append(_State.AfterObjectKey) + self._stack.append(_State.InString) + self._handler.handle_string_start() + self._emitted_tokens += 1 + self._tokenize_string() + return + else: + raise ValueError(f"Expected start of object key, got {chr(next_char)!r}") + + def _tokenize_after_object_key(self) -> None: + """Tokenize after object key (expect :)""" + self.input.skip_past_whitespace() + next_char = self.input.try_to_take_char_code() + + if next_char is None: + return + elif next_char == 0x3A: # : + self._stack.pop() + self._stack.append(_State.AfterObjectValue) + self._stack.append(_State.ExpectingValue) + self._tokenize_value() + return + else: + raise ValueError(f"Expected colon after object key, got {chr(next_char)!r}") + + def _tokenize_after_object_value(self) -> None: + """Tokenize after object value (expect , or })""" + self.input.skip_past_whitespace() + next_char = self.input.try_to_take_char_code() + + if next_char is None: + return + elif next_char == 0x7D: # } + self._handler.handle_object_end() + self._emitted_tokens += 1 + self._stack.pop() + return + elif next_char == 0x2C: # , + self._stack.pop() + self._stack.append(_State.BeforeObjectKey) + self._tokenize_before_object_key() + return + else: + raise ValueError( + f"Expected , or }} after object value, got {chr(next_char)!r}" + ) + + def _tokenize_before_object_key(self) -> None: + """Tokenize before object key (after comma)""" + self.input.skip_past_whitespace() + next_char = self.input.try_to_take_char_code() + + if next_char is None: + return + elif next_char == 0x22: # " + self._stack.pop() + self._stack.append(_State.AfterObjectKey) + self._stack.append(_State.InString) + self._handler.handle_string_start() + self._emitted_tokens += 1 + self._tokenize_string() + return + else: + raise ValueError(f"Expected start of object key, got {chr(next_char)!r}") diff --git a/backend/onyx/indexing/postgres_sanitization.py b/backend/onyx/utils/postgres_sanitization.py similarity index 58% rename from backend/onyx/indexing/postgres_sanitization.py rename to backend/onyx/utils/postgres_sanitization.py index 098e5aca05d..9ce6de11cad 100644 --- a/backend/onyx/indexing/postgres_sanitization.py +++ b/backend/onyx/utils/postgres_sanitization.py @@ -1,30 +1,49 @@ +import re from typing import Any from onyx.access.models import ExternalAccess from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode +from onyx.utils.logger import setup_logger +logger = setup_logger() -def _sanitize_string(value: str) -> str: - return value.replace("\x00", "") +_SURROGATE_RE = re.compile(r"[\ud800-\udfff]") -def _sanitize_json_like(value: Any) -> Any: +def sanitize_string(value: str) -> str: + """Strip characters that PostgreSQL text/JSONB columns cannot store. + + Removes: + - NUL bytes (\\x00) + - UTF-16 surrogates (\\ud800-\\udfff), which are invalid in UTF-8 + """ + sanitized = value.replace("\x00", "") + sanitized = _SURROGATE_RE.sub("", sanitized) + if value and not sanitized: + logger.warning( + "sanitize_string: all characters were removed from a non-empty string" + ) + return sanitized + + +def sanitize_json_like(value: Any) -> Any: + """Recursively sanitize all strings in a JSON-like structure (dict/list/tuple).""" if isinstance(value, str): - return _sanitize_string(value) + return sanitize_string(value) if isinstance(value, list): - return [_sanitize_json_like(item) for item in value] + return [sanitize_json_like(item) for item in value] if isinstance(value, tuple): - return tuple(_sanitize_json_like(item) for item in value) + return tuple(sanitize_json_like(item) for item in value) if isinstance(value, dict): sanitized: dict[Any, Any] = {} for key, nested_value in value.items(): - cleaned_key = _sanitize_string(key) if isinstance(key, str) else key - sanitized[cleaned_key] = _sanitize_json_like(nested_value) + cleaned_key = sanitize_string(key) if isinstance(key, str) else key + sanitized[cleaned_key] = sanitize_json_like(nested_value) return sanitized return value @@ -34,27 +53,27 @@ def _sanitize_expert_info(expert: BasicExpertInfo) -> BasicExpertInfo: return expert.model_copy( update={ "display_name": ( - _sanitize_string(expert.display_name) + sanitize_string(expert.display_name) if expert.display_name is not None else None ), "first_name": ( - _sanitize_string(expert.first_name) + sanitize_string(expert.first_name) if expert.first_name is not None else None ), "middle_initial": ( - _sanitize_string(expert.middle_initial) + sanitize_string(expert.middle_initial) if expert.middle_initial is not None else None ), "last_name": ( - _sanitize_string(expert.last_name) + sanitize_string(expert.last_name) if expert.last_name is not None else None ), "email": ( - _sanitize_string(expert.email) if expert.email is not None else None + sanitize_string(expert.email) if expert.email is not None else None ), } ) @@ -63,10 +82,10 @@ def _sanitize_expert_info(expert: BasicExpertInfo) -> BasicExpertInfo: def _sanitize_external_access(external_access: ExternalAccess) -> ExternalAccess: return ExternalAccess( external_user_emails={ - _sanitize_string(email) for email in external_access.external_user_emails + sanitize_string(email) for email in external_access.external_user_emails }, external_user_group_ids={ - _sanitize_string(group_id) + sanitize_string(group_id) for group_id in external_access.external_user_group_ids }, is_public=external_access.is_public, @@ -76,26 +95,26 @@ def _sanitize_external_access(external_access: ExternalAccess) -> ExternalAccess def sanitize_document_for_postgres(document: Document) -> Document: cleaned_doc = document.model_copy(deep=True) - cleaned_doc.id = _sanitize_string(cleaned_doc.id) - cleaned_doc.semantic_identifier = _sanitize_string(cleaned_doc.semantic_identifier) + cleaned_doc.id = sanitize_string(cleaned_doc.id) + cleaned_doc.semantic_identifier = sanitize_string(cleaned_doc.semantic_identifier) if cleaned_doc.title is not None: - cleaned_doc.title = _sanitize_string(cleaned_doc.title) + cleaned_doc.title = sanitize_string(cleaned_doc.title) if cleaned_doc.parent_hierarchy_raw_node_id is not None: - cleaned_doc.parent_hierarchy_raw_node_id = _sanitize_string( + cleaned_doc.parent_hierarchy_raw_node_id = sanitize_string( cleaned_doc.parent_hierarchy_raw_node_id ) cleaned_doc.metadata = { - _sanitize_string(key): ( - [_sanitize_string(item) for item in value] + sanitize_string(key): ( + [sanitize_string(item) for item in value] if isinstance(value, list) - else _sanitize_string(value) + else sanitize_string(value) ) for key, value in cleaned_doc.metadata.items() } if cleaned_doc.doc_metadata is not None: - cleaned_doc.doc_metadata = _sanitize_json_like(cleaned_doc.doc_metadata) + cleaned_doc.doc_metadata = sanitize_json_like(cleaned_doc.doc_metadata) if cleaned_doc.primary_owners is not None: cleaned_doc.primary_owners = [ @@ -113,11 +132,11 @@ def sanitize_document_for_postgres(document: Document) -> Document: for section in cleaned_doc.sections: if section.link is not None: - section.link = _sanitize_string(section.link) + section.link = sanitize_string(section.link) if section.text is not None: - section.text = _sanitize_string(section.text) + section.text = sanitize_string(section.text) if section.image_file_id is not None: - section.image_file_id = _sanitize_string(section.image_file_id) + section.image_file_id = sanitize_string(section.image_file_id) return cleaned_doc @@ -129,12 +148,12 @@ def sanitize_documents_for_postgres(documents: list[Document]) -> list[Document] def sanitize_hierarchy_node_for_postgres(node: HierarchyNode) -> HierarchyNode: cleaned_node = node.model_copy(deep=True) - cleaned_node.raw_node_id = _sanitize_string(cleaned_node.raw_node_id) - cleaned_node.display_name = _sanitize_string(cleaned_node.display_name) + cleaned_node.raw_node_id = sanitize_string(cleaned_node.raw_node_id) + cleaned_node.display_name = sanitize_string(cleaned_node.display_name) if cleaned_node.raw_parent_id is not None: - cleaned_node.raw_parent_id = _sanitize_string(cleaned_node.raw_parent_id) + cleaned_node.raw_parent_id = sanitize_string(cleaned_node.raw_parent_id) if cleaned_node.link is not None: - cleaned_node.link = _sanitize_string(cleaned_node.link) + cleaned_node.link = sanitize_string(cleaned_node.link) if cleaned_node.external_access is not None: cleaned_node.external_access = _sanitize_external_access( diff --git a/backend/onyx/utils/pydantic_util.py b/backend/onyx/utils/pydantic_util.py new file mode 100644 index 00000000000..43a0d558cf8 --- /dev/null +++ b/backend/onyx/utils/pydantic_util.py @@ -0,0 +1,13 @@ +from typing import Any + +from pydantic import BaseModel + + +def shallow_model_dump(model_instance: BaseModel) -> dict[str, Any]: + """Like model_dump(), but returns references to field values instead of + deep copies. Use with model_construct() to avoid unnecessary memory + duplication when building subclass instances.""" + return { + field_name: getattr(model_instance, field_name) + for field_name in model_instance.__class__.model_fields + } diff --git a/backend/onyx/utils/sensitive.py b/backend/onyx/utils/sensitive.py index 8d2fe1ec5a1..a6100bc872d 100644 --- a/backend/onyx/utils/sensitive.py +++ b/backend/onyx/utils/sensitive.py @@ -128,6 +128,8 @@ def get_value( value = self._decrypt() if not apply_mask: + # Callers must not mutate the returned dict — doing so would + # desync the cache from the encrypted bytes and the DB. return value # Apply masking @@ -174,18 +176,20 @@ def __getitem__(self, key: Any) -> NoReturn: ) def __eq__(self, other: Any) -> bool: - """Prevent direct comparison which might expose value.""" - if isinstance(other, SensitiveValue): - # Compare encrypted bytes for equality check - return self._encrypted_bytes == other._encrypted_bytes - raise SensitiveAccessError( - "Cannot compare SensitiveValue with non-SensitiveValue. " - "Use .get_value(apply_mask=True/False) to access the value for comparison." - ) + """Compare SensitiveValues by their decrypted content.""" + # NOTE: if you attempt to compare a string/dict to a SensitiveValue, + # this comparison will return NotImplemented, which then evaluates to False. + # This is the convention and required for SQLAlchemy's attribute tracking. + if not isinstance(other, SensitiveValue): + return NotImplemented + return self._decrypt() == other._decrypt() def __hash__(self) -> int: - """Allow hashing based on encrypted bytes.""" - return hash(self._encrypted_bytes) + """Hash based on decrypted content.""" + value = self._decrypt() + if isinstance(value, dict): + return hash(json.dumps(value, sort_keys=True)) + return hash(value) # Prevent JSON serialization def __json__(self) -> Any: diff --git a/backend/onyx/utils/telemetry.py b/backend/onyx/utils/telemetry.py index 3bb5d2ec876..d437095065e 100644 --- a/backend/onyx/utils/telemetry.py +++ b/backend/onyx/utils/telemetry.py @@ -2,7 +2,6 @@ import threading import uuid from enum import Enum -from typing import cast import requests @@ -15,6 +14,7 @@ from onyx.db.models import User from onyx.key_value_store.factory import get_kv_store from onyx.key_value_store.interface import KvKeyNotFoundError +from onyx.key_value_store.interface import unwrap_str from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import ( fetch_versioned_implementation_with_fallback, @@ -25,6 +25,7 @@ logger = setup_logger() + _DANSWER_TELEMETRY_ENDPOINT = "https://telemetry.onyx.app/anonymous_telemetry" _CACHED_UUID: str | None = None _CACHED_INSTANCE_DOMAIN: str | None = None @@ -62,10 +63,10 @@ def get_or_generate_uuid() -> str: kv_store = get_kv_store() try: - _CACHED_UUID = cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY)) + _CACHED_UUID = unwrap_str(kv_store.load(KV_CUSTOMER_UUID_KEY)) except KvKeyNotFoundError: _CACHED_UUID = str(uuid.uuid4()) - kv_store.store(KV_CUSTOMER_UUID_KEY, _CACHED_UUID, encrypt=True) + kv_store.store(KV_CUSTOMER_UUID_KEY, {"value": _CACHED_UUID}, encrypt=True) return _CACHED_UUID @@ -79,14 +80,16 @@ def _get_or_generate_instance_domain() -> str | None: # kv_store = get_kv_store() try: - _CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY)) + _CACHED_INSTANCE_DOMAIN = unwrap_str(kv_store.load(KV_INSTANCE_DOMAIN_KEY)) except KvKeyNotFoundError: with get_session_with_current_tenant() as db_session: first_user = db_session.query(User).first() if first_user: _CACHED_INSTANCE_DOMAIN = first_user.email.split("@")[-1] kv_store.store( - KV_INSTANCE_DOMAIN_KEY, _CACHED_INSTANCE_DOMAIN, encrypt=True + KV_INSTANCE_DOMAIN_KEY, + {"value": _CACHED_INSTANCE_DOMAIN}, + encrypt=True, ) return _CACHED_INSTANCE_DOMAIN diff --git a/backend/onyx/utils/variable_functionality.py b/backend/onyx/utils/variable_functionality.py index c1d2ff7497b..f67210b363d 100644 --- a/backend/onyx/utils/variable_functionality.py +++ b/backend/onyx/utils/variable_functionality.py @@ -24,6 +24,9 @@ def __init__(self) -> None: def set_ee(self) -> None: self._is_ee = True + def unset_ee(self) -> None: + self._is_ee = False + def is_ee_version(self) -> bool: return self._is_ee diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 56069c8f064..8a1b00b3e61 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -8,37 +8,3 @@ dependencies = [ [tool.uv.sources] onyx = { workspace = true } - -[tool.mypy] -plugins = "sqlalchemy.ext.mypy.plugin" -mypy_path = "backend" -explicit_package_bases = true -disallow_untyped_defs = true -warn_unused_ignores = true -enable_error_code = ["possibly-undefined"] -strict_equality = true -# Patterns match paths whether mypy is run from backend/ (CI) or repo root (e.g. VS Code extension with target ./backend) -exclude = [ - "(?:^|/)generated/", - "(?:^|/)\\.venv/", - "(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/skills/", - "(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/templates/", -] - -[[tool.mypy.overrides]] -module = "alembic.versions.*" -disable_error_code = ["var-annotated"] - -[[tool.mypy.overrides]] -module = "alembic_tenants.versions.*" -disable_error_code = ["var-annotated"] - -[[tool.mypy.overrides]] -module = "generated.*" -follow_imports = "silent" -ignore_errors = true - -[[tool.mypy.overrides]] -module = "transformers.*" -follow_imports = "skip" -ignore_errors = true diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 799d2ac3bae..629139991a1 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -65,7 +65,7 @@ attrs==25.4.0 # jsonschema # referencing # zeep -authlib==1.6.6 +authlib==1.6.7 # via fastmcp babel==2.17.0 # via courlan @@ -109,9 +109,7 @@ brotli==1.2.0 bytecode==0.17.0 # via ddtrace cachetools==6.2.2 - # via - # google-auth - # py-key-value-aio + # via py-key-value-aio caio==0.9.25 # via aiofile celery==5.5.1 @@ -190,6 +188,7 @@ courlan==1.3.2 cryptography==46.0.5 # via # authlib + # google-auth # msal # msoffcrypto-tool # pdfminer-six @@ -230,9 +229,7 @@ distro==1.9.0 dnspython==2.8.0 # via email-validator docstring-parser==0.17.0 - # via - # cyclopts - # google-cloud-aiplatform + # via cyclopts docutils==0.22.3 # via rich-rst dropbox==12.0.2 @@ -297,26 +294,15 @@ gitdb==4.0.12 gitpython==3.1.45 # via braintrust google-api-core==2.28.1 - # via - # google-api-python-client - # google-cloud-aiplatform - # google-cloud-bigquery - # google-cloud-core - # google-cloud-resource-manager - # google-cloud-storage + # via google-api-python-client google-api-python-client==2.86.0 # via onyx -google-auth==2.43.0 +google-auth==2.48.0 # via # google-api-core # google-api-python-client # google-auth-httplib2 # google-auth-oauthlib - # google-cloud-aiplatform - # google-cloud-bigquery - # google-cloud-core - # google-cloud-resource-manager - # google-cloud-storage # google-genai # kubernetes google-auth-httplib2==0.1.0 @@ -325,51 +311,16 @@ google-auth-httplib2==0.1.0 # onyx google-auth-oauthlib==1.0.0 # via onyx -google-cloud-aiplatform==1.121.0 - # via onyx -google-cloud-bigquery==3.38.0 - # via google-cloud-aiplatform -google-cloud-core==2.5.0 - # via - # google-cloud-bigquery - # google-cloud-storage -google-cloud-resource-manager==1.15.0 - # via google-cloud-aiplatform -google-cloud-storage==2.19.0 - # via google-cloud-aiplatform -google-crc32c==1.7.1 - # via - # google-cloud-storage - # google-resumable-media google-genai==1.52.0 - # via - # google-cloud-aiplatform - # onyx -google-resumable-media==2.7.2 - # via - # google-cloud-bigquery - # google-cloud-storage + # via onyx googleapis-common-protos==1.72.0 # via # google-api-core - # grpc-google-iam-v1 - # grpcio-status # opentelemetry-exporter-otlp-proto-http greenlet==3.2.4 # via # playwright # sqlalchemy -grpc-google-iam-v1==0.14.3 - # via google-cloud-resource-manager -grpcio==1.76.0 - # via - # google-api-core - # google-cloud-resource-manager - # googleapis-common-protos - # grpc-google-iam-v1 - # grpcio-status -grpcio-status==1.76.0 - # via google-api-core h11==0.16.0 # via # httpcore @@ -528,7 +479,7 @@ lxml==5.3.0 # unstructured # xmlsec # zeep -lxml-html-clean==0.4.3 +lxml-html-clean==0.4.4 # via lxml magika==0.6.3 # via markitdown @@ -596,7 +547,7 @@ mypy-extensions==1.0.0 # typing-inspect nest-asyncio==1.6.0 # via onyx -nltk==3.9.1 +nltk==3.9.3 # via unstructured numpy==2.4.1 # via @@ -663,15 +614,13 @@ opentelemetry-sdk==1.39.1 # opentelemetry-exporter-otlp-proto-http opentelemetry-semantic-conventions==0.60b1 # via opentelemetry-sdk -orjson==3.11.4 ; platform_python_implementation != 'PyPy' +orjson==3.11.6 ; platform_python_implementation != 'PyPy' # via langsmith packaging==24.2 # via # dask # distributed # fastmcp - # google-cloud-aiplatform - # google-cloud-bigquery # huggingface-hub # jira # kombu @@ -721,19 +670,12 @@ propcache==0.4.1 # aiohttp # yarl proto-plus==1.26.1 - # via - # google-api-core - # google-cloud-aiplatform - # google-cloud-resource-manager + # via google-api-core protobuf==6.33.5 # via # ddtrace # google-api-core - # google-cloud-aiplatform - # google-cloud-resource-manager # googleapis-common-protos - # grpc-google-iam-v1 - # grpcio-status # onnxruntime # opentelemetry-proto # proto-plus @@ -771,7 +713,6 @@ pydantic==2.11.7 # exa-py # fastapi # fastmcp - # google-cloud-aiplatform # google-genai # langchain-core # langfuse @@ -809,7 +750,7 @@ pypandoc-binary==1.16.2 # via onyx pyparsing==3.2.5 # via httplib2 -pypdf==6.7.3 +pypdf==6.8.0 # via # onyx # unstructured-client @@ -835,7 +776,6 @@ python-dateutil==2.8.2 # botocore # celery # dateparser - # google-cloud-bigquery # htmldate # hubspot-api-client # kubernetes @@ -927,8 +867,6 @@ requests==2.32.5 # dropbox # exa-py # google-api-core - # google-cloud-bigquery - # google-cloud-storage # google-genai # hubspot-api-client # huggingface-hub @@ -1002,9 +940,7 @@ sendgrid==6.12.5 sentry-sdk==2.14.0 # via onyx shapely==2.0.6 - # via - # google-cloud-aiplatform - # onyx + # via onyx shellingham==1.5.4 # via typer simple-salesforce==1.12.6 @@ -1084,7 +1020,7 @@ toolz==1.1.0 # dask # distributed # partd -tornado==6.5.2 +tornado==6.5.5 # via distributed tqdm==4.67.1 # via @@ -1118,9 +1054,7 @@ typing-extensions==4.15.0 # exa-py # exceptiongroup # fastapi - # google-cloud-aiplatform # google-genai - # grpcio # huggingface-hub # jira # langchain-core diff --git a/backend/requirements/dev.txt b/backend/requirements/dev.txt index 88c5f4065ce..b3320e46e47 100644 --- a/backend/requirements/dev.txt +++ b/backend/requirements/dev.txt @@ -59,8 +59,6 @@ botocore==1.39.11 # s3transfer brotli==1.2.0 # via onyx -cachetools==6.2.2 - # via google-auth celery-types==0.19.0 # via onyx certifi==2025.11.12 @@ -100,7 +98,9 @@ comm==0.2.3 contourpy==1.3.3 # via matplotlib cryptography==46.0.5 - # via pyjwt + # via + # google-auth + # pyjwt cycler==0.12.1 # via matplotlib debugpy==1.8.17 @@ -115,8 +115,6 @@ distlib==0.4.0 # via virtualenv distro==1.9.0 # via openai -docstring-parser==0.17.0 - # via google-cloud-aiplatform durationpy==0.10 # via kubernetes execnet==2.1.2 @@ -145,65 +143,14 @@ frozenlist==1.8.0 # aiosignal fsspec==2025.10.0 # via huggingface-hub -google-api-core==2.28.1 - # via - # google-cloud-aiplatform - # google-cloud-bigquery - # google-cloud-core - # google-cloud-resource-manager - # google-cloud-storage -google-auth==2.43.0 - # via - # google-api-core - # google-cloud-aiplatform - # google-cloud-bigquery - # google-cloud-core - # google-cloud-resource-manager - # google-cloud-storage +google-auth==2.48.0 + # via # google-genai # kubernetes -google-cloud-aiplatform==1.121.0 - # via onyx -google-cloud-bigquery==3.38.0 - # via google-cloud-aiplatform -google-cloud-core==2.5.0 - # via - # google-cloud-bigquery - # google-cloud-storage -google-cloud-resource-manager==1.15.0 - # via google-cloud-aiplatform -google-cloud-storage==2.19.0 - # via google-cloud-aiplatform -google-crc32c==1.7.1 - # via - # google-cloud-storage - # google-resumable-media google-genai==1.52.0 - # via - # google-cloud-aiplatform - # onyx -google-resumable-media==2.7.2 - # via - # google-cloud-bigquery - # google-cloud-storage -googleapis-common-protos==1.72.0 - # via - # google-api-core - # grpc-google-iam-v1 - # grpcio-status + # via onyx greenlet==3.2.4 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64' # via sqlalchemy -grpc-google-iam-v1==0.14.3 - # via google-cloud-resource-manager -grpcio==1.76.0 - # via - # google-api-core - # google-cloud-resource-manager - # googleapis-common-protos - # grpc-google-iam-v1 - # grpcio-status -grpcio-status==1.76.0 - # via google-api-core h11==0.16.0 # via # httpcore @@ -311,13 +258,12 @@ numpy==2.4.1 # contourpy # matplotlib # pandas-stubs - # shapely # voyageai oauthlib==3.2.2 # via # kubernetes # requests-oauthlib -onyx-devtools==0.6.2 +onyx-devtools==0.7.0 # via onyx openai==2.14.0 # via @@ -330,8 +276,6 @@ openapi-generator-cli==7.17.0 packaging==24.2 # via # black - # google-cloud-aiplatform - # google-cloud-bigquery # hatchling # huggingface-hub # ipykernel @@ -374,20 +318,6 @@ propcache==0.4.1 # via # aiohttp # yarl -proto-plus==1.26.1 - # via - # google-api-core - # google-cloud-aiplatform - # google-cloud-resource-manager -protobuf==6.33.5 - # via - # google-api-core - # google-cloud-aiplatform - # google-cloud-resource-manager - # googleapis-common-protos - # grpc-google-iam-v1 - # grpcio-status - # proto-plus psutil==7.1.3 # via ipykernel ptyprocess==0.7.0 ; sys_platform != 'emscripten' and sys_platform != 'win32' @@ -409,7 +339,6 @@ pydantic==2.11.7 # agent-client-protocol # cohere # fastapi - # google-cloud-aiplatform # google-genai # litellm # mcp @@ -450,7 +379,6 @@ python-dateutil==2.8.2 # via # aiobotocore # botocore - # google-cloud-bigquery # jupyter-client # kubernetes # matplotlib @@ -478,16 +406,13 @@ referencing==0.36.2 # jsonschema-specifications regex==2025.11.3 # via tiktoken -release-tag==0.4.3 +release-tag==0.5.2 # via onyx reorder-python-imports-black==3.14.0 # via onyx requests==2.32.5 # via # cohere - # google-api-core - # google-cloud-bigquery - # google-cloud-storage # google-genai # huggingface-hub # kubernetes @@ -510,8 +435,6 @@ s3transfer==0.13.1 # via boto3 sentry-sdk==2.14.0 # via onyx -shapely==2.0.6 - # via google-cloud-aiplatform six==1.17.0 # via # kubernetes @@ -543,7 +466,7 @@ tokenizers==0.21.4 # via # cohere # litellm -tornado==6.5.2 +tornado==6.5.5 # via # ipykernel # jupyter-client @@ -602,9 +525,7 @@ typing-extensions==4.15.0 # celery-types # cohere # fastapi - # google-cloud-aiplatform # google-genai - # grpcio # huggingface-hub # ipython # mcp diff --git a/backend/requirements/ee.txt b/backend/requirements/ee.txt index 3b89c50164b..17a6508bf4f 100644 --- a/backend/requirements/ee.txt +++ b/backend/requirements/ee.txt @@ -53,8 +53,6 @@ botocore==1.39.11 # s3transfer brotli==1.2.0 # via onyx -cachetools==6.2.2 - # via google-auth certifi==2025.11.12 # via # httpcore @@ -79,15 +77,15 @@ colorama==0.4.6 ; sys_platform == 'win32' # click # tqdm cryptography==46.0.5 - # via pyjwt + # via + # google-auth + # pyjwt decorator==5.2.1 # via retry discord-py==2.4.0 # via onyx distro==1.9.0 # via openai -docstring-parser==0.17.0 - # via google-cloud-aiplatform durationpy==0.10 # via kubernetes fastapi==0.133.1 @@ -104,63 +102,12 @@ frozenlist==1.8.0 # aiosignal fsspec==2025.10.0 # via huggingface-hub -google-api-core==2.28.1 - # via - # google-cloud-aiplatform - # google-cloud-bigquery - # google-cloud-core - # google-cloud-resource-manager - # google-cloud-storage -google-auth==2.43.0 - # via - # google-api-core - # google-cloud-aiplatform - # google-cloud-bigquery - # google-cloud-core - # google-cloud-resource-manager - # google-cloud-storage +google-auth==2.48.0 + # via # google-genai # kubernetes -google-cloud-aiplatform==1.121.0 - # via onyx -google-cloud-bigquery==3.38.0 - # via google-cloud-aiplatform -google-cloud-core==2.5.0 - # via - # google-cloud-bigquery - # google-cloud-storage -google-cloud-resource-manager==1.15.0 - # via google-cloud-aiplatform -google-cloud-storage==2.19.0 - # via google-cloud-aiplatform -google-crc32c==1.7.1 - # via - # google-cloud-storage - # google-resumable-media google-genai==1.52.0 - # via - # google-cloud-aiplatform - # onyx -google-resumable-media==2.7.2 - # via - # google-cloud-bigquery - # google-cloud-storage -googleapis-common-protos==1.72.0 - # via - # google-api-core - # grpc-google-iam-v1 - # grpcio-status -grpc-google-iam-v1==0.14.3 - # via google-cloud-resource-manager -grpcio==1.76.0 - # via - # google-api-core - # google-cloud-resource-manager - # googleapis-common-protos - # grpc-google-iam-v1 - # grpcio-status -grpcio-status==1.76.0 - # via google-api-core + # via onyx h11==0.16.0 # via # httpcore @@ -221,9 +168,7 @@ multidict==6.7.0 # aiohttp # yarl numpy==2.4.1 - # via - # shapely - # voyageai + # via voyageai oauthlib==3.2.2 # via # kubernetes @@ -233,10 +178,7 @@ openai==2.14.0 # litellm # onyx packaging==24.2 - # via - # google-cloud-aiplatform - # google-cloud-bigquery - # huggingface-hub + # via huggingface-hub parameterized==0.9.0 # via cohere posthog==3.7.4 @@ -251,20 +193,6 @@ propcache==0.4.1 # via # aiohttp # yarl -proto-plus==1.26.1 - # via - # google-api-core - # google-cloud-aiplatform - # google-cloud-resource-manager -protobuf==6.33.5 - # via - # google-api-core - # google-cloud-aiplatform - # google-cloud-resource-manager - # googleapis-common-protos - # grpc-google-iam-v1 - # grpcio-status - # proto-plus py==1.11.0 # via retry pyasn1==0.6.2 @@ -280,7 +208,6 @@ pydantic==2.11.7 # agent-client-protocol # cohere # fastapi - # google-cloud-aiplatform # google-genai # litellm # mcp @@ -297,7 +224,6 @@ python-dateutil==2.8.2 # via # aiobotocore # botocore - # google-cloud-bigquery # kubernetes # posthog python-dotenv==1.1.1 @@ -321,9 +247,6 @@ regex==2025.11.3 requests==2.32.5 # via # cohere - # google-api-core - # google-cloud-bigquery - # google-cloud-storage # google-genai # huggingface-hub # kubernetes @@ -345,8 +268,6 @@ s3transfer==0.13.1 # via boto3 sentry-sdk==2.14.0 # via onyx -shapely==2.0.6 - # via google-cloud-aiplatform six==1.17.0 # via # kubernetes @@ -385,9 +306,7 @@ typing-extensions==4.15.0 # anyio # cohere # fastapi - # google-cloud-aiplatform # google-genai - # grpcio # huggingface-hub # mcp # openai diff --git a/backend/requirements/model_server.txt b/backend/requirements/model_server.txt index 1190245a2f5..4f0bc2b9f9e 100644 --- a/backend/requirements/model_server.txt +++ b/backend/requirements/model_server.txt @@ -57,8 +57,6 @@ botocore==1.39.11 # s3transfer brotli==1.2.0 # via onyx -cachetools==6.2.2 - # via google-auth celery==5.5.1 # via sentry-sdk certifi==2025.11.12 @@ -95,15 +93,15 @@ colorama==0.4.6 ; sys_platform == 'win32' # click # tqdm cryptography==46.0.5 - # via pyjwt + # via + # google-auth + # pyjwt decorator==5.2.1 # via retry discord-py==2.4.0 # via onyx distro==1.9.0 # via openai -docstring-parser==0.17.0 - # via google-cloud-aiplatform durationpy==0.10 # via kubernetes einops==0.8.1 @@ -129,63 +127,12 @@ fsspec==2025.10.0 # via # huggingface-hub # torch -google-api-core==2.28.1 - # via - # google-cloud-aiplatform - # google-cloud-bigquery - # google-cloud-core - # google-cloud-resource-manager - # google-cloud-storage -google-auth==2.43.0 - # via - # google-api-core - # google-cloud-aiplatform - # google-cloud-bigquery - # google-cloud-core - # google-cloud-resource-manager - # google-cloud-storage +google-auth==2.48.0 + # via # google-genai # kubernetes -google-cloud-aiplatform==1.121.0 - # via onyx -google-cloud-bigquery==3.38.0 - # via google-cloud-aiplatform -google-cloud-core==2.5.0 - # via - # google-cloud-bigquery - # google-cloud-storage -google-cloud-resource-manager==1.15.0 - # via google-cloud-aiplatform -google-cloud-storage==2.19.0 - # via google-cloud-aiplatform -google-crc32c==1.7.1 - # via - # google-cloud-storage - # google-resumable-media google-genai==1.52.0 - # via - # google-cloud-aiplatform - # onyx -google-resumable-media==2.7.2 - # via - # google-cloud-bigquery - # google-cloud-storage -googleapis-common-protos==1.72.0 - # via - # google-api-core - # grpc-google-iam-v1 - # grpcio-status -grpc-google-iam-v1==0.14.3 - # via google-cloud-resource-manager -grpcio==1.76.0 - # via - # google-api-core - # google-cloud-resource-manager - # googleapis-common-protos - # grpc-google-iam-v1 - # grpcio-status -grpcio-status==1.76.0 - # via google-api-core + # via onyx h11==0.16.0 # via # httpcore @@ -263,7 +210,6 @@ numpy==2.4.1 # onyx # scikit-learn # scipy - # shapely # transformers # voyageai nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux' @@ -316,8 +262,6 @@ openai==2.14.0 packaging==24.2 # via # accelerate - # google-cloud-aiplatform - # google-cloud-bigquery # huggingface-hub # kombu # transformers @@ -337,20 +281,6 @@ propcache==0.4.1 # via # aiohttp # yarl -proto-plus==1.26.1 - # via - # google-api-core - # google-cloud-aiplatform - # google-cloud-resource-manager -protobuf==6.33.5 - # via - # google-api-core - # google-cloud-aiplatform - # google-cloud-resource-manager - # googleapis-common-protos - # grpc-google-iam-v1 - # grpcio-status - # proto-plus psutil==7.1.3 # via accelerate py==1.11.0 @@ -368,7 +298,6 @@ pydantic==2.11.7 # agent-client-protocol # cohere # fastapi - # google-cloud-aiplatform # google-genai # litellm # mcp @@ -386,7 +315,6 @@ python-dateutil==2.8.2 # aiobotocore # botocore # celery - # google-cloud-bigquery # kubernetes python-dotenv==1.1.1 # via @@ -413,9 +341,6 @@ regex==2025.11.3 requests==2.32.5 # via # cohere - # google-api-core - # google-cloud-bigquery - # google-cloud-storage # google-genai # huggingface-hub # kubernetes @@ -452,8 +377,6 @@ sentry-sdk==2.14.0 # via onyx setuptools==80.9.0 ; python_full_version >= '3.12' # via torch -shapely==2.0.6 - # via google-cloud-aiplatform six==1.17.0 # via # kubernetes @@ -510,9 +433,7 @@ typing-extensions==4.15.0 # anyio # cohere # fastapi - # google-cloud-aiplatform # google-genai - # grpcio # huggingface-hub # mcp # openai diff --git a/backend/scripts/debugging/opensearch/opensearch_debug.py b/backend/scripts/debugging/opensearch/opensearch_debug.py new file mode 100644 index 00000000000..cc14ad4c735 --- /dev/null +++ b/backend/scripts/debugging/opensearch/opensearch_debug.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +"""A utility to interact with OpenSearch. + +Usage: + python3 opensearch_debug.py --help + python3 opensearch_debug.py list + python3 opensearch_debug.py delete + +Environment Variables: + OPENSEARCH_HOST: OpenSearch host + OPENSEARCH_REST_API_PORT: OpenSearch port + OPENSEARCH_ADMIN_USERNAME: Admin username + OPENSEARCH_ADMIN_PASSWORD: Admin password + +Dependencies: + backend/shared_configs/configs.py + backend/onyx/document_index/opensearch/client.py +""" + +import argparse +import os +import sys + +from onyx.document_index.opensearch.client import OpenSearchClient +from onyx.document_index.opensearch.client import OpenSearchIndexClient +from shared_configs.configs import MULTI_TENANT + + +def list_indices(client: OpenSearchClient) -> None: + indices = client.list_indices_with_info() + print(f"Found {len(indices)} indices.") + print("-" * 80) + for index in sorted(indices, key=lambda x: x.name): + print(f"Index: {index.name}") + print(f"Health: {index.health}") + print(f"Status: {index.status}") + print(f"Num Primary Shards: {index.num_primary_shards}") + print(f"Num Replica Shards: {index.num_replica_shards}") + print(f"Docs Count: {index.docs_count}") + print(f"Docs Deleted: {index.docs_deleted}") + print(f"Created At: {index.created_at}") + print(f"Total Size: {index.total_size}") + print(f"Primary Shards Size: {index.primary_shards_size}") + print("-" * 80) + + +def delete_index(client: OpenSearchIndexClient) -> None: + if not client.index_exists(): + print(f"Index '{client._index_name}' does not exist.") + return + + confirm = input(f"Delete index '{client._index_name}'? (yes/no): ") + if confirm.lower() != "yes": + print("Aborted.") + return + + if client.delete_index(): + print(f"Deleted index '{client._index_name}'.") + else: + print(f"Failed to delete index '{client._index_name}' for an unknown reason.") + + +def main() -> None: + def add_standard_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--host", + help="OpenSearch host. If not provided, will fall back to OPENSEARCH_HOST, then prompt " + "for input.", + type=str, + default=os.environ.get("OPENSEARCH_HOST", ""), + ) + parser.add_argument( + "--port", + help="OpenSearch port. If not provided, will fall back to OPENSEARCH_REST_API_PORT, " + "then prompt for input.", + type=int, + default=int(os.environ.get("OPENSEARCH_REST_API_PORT", 0)), + ) + parser.add_argument( + "--username", + help="OpenSearch username. If not provided, will fall back to OPENSEARCH_ADMIN_USERNAME, " + "then prompt for input.", + type=str, + default=os.environ.get("OPENSEARCH_ADMIN_USERNAME", ""), + ) + parser.add_argument( + "--password", + help="OpenSearch password. If not provided, will fall back to OPENSEARCH_ADMIN_PASSWORD, " + "then prompt for input.", + type=str, + default=os.environ.get("OPENSEARCH_ADMIN_PASSWORD", ""), + ) + parser.add_argument( + "--no-ssl", help="Disable SSL.", action="store_true", default=False + ) + parser.add_argument( + "--no-verify-certs", + help="Disable certificate verification (for self-signed certs).", + action="store_true", + default=False, + ) + parser.add_argument( + "--use-aws-managed-opensearch", + help="Whether to use AWS-managed OpenSearch. If not provided, will fall back to checking " + "USING_AWS_MANAGED_OPENSEARCH=='true', then default to False.", + action=argparse.BooleanOptionalAction, + default=os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() + == "true", + ) + + parser = argparse.ArgumentParser( + description="A utility to interact with OpenSearch." + ) + subparsers = parser.add_subparsers( + dest="command", help="Command to execute.", required=True + ) + + list_parser = subparsers.add_parser("list", help="List all indices with info.") + add_standard_arguments(list_parser) + + delete_parser = subparsers.add_parser("delete", help="Delete an index.") + delete_parser.add_argument("index", help="Index name.", type=str) + add_standard_arguments(delete_parser) + + args = parser.parse_args() + + if not (host := args.host or input("Enter the OpenSearch host: ")): + print("Error: OpenSearch host is required.") + sys.exit(1) + if not (port := args.port or int(input("Enter the OpenSearch port: "))): + print("Error: OpenSearch port is required.") + sys.exit(1) + if not (username := args.username or input("Enter the OpenSearch username: ")): + print("Error: OpenSearch username is required.") + sys.exit(1) + if not (password := args.password or input("Enter the OpenSearch password: ")): + print("Error: OpenSearch password is required.") + sys.exit(1) + print("Using AWS-managed OpenSearch: ", args.use_aws_managed_opensearch) + print(f"MULTI_TENANT: {MULTI_TENANT}") + + with ( + OpenSearchIndexClient( + index_name=args.index, + host=host, + port=port, + auth=(username, password), + use_ssl=not args.no_ssl, + verify_certs=not args.no_verify_certs, + ) + if args.command == "delete" + else OpenSearchClient( + host=host, + port=port, + auth=(username, password), + use_ssl=not args.no_ssl, + verify_certs=not args.no_verify_certs, + ) + ) as client: + if not client.ping(): + print("Error: Could not connect to OpenSearch.") + sys.exit(1) + + if args.command == "list": + list_indices(client) + elif args.command == "delete": + delete_index(client) + + +if __name__ == "__main__": + main() diff --git a/backend/scripts/decrypt.py b/backend/scripts/decrypt.py index 40e6413564e..4aebbd4d885 100644 --- a/backend/scripts/decrypt.py +++ b/backend/scripts/decrypt.py @@ -1,48 +1,93 @@ +"""Decrypt a raw hex-encoded credential value. + +Usage: + python -m scripts.decrypt + python -m scripts.decrypt --key "my-encryption-key" + python -m scripts.decrypt --key "" + +Pass --key "" to skip decryption and just decode the raw bytes as UTF-8. +Omit --key to use the current ENCRYPTION_KEY_SECRET from the environment. +""" + +import argparse import binascii import json +import os import sys -from onyx.utils.encryption import decrypt_bytes_to_string +parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(parent_dir) +from onyx.utils.encryption import decrypt_bytes_to_string # noqa: E402 +from onyx.utils.variable_functionality import global_version # noqa: E402 -def decrypt_raw_credential(encrypted_value: str) -> None: - """Decrypt and display a raw encrypted credential value + +def decrypt_raw_credential(encrypted_value: str, key: str | None = None) -> None: + """Decrypt and display a raw encrypted credential value. Args: - encrypted_value: The hex encoded encrypted credential value + encrypted_value: The hex-encoded encrypted credential value. + key: Encryption key to use. None means use ENCRYPTION_KEY_SECRET, + empty string means just decode as UTF-8. """ - try: - # If string starts with 'x', remove it as it's just a prefix indicating hex - if encrypted_value.startswith("x"): - encrypted_value = encrypted_value[1:] - elif encrypted_value.startswith("\\x"): - encrypted_value = encrypted_value[2:] + # Strip common hex prefixes + if encrypted_value.startswith("\\x"): + encrypted_value = encrypted_value[2:] + elif encrypted_value.startswith("x"): + encrypted_value = encrypted_value[1:] + print(encrypted_value) - # Convert hex string to bytes - encrypted_bytes = binascii.unhexlify(encrypted_value) + try: + raw_bytes = binascii.unhexlify(encrypted_value) + except binascii.Error: + print("Error: Invalid hex-encoded string") + sys.exit(1) - # Decrypt the bytes - decrypted_str = decrypt_bytes_to_string(encrypted_bytes) + if key == "": + # Empty key → just decode as UTF-8, no decryption + try: + decrypted_str = raw_bytes.decode("utf-8") + except UnicodeDecodeError as e: + print(f"Error decoding bytes as UTF-8: {e}") + sys.exit(1) + else: + print(key) + try: + decrypted_str = decrypt_bytes_to_string(raw_bytes, key=key) + except Exception as e: + print(f"Error decrypting value: {e}") + sys.exit(1) - # Parse and pretty print the decrypted JSON - decrypted_json = json.loads(decrypted_str) - print("Decrypted credential value:") - print(json.dumps(decrypted_json, indent=2)) + # Try to pretty-print as JSON, otherwise print raw + try: + parsed = json.loads(decrypted_str) + print(json.dumps(parsed, indent=2)) + except json.JSONDecodeError: + print(decrypted_str) - except binascii.Error: - print("Error: Invalid hex encoded string") - except json.JSONDecodeError as e: - print(f"Decrypted raw value (not JSON): {e}") +def main() -> None: + parser = argparse.ArgumentParser( + description="Decrypt a hex-encoded credential value." + ) + parser.add_argument( + "value", + help="Hex-encoded encrypted value to decrypt.", + ) + parser.add_argument( + "--key", + default=None, + help=( + "Encryption key. Omit to use ENCRYPTION_KEY_SECRET from env. " + 'Pass "" (empty) to just decode as UTF-8 without decryption.' + ), + ) + args = parser.parse_args() - except Exception as e: - print(f"Error decrypting value: {e}") + global_version.set_ee() + decrypt_raw_credential(args.value, key=args.key) + global_version.unset_ee() if __name__ == "__main__": - if len(sys.argv) != 2: - print("Usage: python decrypt.py ") - sys.exit(1) - - encrypted_value = sys.argv[1] - decrypt_raw_credential(encrypted_value) + main() diff --git a/backend/scripts/dev_run_background_jobs.py b/backend/scripts/dev_run_background_jobs.py index 1c26827043f..115391910c5 100644 --- a/backend/scripts/dev_run_background_jobs.py +++ b/backend/scripts/dev_run_background_jobs.py @@ -16,10 +16,6 @@ def monitor_process(process_name: str, process: subprocess.Popen) -> None: def run_jobs() -> None: - # Check if we should use lightweight mode, defaults to True, change to False to use separate background workers - use_lightweight = True - - # command setup cmd_worker_primary = [ "celery", "-A", @@ -74,6 +70,48 @@ def run_jobs() -> None: "--queues=connector_doc_fetching", ] + cmd_worker_heavy = [ + "celery", + "-A", + "onyx.background.celery.versioned_apps.heavy", + "worker", + "--pool=threads", + "--concurrency=4", + "--prefetch-multiplier=1", + "--loglevel=INFO", + "--hostname=heavy@%n", + "-Q", + "connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,sandbox", + ] + + cmd_worker_monitoring = [ + "celery", + "-A", + "onyx.background.celery.versioned_apps.monitoring", + "worker", + "--pool=threads", + "--concurrency=1", + "--prefetch-multiplier=1", + "--loglevel=INFO", + "--hostname=monitoring@%n", + "-Q", + "monitoring", + ] + + cmd_worker_user_file_processing = [ + "celery", + "-A", + "onyx.background.celery.versioned_apps.user_file_processing", + "worker", + "--pool=threads", + "--concurrency=2", + "--prefetch-multiplier=1", + "--loglevel=INFO", + "--hostname=user_file_processing@%n", + "-Q", + "user_file_processing,user_file_project_sync,user_file_delete", + ] + cmd_beat = [ "celery", "-A", @@ -82,144 +120,31 @@ def run_jobs() -> None: "--loglevel=INFO", ] - # Prepare background worker commands based on mode - if use_lightweight: - print("Starting workers in LIGHTWEIGHT mode (single background worker)") - cmd_worker_background = [ - "celery", - "-A", - "onyx.background.celery.versioned_apps.background", - "worker", - "--pool=threads", - "--concurrency=6", - "--prefetch-multiplier=1", - "--loglevel=INFO", - "--hostname=background@%n", - "-Q", - "connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration", - ] - background_workers = [("BACKGROUND", cmd_worker_background)] - else: - print("Starting workers in STANDARD mode (separate background workers)") - cmd_worker_heavy = [ - "celery", - "-A", - "onyx.background.celery.versioned_apps.heavy", - "worker", - "--pool=threads", - "--concurrency=4", - "--prefetch-multiplier=1", - "--loglevel=INFO", - "--hostname=heavy@%n", - "-Q", - "connector_pruning,sandbox", - ] - cmd_worker_monitoring = [ - "celery", - "-A", - "onyx.background.celery.versioned_apps.monitoring", - "worker", - "--pool=threads", - "--concurrency=1", - "--prefetch-multiplier=1", - "--loglevel=INFO", - "--hostname=monitoring@%n", - "-Q", - "monitoring", - ] - cmd_worker_user_file_processing = [ - "celery", - "-A", - "onyx.background.celery.versioned_apps.user_file_processing", - "worker", - "--pool=threads", - "--concurrency=2", - "--prefetch-multiplier=1", - "--loglevel=INFO", - "--hostname=user_file_processing@%n", - "-Q", - "user_file_processing,user_file_project_sync,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,user_file_delete", - ] - background_workers = [ - ("HEAVY", cmd_worker_heavy), - ("MONITORING", cmd_worker_monitoring), - ("USER_FILE_PROCESSING", cmd_worker_user_file_processing), - ] - - # spawn processes - worker_primary_process = subprocess.Popen( - cmd_worker_primary, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True - ) - - worker_light_process = subprocess.Popen( - cmd_worker_light, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True - ) - - worker_docprocessing_process = subprocess.Popen( - cmd_worker_docprocessing, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - ) - - worker_docfetching_process = subprocess.Popen( - cmd_worker_docfetching, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - ) - - beat_process = subprocess.Popen( - cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True - ) - - # Spawn background worker processes based on mode - background_processes = [] - for name, cmd in background_workers: + all_workers = [ + ("PRIMARY", cmd_worker_primary), + ("LIGHT", cmd_worker_light), + ("DOCPROCESSING", cmd_worker_docprocessing), + ("DOCFETCHING", cmd_worker_docfetching), + ("HEAVY", cmd_worker_heavy), + ("MONITORING", cmd_worker_monitoring), + ("USER_FILE_PROCESSING", cmd_worker_user_file_processing), + ("BEAT", cmd_beat), + ] + + processes = [] + for name, cmd in all_workers: process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True ) - background_processes.append((name, process)) - - # monitor threads - worker_primary_thread = threading.Thread( - target=monitor_process, args=("PRIMARY", worker_primary_process) - ) - worker_light_thread = threading.Thread( - target=monitor_process, args=("LIGHT", worker_light_process) - ) - worker_docprocessing_thread = threading.Thread( - target=monitor_process, args=("DOCPROCESSING", worker_docprocessing_process) - ) - worker_docfetching_thread = threading.Thread( - target=monitor_process, args=("DOCFETCHING", worker_docfetching_process) - ) - beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process)) - - # Create monitor threads for background workers - background_threads = [] - for name, process in background_processes: - thread = threading.Thread(target=monitor_process, args=(name, process)) - background_threads.append(thread) - - # Start all threads - worker_primary_thread.start() - worker_light_thread.start() - worker_docprocessing_thread.start() - worker_docfetching_thread.start() - beat_thread.start() + processes.append((name, process)) - for thread in background_threads: + threads = [] + for name, process in processes: + thread = threading.Thread(target=monitor_process, args=(name, process)) + threads.append(thread) thread.start() - # Wait for all threads - worker_primary_thread.join() - worker_light_thread.join() - worker_docprocessing_thread.join() - worker_docfetching_thread.join() - beat_thread.join() - - for thread in background_threads: + for thread in threads: thread.join() diff --git a/backend/scripts/reencrypt_secrets.py b/backend/scripts/reencrypt_secrets.py new file mode 100644 index 00000000000..0458cbe1f6e --- /dev/null +++ b/backend/scripts/reencrypt_secrets.py @@ -0,0 +1,107 @@ +"""Re-encrypt secrets under the current ENCRYPTION_KEY_SECRET. + +Decrypts all encrypted columns using the old key (or raw decode if the old key +is empty), then re-encrypts them with the current ENCRYPTION_KEY_SECRET. + +Usage (docker): + docker exec -it onyx-api_server-1 \ + python -m scripts.reencrypt_secrets --old-key "previous-key" + +Usage (kubernetes): + kubectl exec -it -- \ + python -m scripts.reencrypt_secrets --old-key "previous-key" + +Omit --old-key (or pass "") if secrets were not previously encrypted. + +For multi-tenant deployments, pass --tenant-id to target a specific tenant, +or --all-tenants to iterate every tenant. +""" + +import argparse +import os +import sys + +parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(parent_dir) + +from onyx.db.rotate_encryption_key import rotate_encryption_key # noqa: E402 +from onyx.db.engine.sql_engine import get_session_with_tenant # noqa: E402 +from onyx.db.engine.sql_engine import SqlEngine # noqa: E402 +from onyx.db.engine.tenant_utils import get_all_tenant_ids # noqa: E402 +from onyx.utils.variable_functionality import global_version # noqa: E402 +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA # noqa: E402 + + +def _run_for_tenant(tenant_id: str, old_key: str | None, dry_run: bool = False) -> None: + print(f"Re-encrypting secrets for tenant: {tenant_id}") + with get_session_with_tenant(tenant_id=tenant_id) as db_session: + results = rotate_encryption_key(db_session, old_key=old_key, dry_run=dry_run) + + if results: + for col, count in results.items(): + print( + f" {col}: {count} row(s) {'would be ' if dry_run else ''}re-encrypted" + ) + else: + print("No rows needed re-encryption.") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Re-encrypt secrets under the current encryption key." + ) + parser.add_argument( + "--old-key", + default=None, + help="Previous encryption key. Omit or pass empty string if not applicable.", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Show what would be re-encrypted without making changes.", + ) + + tenant_group = parser.add_mutually_exclusive_group() + tenant_group.add_argument( + "--tenant-id", + default=None, + help="Target a specific tenant schema.", + ) + tenant_group.add_argument( + "--all-tenants", + action="store_true", + help="Iterate all tenants.", + ) + + args = parser.parse_args() + + old_key = args.old_key if args.old_key else None + + global_version.set_ee() + SqlEngine.init_engine(pool_size=5, max_overflow=2) + + if args.dry_run: + print("DRY RUN — no changes will be made") + + if args.all_tenants: + tenant_ids = get_all_tenant_ids() + print(f"Found {len(tenant_ids)} tenant(s)") + failed_tenants: list[str] = [] + for tid in tenant_ids: + try: + _run_for_tenant(tid, old_key, dry_run=args.dry_run) + except Exception as e: + print(f" ERROR for tenant {tid}: {e}") + failed_tenants.append(tid) + if failed_tenants: + print(f"FAILED tenants ({len(failed_tenants)}): {failed_tenants}") + sys.exit(1) + else: + tenant_id = args.tenant_id or POSTGRES_DEFAULT_SCHEMA + _run_for_tenant(tenant_id, old_key, dry_run=args.dry_run) + + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/backend/scripts/restart_containers.sh b/backend/scripts/restart_containers.sh index 708e1921cd6..759ba6933b9 100755 --- a/backend/scripts/restart_containers.sh +++ b/backend/scripts/restart_containers.sh @@ -1,10 +1,20 @@ #!/bin/bash set -e -cleanup() { - echo "Error occurred. Cleaning up..." +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +COMPOSE_FILE="$SCRIPT_DIR/../../deployment/docker_compose/docker-compose.yml" +COMPOSE_DEV_FILE="$SCRIPT_DIR/../../deployment/docker_compose/docker-compose.dev.yml" + +stop_and_remove_containers() { docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true + docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled stop opensearch 2>/dev/null || true + docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled rm -f opensearch 2>/dev/null || true +} + +cleanup() { + echo "Error occurred. Cleaning up..." + stop_and_remove_containers } # Trap errors and output a message, then cleanup @@ -12,16 +22,26 @@ trap 'echo "Error occurred on line $LINENO. Exiting script." >&2; cleanup' ERR # Usage of the script with optional volume arguments # ./restart_containers.sh [vespa_volume] [postgres_volume] [redis_volume] - -VESPA_VOLUME=${1:-""} # Default is empty if not provided -POSTGRES_VOLUME=${2:-""} # Default is empty if not provided -REDIS_VOLUME=${3:-""} # Default is empty if not provided -MINIO_VOLUME=${4:-""} # Default is empty if not provided +# [minio_volume] [--keep-opensearch-data] + +KEEP_OPENSEARCH_DATA=false +POSITIONAL_ARGS=() +for arg in "$@"; do + if [[ "$arg" == "--keep-opensearch-data" ]]; then + KEEP_OPENSEARCH_DATA=true + else + POSITIONAL_ARGS+=("$arg") + fi +done + +VESPA_VOLUME=${POSITIONAL_ARGS[0]:-""} +POSTGRES_VOLUME=${POSITIONAL_ARGS[1]:-""} +REDIS_VOLUME=${POSITIONAL_ARGS[2]:-""} +MINIO_VOLUME=${POSITIONAL_ARGS[3]:-""} # Stop and remove the existing containers echo "Stopping and removing existing containers..." -docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true -docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true +stop_and_remove_containers # Start the PostgreSQL container with optional volume echo "Starting PostgreSQL container..." @@ -39,6 +59,29 @@ else docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 vespaengine/vespa:8 fi +# If OPENSEARCH_ADMIN_PASSWORD is not already set, try loading it from +# .vscode/.env so existing dev setups that stored it there aren't silently +# broken. +VSCODE_ENV="$SCRIPT_DIR/../../.vscode/.env" +if [[ -z "${OPENSEARCH_ADMIN_PASSWORD:-}" && -f "$VSCODE_ENV" ]]; then + set -a + # shellcheck source=/dev/null + source "$VSCODE_ENV" + set +a +fi + +# Start the OpenSearch container using the same service from docker-compose that +# our users use, setting OPENSEARCH_INITIAL_ADMIN_PASSWORD from the env's +# OPENSEARCH_ADMIN_PASSWORD if it exists, else defaulting to StrongPassword123!. +# Pass --keep-opensearch-data to preserve the opensearch-data volume across +# restarts, else the volume is deleted so the container starts fresh. +if [[ "$KEEP_OPENSEARCH_DATA" == "false" ]]; then + echo "Deleting opensearch-data volume..." + docker volume rm onyx_opensearch-data 2>/dev/null || true +fi +echo "Starting OpenSearch container..." +docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled up --force-recreate -d opensearch + # Start the Redis container with optional volume echo "Starting Redis container..." if [[ -n "$REDIS_VOLUME" ]]; then @@ -60,7 +103,6 @@ echo "Starting Code Interpreter container..." docker run --detach --name onyx_code_interpreter --publish 8000:8000 --user root -v /var/run/docker.sock:/var/run/docker.sock onyxdotapp/code-interpreter:latest bash ./entrypoint.sh code-interpreter-api # Ensure alembic runs in the correct directory (backend/) -SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" PARENT_DIR="$(dirname "$SCRIPT_DIR")" cd "$PARENT_DIR" diff --git a/backend/scripts/restart_opensearch_container.sh b/backend/scripts/restart_opensearch_container.sh deleted file mode 100644 index ad3f2903091..00000000000 --- a/backend/scripts/restart_opensearch_container.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash - -# We get OPENSEARCH_ADMIN_PASSWORD from the repo .env file. -source "$(dirname "$0")/../../.vscode/.env" - -cd "$(dirname "$0")/../../deployment/docker_compose" - -# Start OpenSearch. -echo "Forcefully starting fresh OpenSearch container..." -docker compose -f docker-compose.opensearch.yml up --force-recreate -d opensearch diff --git a/backend/scripts/supervisord_entrypoint.sh b/backend/scripts/supervisord_entrypoint.sh index 463b6dae2b2..21cdf5bc81d 100755 --- a/backend/scripts/supervisord_entrypoint.sh +++ b/backend/scripts/supervisord_entrypoint.sh @@ -1,23 +1,5 @@ #!/bin/sh -# Entrypoint script for supervisord that sets environment variables -# for controlling which celery workers to start - -# Default to lightweight mode if not set -if [ -z "$USE_LIGHTWEIGHT_BACKGROUND_WORKER" ]; then - export USE_LIGHTWEIGHT_BACKGROUND_WORKER="true" -fi - -# Set the complementary variable for supervisord -# because it doesn't support %(not ENV_USE_LIGHTWEIGHT_BACKGROUND_WORKER) syntax -if [ "$USE_LIGHTWEIGHT_BACKGROUND_WORKER" = "true" ]; then - export USE_SEPARATE_BACKGROUND_WORKERS="false" -else - export USE_SEPARATE_BACKGROUND_WORKERS="true" -fi - -echo "Worker mode configuration:" -echo " USE_LIGHTWEIGHT_BACKGROUND_WORKER=$USE_LIGHTWEIGHT_BACKGROUND_WORKER" -echo " USE_SEPARATE_BACKGROUND_WORKERS=$USE_SEPARATE_BACKGROUND_WORKERS" +# Entrypoint script for supervisord # Launch supervisord with environment variables available exec /usr/bin/supervisord -c /etc/supervisor/conf.d/supervisord.conf diff --git a/backend/supervisord.conf b/backend/supervisord.conf index 446dfddd946..4847689dc57 100644 --- a/backend/supervisord.conf +++ b/backend/supervisord.conf @@ -39,7 +39,6 @@ autorestart=true startsecs=10 stopasgroup=true -# Standard mode: Light worker for fast operations # NOTE: only allowing configuration here and not in the other celery workers, # since this is often the bottleneck for "sync" jobs (e.g. document set syncing, # user group syncing, deletion, etc.) @@ -54,26 +53,7 @@ redirect_stderr=true autorestart=true startsecs=10 stopasgroup=true -autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s -# Lightweight mode: single consolidated background worker -# Used when USE_LIGHTWEIGHT_BACKGROUND_WORKER=true (default) -# Consolidates: light, docprocessing, docfetching, heavy, monitoring, user_file_processing -[program:celery_worker_background] -command=celery -A onyx.background.celery.versioned_apps.background worker - --loglevel=INFO - --hostname=background@%%n - -Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,sandbox,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,monitoring,user_file_processing,user_file_project_sync,opensearch_migration -stdout_logfile=/var/log/celery_worker_background.log -stdout_logfile_maxbytes=16MB -redirect_stderr=true -autorestart=true -startsecs=10 -stopasgroup=true -autostart=%(ENV_USE_LIGHTWEIGHT_BACKGROUND_WORKER)s - -# Standard mode: separate workers for different background tasks -# Used when USE_LIGHTWEIGHT_BACKGROUND_WORKER=false [program:celery_worker_heavy] command=celery -A onyx.background.celery.versioned_apps.heavy worker --loglevel=INFO @@ -85,9 +65,7 @@ redirect_stderr=true autorestart=true startsecs=10 stopasgroup=true -autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s -# Standard mode: Document processing worker [program:celery_worker_docprocessing] command=celery -A onyx.background.celery.versioned_apps.docprocessing worker --loglevel=INFO @@ -99,7 +77,6 @@ redirect_stderr=true autorestart=true startsecs=10 stopasgroup=true -autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s [program:celery_worker_user_file_processing] command=celery -A onyx.background.celery.versioned_apps.user_file_processing worker @@ -112,9 +89,7 @@ redirect_stderr=true autorestart=true startsecs=10 stopasgroup=true -autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s -# Standard mode: Document fetching worker [program:celery_worker_docfetching] command=celery -A onyx.background.celery.versioned_apps.docfetching worker --loglevel=INFO @@ -126,7 +101,6 @@ redirect_stderr=true autorestart=true startsecs=10 stopasgroup=true -autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s [program:celery_worker_monitoring] command=celery -A onyx.background.celery.versioned_apps.monitoring worker @@ -139,7 +113,6 @@ redirect_stderr=true autorestart=true startsecs=10 stopasgroup=true -autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s # Job scheduler for periodic tasks @@ -197,7 +170,6 @@ command=tail -qF /var/log/celery_beat.log /var/log/celery_worker_primary.log /var/log/celery_worker_light.log - /var/log/celery_worker_background.log /var/log/celery_worker_heavy.log /var/log/celery_worker_docprocessing.log /var/log/celery_worker_monitoring.log diff --git a/backend/tests/README.md b/backend/tests/README.md new file mode 100644 index 00000000000..80fbe8f2a37 --- /dev/null +++ b/backend/tests/README.md @@ -0,0 +1,71 @@ +# Backend Tests + +## Test Types + +There are four test categories, ordered by increasing scope: + +### Unit Tests (`tests/unit/`) + +No external services. Mock all I/O with `unittest.mock`. Use for complex, isolated +logic (e.g. citation processing, encryption). + +```bash +pytest -xv backend/tests/unit +``` + +### External Dependency Unit Tests (`tests/external_dependency_unit/`) + +External services (Postgres, Redis, Vespa, OpenAI, etc.) are running, but Onyx +application containers are not. Tests call functions directly and can mock selectively. + +Use when you need a real database or real API calls but want control over setup. + +```bash +python -m dotenv -f .vscode/.env run -- pytest backend/tests/external_dependency_unit +``` + +### Integration Tests (`tests/integration/`) + +Full Onyx deployment running. No mocking. Prefer this over other test types when possible. + +```bash +python -m dotenv -f .vscode/.env run -- pytest backend/tests/integration +``` + +### Playwright / E2E Tests (`web/tests/e2e/`) + +Full stack including web server. Use for frontend-backend coordination. + +```bash +npx playwright test +``` + +## Shared Fixtures + +Shared fixtures live in `backend/tests/conftest.py`. Test subdirectories can define +their own `conftest.py` for directory-scoped fixtures. + +## Best Practices + +### Use `enable_ee` fixture instead of inlining + +Enables EE mode for a test, with proper teardown and cache clearing. + +```python +# Whole file (in a test module, NOT in conftest.py) +pytestmark = pytest.mark.usefixtures("enable_ee") + +# Whole directory — add an autouse wrapper to the directory's conftest.py +@pytest.fixture(autouse=True) +def _enable_ee_for_directory(enable_ee: None) -> None: # noqa: ARG001 + """Wraps the shared enable_ee fixture with autouse for this directory.""" + +# Single test +def test_something(enable_ee: None) -> None: ... +``` + +**Note:** `pytestmark` in a `conftest.py` does NOT apply markers to tests in that +directory — it only affects tests defined in the conftest itself (which is none). +Use the autouse fixture wrapper pattern shown above instead. + +Do NOT inline `global_version.set_ee()` — always use the fixture. diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 00000000000..1bcb468eb45 --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,24 @@ +"""Root conftest — shared fixtures available to all test directories.""" + +from collections.abc import Generator + +import pytest + +from onyx.utils.variable_functionality import fetch_versioned_implementation +from onyx.utils.variable_functionality import global_version + + +@pytest.fixture() +def enable_ee() -> Generator[None, None, None]: + """Temporarily enable EE mode for a single test. + + Restores the previous EE state and clears the versioned-implementation + cache on teardown so state doesn't leak between tests. + """ + was_ee = global_version.is_ee_version() + global_version.set_ee() + fetch_versioned_implementation.cache_clear() + yield + if not was_ee: + global_version.unset_ee() + fetch_versioned_implementation.cache_clear() diff --git a/backend/tests/daily/conftest.py b/backend/tests/daily/conftest.py index 3a7da882d59..8fe64edd061 100644 --- a/backend/tests/daily/conftest.py +++ b/backend/tests/daily/conftest.py @@ -19,7 +19,7 @@ from onyx.auth.users import current_admin_user from onyx.db.engine.sql_engine import get_session from onyx.db.models import UserRole -from onyx.main import fetch_versioned_implementation +from onyx.main import get_application from onyx.utils.logger import setup_logger logger = setup_logger() @@ -51,11 +51,8 @@ def client() -> Generator[TestClient, None, None]: # Patch out prometheus metrics setup to avoid "Duplicated timeseries in # CollectorRegistry" errors when multiple tests each create a new app # (prometheus registers metrics globally and rejects duplicate names). - get_app = fetch_versioned_implementation( - module="onyx.main", attribute="get_application" - ) with patch("onyx.main.setup_prometheus_metrics"): - app: FastAPI = get_app(lifespan_override=test_lifespan) + app: FastAPI = get_application(lifespan_override=test_lifespan) # Override the database session dependency with a mock # (these tests don't actually need DB access) diff --git a/backend/tests/daily/connectors/confluence/test_confluence_permissions_basic.py b/backend/tests/daily/connectors/confluence/test_confluence_permissions_basic.py index d47a98efd90..40edb036b24 100644 --- a/backend/tests/daily/connectors/confluence/test_confluence_permissions_basic.py +++ b/backend/tests/daily/connectors/confluence/test_confluence_permissions_basic.py @@ -45,7 +45,7 @@ def confluence_connector() -> ConfluenceConnector: def test_confluence_connector_permissions( mock_get_api_key: MagicMock, # noqa: ARG001 confluence_connector: ConfluenceConnector, - set_ee_on: None, # noqa: ARG001 + enable_ee: None, # noqa: ARG001 ) -> None: # Get all doc IDs from the full connector all_full_doc_ids = set() @@ -93,7 +93,7 @@ def test_confluence_connector_permissions( def test_confluence_connector_restriction_handling( mock_get_api_key: MagicMock, # noqa: ARG001 mock_db_provider_class: MagicMock, - set_ee_on: None, # noqa: ARG001 + enable_ee: None, # noqa: ARG001 ) -> None: # Test space key test_space_key = "DailyPermS" diff --git a/backend/tests/daily/connectors/conftest.py b/backend/tests/daily/connectors/conftest.py index ad702e4cfcd..88a00b57af1 100644 --- a/backend/tests/daily/connectors/conftest.py +++ b/backend/tests/daily/connectors/conftest.py @@ -4,8 +4,6 @@ import pytest -from onyx.utils.variable_functionality import global_version - @pytest.fixture def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]: @@ -14,14 +12,3 @@ def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]: return_value=None, ) as mock: yield mock - - -@pytest.fixture -def set_ee_on() -> Generator[None, None, None]: - """Need EE to be enabled for these tests to work since - perm syncing is a an EE-only feature.""" - global_version.set_ee() - - yield - - global_version._is_ee = False diff --git a/backend/tests/daily/connectors/gitlab/test_gitlab_basic.py b/backend/tests/daily/connectors/gitlab/test_gitlab_basic.py index 6b1e515b169..781c396f4cb 100644 --- a/backend/tests/daily/connectors/gitlab/test_gitlab_basic.py +++ b/backend/tests/daily/connectors/gitlab/test_gitlab_basic.py @@ -48,7 +48,7 @@ def test_gitlab_connector_basic(gitlab_connector: GitlabConnector) -> None: # --- Specific Document Details to Validate --- target_mr_id = f"https://{gitlab_base_url}/{project_path}/-/merge_requests/1" - target_issue_id = f"https://{gitlab_base_url}/{project_path}/-/issues/2" + target_issue_id = f"https://{gitlab_base_url}/{project_path}/-/work_items/2" target_code_file_semantic_id = "README.md" # --- diff --git a/backend/tests/daily/connectors/google_drive/test_drive_perm_sync.py b/backend/tests/daily/connectors/google_drive/test_drive_perm_sync.py index 5a8e9d6b2fe..0f5777e372c 100644 --- a/backend/tests/daily/connectors/google_drive/test_drive_perm_sync.py +++ b/backend/tests/daily/connectors/google_drive/test_drive_perm_sync.py @@ -98,7 +98,7 @@ def _build_connector( def test_gdrive_perm_sync_with_real_data( google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], - set_ee_on: None, # noqa: ARG001 + enable_ee: None, # noqa: ARG001 ) -> None: """ Test gdrive_doc_sync and gdrive_group_sync with real data from the test drive. diff --git a/backend/tests/daily/connectors/slack/test_slack_perm_sync.py b/backend/tests/daily/connectors/slack/test_slack_perm_sync.py index 619482ba964..d53404414ad 100644 --- a/backend/tests/daily/connectors/slack/test_slack_perm_sync.py +++ b/backend/tests/daily/connectors/slack/test_slack_perm_sync.py @@ -1,12 +1,10 @@ import time -from collections.abc import Generator import pytest from onyx.connectors.models import HierarchyNode from onyx.connectors.models import SlimDocument from onyx.connectors.slack.connector import SlackConnector -from onyx.utils.variable_functionality import global_version from tests.daily.connectors.utils import load_all_from_connector @@ -19,16 +17,7 @@ "test_user_2@onyx-test.com", ] - -@pytest.fixture(autouse=True) -def set_ee_on() -> Generator[None, None, None]: - """Need EE to be enabled for these tests to work since - perm syncing is a an EE-only feature.""" - global_version.set_ee() - - yield - - global_version._is_ee = False +pytestmark = pytest.mark.usefixtures("enable_ee") @pytest.mark.parametrize( diff --git a/backend/tests/daily/connectors/teams/test_teams_connector.py b/backend/tests/daily/connectors/teams/test_teams_connector.py index 2b5c1e62cbd..39d8f52fd84 100644 --- a/backend/tests/daily/connectors/teams/test_teams_connector.py +++ b/backend/tests/daily/connectors/teams/test_teams_connector.py @@ -1,13 +1,11 @@ import os import time -from collections.abc import Generator import pytest from onyx.access.models import ExternalAccess from onyx.connectors.models import HierarchyNode from onyx.connectors.teams.connector import TeamsConnector -from onyx.utils.variable_functionality import global_version from tests.daily.connectors.teams.models import TeamsThread from tests.daily.connectors.utils import load_all_from_connector @@ -168,18 +166,9 @@ def test_slim_docs_retrieval_from_teams_connector( _assert_is_valid_external_access(external_access=slim_doc.external_access) -@pytest.fixture(autouse=False) -def set_ee_on() -> Generator[None, None, None]: - """Need EE to be enabled for perm sync tests to work since - perm syncing is an EE-only feature.""" - global_version.set_ee() - yield - global_version._is_ee = False - - def test_load_from_checkpoint_with_perm_sync( teams_connector: TeamsConnector, - set_ee_on: None, # noqa: ARG001 + enable_ee: None, # noqa: ARG001 ) -> None: """Test that load_from_checkpoint_with_perm_sync returns documents with external_access. diff --git a/backend/tests/external_dependency_unit/background/test_periodic_task_claim.py b/backend/tests/external_dependency_unit/background/test_periodic_task_claim.py new file mode 100644 index 00000000000..de5fc44edf4 --- /dev/null +++ b/backend/tests/external_dependency_unit/background/test_periodic_task_claim.py @@ -0,0 +1,257 @@ +"""External dependency unit tests for periodic task claiming. + +Tests ``_try_claim_task`` and ``_try_run_periodic_task`` against real +PostgreSQL, verifying happy-path behavior and concurrent-access safety. + +The claim mechanism uses a transaction-scoped advisory lock + a KVStore +timestamp for cross-instance dedup. The DB session is released before +the task runs, so long-running tasks don't hold connections. +""" + +import time +from collections.abc import Generator +from concurrent.futures import as_completed +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from datetime import timedelta +from datetime import timezone +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest + +from onyx.background.periodic_poller import _PeriodicTaskDef +from onyx.background.periodic_poller import _try_claim_task +from onyx.background.periodic_poller import _try_run_periodic_task +from onyx.background.periodic_poller import PERIODIC_TASK_KV_PREFIX +from onyx.db.engine.sql_engine import get_session_with_current_tenant +from onyx.db.engine.sql_engine import SqlEngine +from onyx.db.models import KVStore +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR +from tests.external_dependency_unit.constants import TEST_TENANT_ID + +_TEST_LOCK_BASE = 90_000 + + +@pytest.fixture(scope="module", autouse=True) +def _init_engine() -> None: + SqlEngine.init_engine(pool_size=10, max_overflow=5) + + +def _make_task( + *, + name: str | None = None, + interval: float = 3600, + lock_id: int | None = None, + run_fn: MagicMock | None = None, +) -> _PeriodicTaskDef: + return _PeriodicTaskDef( + name=name if name is not None else f"test-{uuid4().hex[:8]}", + interval_seconds=interval, + lock_id=lock_id if lock_id is not None else _TEST_LOCK_BASE, + run_fn=run_fn if run_fn is not None else MagicMock(), + ) + + +@pytest.fixture(autouse=True) +def _cleanup_kv( + tenant_context: None, # noqa: ARG001 +) -> Generator[None, None, None]: + yield + with get_session_with_current_tenant() as db_session: + db_session.query(KVStore).filter( + KVStore.key.like(f"{PERIODIC_TASK_KV_PREFIX}test-%") + ).delete(synchronize_session=False) + db_session.commit() + + +# ------------------------------------------------------------------ +# Happy-path: _try_claim_task +# ------------------------------------------------------------------ + + +class TestClaimHappyPath: + def test_first_claim_succeeds(self) -> None: + assert _try_claim_task(_make_task()) is True + + def test_first_claim_creates_kv_row(self) -> None: + task = _make_task() + _try_claim_task(task) + + with get_session_with_current_tenant() as db_session: + row = ( + db_session.query(KVStore) + .filter_by(key=PERIODIC_TASK_KV_PREFIX + task.name) + .first() + ) + assert row is not None + assert row.value is not None + + def test_second_claim_within_interval_fails(self) -> None: + task = _make_task(interval=3600) + assert _try_claim_task(task) is True + assert _try_claim_task(task) is False + + def test_claim_after_interval_succeeds(self) -> None: + task = _make_task(interval=1) + assert _try_claim_task(task) is True + + kv_key = PERIODIC_TASK_KV_PREFIX + task.name + with get_session_with_current_tenant() as db_session: + row = db_session.query(KVStore).filter_by(key=kv_key).first() + assert row is not None + row.value = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat() + db_session.commit() + + assert _try_claim_task(task) is True + + +# ------------------------------------------------------------------ +# Happy-path: _try_run_periodic_task +# ------------------------------------------------------------------ + + +class TestRunHappyPath: + def test_runs_task_and_updates_last_run_at(self) -> None: + mock_fn = MagicMock() + task = _make_task(run_fn=mock_fn) + + _try_run_periodic_task(task) + + mock_fn.assert_called_once() + assert task.last_run_at > 0 + + def test_skips_when_in_memory_interval_not_elapsed(self) -> None: + mock_fn = MagicMock() + task = _make_task(run_fn=mock_fn, interval=3600) + task.last_run_at = time.monotonic() + + _try_run_periodic_task(task) + + mock_fn.assert_not_called() + + def test_skips_when_db_claim_blocked(self) -> None: + name = f"test-{uuid4().hex[:8]}" + lock_id = _TEST_LOCK_BASE + 10 + + _try_claim_task(_make_task(name=name, lock_id=lock_id, interval=3600)) + + mock_fn = MagicMock() + task = _make_task(name=name, lock_id=lock_id, interval=3600, run_fn=mock_fn) + _try_run_periodic_task(task) + + mock_fn.assert_not_called() + + def test_task_exception_does_not_propagate(self) -> None: + task = _make_task(run_fn=MagicMock(side_effect=RuntimeError("boom"))) + _try_run_periodic_task(task) + + def test_claim_committed_before_task_runs(self) -> None: + """The KV claim must be visible in the DB when run_fn executes.""" + task_name = f"test-order-{uuid4().hex[:8]}" + kv_key = PERIODIC_TASK_KV_PREFIX + task_name + claim_visible: list[bool] = [] + + def check_claim() -> None: + with get_session_with_current_tenant() as db_session: + row = db_session.query(KVStore).filter_by(key=kv_key).first() + claim_visible.append(row is not None and row.value is not None) + + task = _PeriodicTaskDef( + name=task_name, + interval_seconds=3600, + lock_id=_TEST_LOCK_BASE + 11, + run_fn=check_claim, + ) + + _try_run_periodic_task(task) + + assert claim_visible == [True] + + +# ------------------------------------------------------------------ +# Concurrency: only one claimer should win +# ------------------------------------------------------------------ + + +class TestClaimConcurrency: + def test_concurrent_claims_single_winner(self) -> None: + """Many threads claim the same task — exactly one should succeed.""" + num_threads = 20 + task_name = f"test-race-{uuid4().hex[:8]}" + lock_id = _TEST_LOCK_BASE + 20 + + def claim() -> bool: + CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID) + return _try_claim_task( + _PeriodicTaskDef( + name=task_name, + interval_seconds=3600, + lock_id=lock_id, + run_fn=lambda: None, + ) + ) + + results: list[bool] = [] + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(claim) for _ in range(num_threads)] + for future in as_completed(futures): + results.append(future.result()) + + winners = sum(1 for r in results if r) + assert winners == 1, f"Expected 1 winner, got {winners}" + + def test_concurrent_run_single_execution(self) -> None: + """Many threads run the same task — run_fn fires exactly once.""" + num_threads = 20 + task_name = f"test-run-race-{uuid4().hex[:8]}" + lock_id = _TEST_LOCK_BASE + 21 + counter = MagicMock() + + def run() -> None: + CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID) + _try_run_periodic_task( + _PeriodicTaskDef( + name=task_name, + interval_seconds=3600, + lock_id=lock_id, + run_fn=counter, + ) + ) + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(run) for _ in range(num_threads)] + for future in as_completed(futures): + future.result() + + assert ( + counter.call_count == 1 + ), f"Expected run_fn called once, got {counter.call_count}" + + def test_no_errors_under_contention(self) -> None: + """All threads complete without exceptions under high contention.""" + num_threads = 30 + task_name = f"test-err-{uuid4().hex[:8]}" + lock_id = _TEST_LOCK_BASE + 22 + errors: list[Exception] = [] + + def claim() -> bool: + CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID) + return _try_claim_task( + _PeriodicTaskDef( + name=task_name, + interval_seconds=3600, + lock_id=lock_id, + run_fn=lambda: None, + ) + ) + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(claim) for _ in range(num_threads)] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + errors.append(e) + + assert errors == [], f"Got {len(errors)} errors: {errors}" diff --git a/backend/tests/external_dependency_unit/background/test_startup_recovery.py b/backend/tests/external_dependency_unit/background/test_startup_recovery.py new file mode 100644 index 00000000000..1769b92f3e7 --- /dev/null +++ b/backend/tests/external_dependency_unit/background/test_startup_recovery.py @@ -0,0 +1,352 @@ +"""External dependency unit tests for startup recovery (Step 10g). + +Seeds ``UserFile`` records in stuck states (PROCESSING, DELETING, +needs_project_sync) then calls ``recover_stuck_user_files`` and verifies +the drain loops pick them up via ``FOR UPDATE SKIP LOCKED``. + +Uses real PostgreSQL (via ``db_session`` / ``tenant_context`` fixtures). +The per-file ``*_impl`` functions are mocked so no real file store or +connector is needed — we only verify that recovery finds and dispatches +the correct files. +""" + +from collections.abc import Generator +from unittest.mock import MagicMock +from unittest.mock import patch +from uuid import UUID +from uuid import uuid4 + +import pytest +import sqlalchemy as sa +from sqlalchemy.orm import Session + +from onyx.background.periodic_poller import recover_stuck_user_files +from onyx.db.enums import UserFileStatus +from onyx.db.models import UserFile +from tests.external_dependency_unit.conftest import create_test_user +from tests.external_dependency_unit.constants import TEST_TENANT_ID + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_IMPL_MODULE = "onyx.background.celery.tasks.user_file_processing.tasks" + + +def _create_user_file( + db_session: Session, + user_id: object, + *, + status: UserFileStatus = UserFileStatus.PROCESSING, + needs_project_sync: bool = False, + needs_persona_sync: bool = False, +) -> UserFile: + uf = UserFile( + id=uuid4(), + user_id=user_id, + file_id=f"test_file_{uuid4().hex[:8]}", + name=f"test_{uuid4().hex[:8]}.txt", + file_type="text/plain", + status=status, + needs_project_sync=needs_project_sync, + needs_persona_sync=needs_persona_sync, + ) + db_session.add(uf) + db_session.commit() + db_session.refresh(uf) + return uf + + +def _fake_delete_impl( + user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001 +) -> None: + """Mock side-effect: delete the row so the drain loop terminates.""" + from onyx.db.engine.sql_engine import get_session_with_current_tenant + + with get_session_with_current_tenant() as session: + session.execute(sa.delete(UserFile).where(UserFile.id == UUID(user_file_id))) + session.commit() + + +def _fake_sync_impl( + user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001 +) -> None: + """Mock side-effect: clear sync flags so the drain loop terminates.""" + from onyx.db.engine.sql_engine import get_session_with_current_tenant + + with get_session_with_current_tenant() as session: + session.execute( + sa.update(UserFile) + .where(UserFile.id == UUID(user_file_id)) + .values(needs_project_sync=False, needs_persona_sync=False) + ) + session.commit() + + +@pytest.fixture() +def _cleanup_user_files(db_session: Session) -> Generator[list[UserFile], None, None]: + """Track created UserFile rows and delete them after each test.""" + created: list[UserFile] = [] + yield created + for uf in created: + existing = db_session.get(UserFile, uf.id) + if existing: + db_session.delete(existing) + db_session.commit() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestRecoverProcessingFiles: + """Files in PROCESSING status are re-processed via the processing drain loop.""" + + def test_processing_files_recovered( + self, + db_session: Session, + tenant_context: None, # noqa: ARG002 + _cleanup_user_files: list[UserFile], + ) -> None: + user = create_test_user(db_session, "recovery_proc") + uf = _create_user_file(db_session, user.id, status=UserFileStatus.PROCESSING) + _cleanup_user_files.append(uf) + + mock_impl = MagicMock() + with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl): + recover_stuck_user_files(TEST_TENANT_ID) + + called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list] + assert ( + str(uf.id) in called_ids + ), f"Expected file {uf.id} to be recovered but got: {called_ids}" + + def test_completed_files_not_recovered( + self, + db_session: Session, + tenant_context: None, # noqa: ARG002 + _cleanup_user_files: list[UserFile], + ) -> None: + user = create_test_user(db_session, "recovery_comp") + uf = _create_user_file(db_session, user.id, status=UserFileStatus.COMPLETED) + _cleanup_user_files.append(uf) + + mock_impl = MagicMock() + with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl): + recover_stuck_user_files(TEST_TENANT_ID) + + called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list] + assert ( + str(uf.id) not in called_ids + ), f"COMPLETED file {uf.id} should not have been recovered" + + +class TestRecoverDeletingFiles: + """Files in DELETING status are recovered via the delete drain loop.""" + + def test_deleting_files_recovered( + self, + db_session: Session, + tenant_context: None, # noqa: ARG002 + _cleanup_user_files: list[UserFile], + ) -> None: + user = create_test_user(db_session, "recovery_del") + uf = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING) + # Row is deleted by _fake_delete_impl, so no cleanup needed. + + mock_impl = MagicMock(side_effect=_fake_delete_impl) + with patch(f"{_IMPL_MODULE}.delete_user_file_impl", mock_impl): + recover_stuck_user_files(TEST_TENANT_ID) + + called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list] + assert ( + str(uf.id) in called_ids + ), f"Expected file {uf.id} to be recovered for deletion but got: {called_ids}" + + +class TestRecoverSyncFiles: + """Files needing project/persona sync are recovered via the sync drain loop.""" + + def test_needs_project_sync_recovered( + self, + db_session: Session, + tenant_context: None, # noqa: ARG002 + _cleanup_user_files: list[UserFile], + ) -> None: + user = create_test_user(db_session, "recovery_sync") + uf = _create_user_file( + db_session, + user.id, + status=UserFileStatus.COMPLETED, + needs_project_sync=True, + ) + _cleanup_user_files.append(uf) + + mock_impl = MagicMock(side_effect=_fake_sync_impl) + with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl): + recover_stuck_user_files(TEST_TENANT_ID) + + called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list] + assert ( + str(uf.id) in called_ids + ), f"Expected file {uf.id} to be recovered for sync but got: {called_ids}" + + def test_needs_persona_sync_recovered( + self, + db_session: Session, + tenant_context: None, # noqa: ARG002 + _cleanup_user_files: list[UserFile], + ) -> None: + user = create_test_user(db_session, "recovery_psync") + uf = _create_user_file( + db_session, + user.id, + status=UserFileStatus.COMPLETED, + needs_persona_sync=True, + ) + _cleanup_user_files.append(uf) + + mock_impl = MagicMock(side_effect=_fake_sync_impl) + with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl): + recover_stuck_user_files(TEST_TENANT_ID) + + called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list] + assert ( + str(uf.id) in called_ids + ), f"Expected file {uf.id} to be recovered for persona sync but got: {called_ids}" + + +class TestRecoveryMultipleFiles: + """Recovery processes all stuck files in one pass, not just the first.""" + + def test_multiple_processing_files( + self, + db_session: Session, + tenant_context: None, # noqa: ARG002 + _cleanup_user_files: list[UserFile], + ) -> None: + user = create_test_user(db_session, "recovery_multi") + files = [] + for _ in range(3): + uf = _create_user_file( + db_session, user.id, status=UserFileStatus.PROCESSING + ) + _cleanup_user_files.append(uf) + files.append(uf) + + mock_impl = MagicMock() + with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl): + recover_stuck_user_files(TEST_TENANT_ID) + + called_ids = {call.kwargs["user_file_id"] for call in mock_impl.call_args_list} + expected_ids = {str(uf.id) for uf in files} + assert expected_ids.issubset(called_ids), ( + f"Expected all {len(files)} files to be recovered. " + f"Missing: {expected_ids - called_ids}" + ) + + +class TestTransientFailures: + """Drain loops skip failed files, process the rest, and terminate.""" + + def test_processing_failure_skips_and_continues( + self, + db_session: Session, + tenant_context: None, # noqa: ARG002 + _cleanup_user_files: list[UserFile], + ) -> None: + user = create_test_user(db_session, "fail_proc") + uf_fail = _create_user_file( + db_session, user.id, status=UserFileStatus.PROCESSING + ) + uf_ok = _create_user_file(db_session, user.id, status=UserFileStatus.PROCESSING) + _cleanup_user_files.extend([uf_fail, uf_ok]) + + fail_id = str(uf_fail.id) + + def side_effect( + *, user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001 + ) -> None: + if user_file_id == fail_id: + raise RuntimeError("transient failure") + + mock_impl = MagicMock(side_effect=side_effect) + with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl): + recover_stuck_user_files(TEST_TENANT_ID) + + called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list] + assert fail_id in called_ids, "Failed file should have been attempted" + assert str(uf_ok.id) in called_ids, "Healthy file should have been processed" + assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop" + assert called_ids.count(str(uf_ok.id)) == 1 + + def test_delete_failure_skips_and_continues( + self, + db_session: Session, + tenant_context: None, # noqa: ARG002 + _cleanup_user_files: list[UserFile], + ) -> None: + user = create_test_user(db_session, "fail_del") + uf_fail = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING) + uf_ok = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING) + _cleanup_user_files.append(uf_fail) + + fail_id = str(uf_fail.id) + + def side_effect( + *, user_file_id: str, tenant_id: str, redis_locking: bool + ) -> None: + if user_file_id == fail_id: + raise RuntimeError("transient failure") + _fake_delete_impl(user_file_id, tenant_id, redis_locking) + + mock_impl = MagicMock(side_effect=side_effect) + with patch(f"{_IMPL_MODULE}.delete_user_file_impl", mock_impl): + recover_stuck_user_files(TEST_TENANT_ID) + + called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list] + assert fail_id in called_ids, "Failed file should have been attempted" + assert str(uf_ok.id) in called_ids, "Healthy file should have been deleted" + assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop" + assert called_ids.count(str(uf_ok.id)) == 1 + + def test_sync_failure_skips_and_continues( + self, + db_session: Session, + tenant_context: None, # noqa: ARG002 + _cleanup_user_files: list[UserFile], + ) -> None: + user = create_test_user(db_session, "fail_sync") + uf_fail = _create_user_file( + db_session, + user.id, + status=UserFileStatus.COMPLETED, + needs_project_sync=True, + ) + uf_ok = _create_user_file( + db_session, + user.id, + status=UserFileStatus.COMPLETED, + needs_persona_sync=True, + ) + _cleanup_user_files.extend([uf_fail, uf_ok]) + + fail_id = str(uf_fail.id) + + def side_effect( + *, user_file_id: str, tenant_id: str, redis_locking: bool + ) -> None: + if user_file_id == fail_id: + raise RuntimeError("transient failure") + _fake_sync_impl(user_file_id, tenant_id, redis_locking) + + mock_impl = MagicMock(side_effect=side_effect) + with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl): + recover_stuck_user_files(TEST_TENANT_ID) + + called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list] + assert fail_id in called_ids, "Failed file should have been attempted" + assert str(uf_ok.id) in called_ids, "Healthy file should have been synced" + assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop" + assert called_ids.count(str(uf_ok.id)) == 1 diff --git a/backend/tests/external_dependency_unit/cache/conftest.py b/backend/tests/external_dependency_unit/cache/conftest.py new file mode 100644 index 00000000000..0bbdfd5139d --- /dev/null +++ b/backend/tests/external_dependency_unit/cache/conftest.py @@ -0,0 +1,57 @@ +"""Fixtures for cache backend tests. + +Requires a running PostgreSQL instance (and Redis for parity tests). +Run with:: + + python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/cache/ +""" + +from collections.abc import Generator + +import pytest + +from onyx.cache.interface import CacheBackend +from onyx.cache.postgres_backend import PostgresCacheBackend +from onyx.cache.redis_backend import RedisCacheBackend +from onyx.db.engine.sql_engine import SqlEngine +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR +from tests.external_dependency_unit.constants import TEST_TENANT_ID + + +@pytest.fixture(scope="session", autouse=True) +def _init_db() -> Generator[None, None, None]: + """Initialize DB engine. Assumes Postgres has migrations applied (e.g. via docker compose).""" + SqlEngine.init_engine(pool_size=5, max_overflow=2) + yield + + +@pytest.fixture(autouse=True) +def _tenant_context() -> Generator[None, None, None]: + token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID) + try: + yield + finally: + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + + +@pytest.fixture +def pg_cache() -> PostgresCacheBackend: + return PostgresCacheBackend(TEST_TENANT_ID) + + +@pytest.fixture +def redis_cache() -> RedisCacheBackend: + from onyx.redis.redis_pool import redis_pool + + return RedisCacheBackend(redis_pool.get_client(TEST_TENANT_ID)) + + +@pytest.fixture(params=["postgres", "redis"], ids=["postgres", "redis"]) +def cache( + request: pytest.FixtureRequest, + pg_cache: PostgresCacheBackend, + redis_cache: RedisCacheBackend, +) -> CacheBackend: + if request.param == "postgres": + return pg_cache + return redis_cache diff --git a/backend/tests/external_dependency_unit/cache/test_cache_backend_parity.py b/backend/tests/external_dependency_unit/cache/test_cache_backend_parity.py new file mode 100644 index 00000000000..7f1f0164d2c --- /dev/null +++ b/backend/tests/external_dependency_unit/cache/test_cache_backend_parity.py @@ -0,0 +1,100 @@ +"""Parameterized tests that run the same CacheBackend operations against +both Redis and PostgreSQL, asserting identical return values. + +Each test runs twice (once per backend) via the ``cache`` fixture defined +in conftest.py. +""" + +import time +from uuid import uuid4 + +from onyx.cache.interface import CacheBackend +from onyx.cache.interface import TTL_KEY_NOT_FOUND +from onyx.cache.interface import TTL_NO_EXPIRY + + +def _key() -> str: + return f"parity_{uuid4().hex[:12]}" + + +class TestKVParity: + def test_get_missing(self, cache: CacheBackend) -> None: + assert cache.get(_key()) is None + + def test_get_set(self, cache: CacheBackend) -> None: + k = _key() + cache.set(k, b"value") + assert cache.get(k) == b"value" + + def test_overwrite(self, cache: CacheBackend) -> None: + k = _key() + cache.set(k, b"a") + cache.set(k, b"b") + assert cache.get(k) == b"b" + + def test_set_string(self, cache: CacheBackend) -> None: + k = _key() + cache.set(k, "hello") + assert cache.get(k) == b"hello" + + def test_set_int(self, cache: CacheBackend) -> None: + k = _key() + cache.set(k, 42) + assert cache.get(k) == b"42" + + def test_delete(self, cache: CacheBackend) -> None: + k = _key() + cache.set(k, b"x") + cache.delete(k) + assert cache.get(k) is None + + def test_exists(self, cache: CacheBackend) -> None: + k = _key() + assert not cache.exists(k) + cache.set(k, b"x") + assert cache.exists(k) + + +class TestTTLParity: + def test_ttl_missing(self, cache: CacheBackend) -> None: + assert cache.ttl(_key()) == TTL_KEY_NOT_FOUND + + def test_ttl_no_expiry(self, cache: CacheBackend) -> None: + k = _key() + cache.set(k, b"x") + assert cache.ttl(k) == TTL_NO_EXPIRY + + def test_ttl_remaining(self, cache: CacheBackend) -> None: + k = _key() + cache.set(k, b"x", ex=10) + remaining = cache.ttl(k) + assert 8 <= remaining <= 10 + + def test_set_with_ttl_expires(self, cache: CacheBackend) -> None: + k = _key() + cache.set(k, b"x", ex=1) + assert cache.get(k) == b"x" + time.sleep(1.5) + assert cache.get(k) is None + + +class TestLockParity: + def test_acquire_release(self, cache: CacheBackend) -> None: + lock = cache.lock(f"parity_lock_{uuid4().hex[:8]}") + assert lock.acquire(blocking=False) + assert lock.owned() + lock.release() + assert not lock.owned() + + +class TestListParity: + def test_rpush_blpop(self, cache: CacheBackend) -> None: + k = f"parity_list_{uuid4().hex[:8]}" + cache.rpush(k, b"item") + result = cache.blpop([k], timeout=1) + assert result is not None + assert result[1] == b"item" + + def test_blpop_timeout(self, cache: CacheBackend) -> None: + result = cache.blpop([f"parity_empty_{uuid4().hex[:8]}"], timeout=1) + assert result is None diff --git a/backend/tests/external_dependency_unit/cache/test_kv_store_cache_layer.py b/backend/tests/external_dependency_unit/cache/test_kv_store_cache_layer.py new file mode 100644 index 00000000000..d74efcc1157 --- /dev/null +++ b/backend/tests/external_dependency_unit/cache/test_kv_store_cache_layer.py @@ -0,0 +1,129 @@ +"""Tests for PgRedisKVStore's cache layer integration with CacheBackend. + +Verifies that the KV store correctly uses the CacheBackend for caching +in front of PostgreSQL: cache hits, cache misses falling through to PG, +cache population after PG reads, cache invalidation on delete, and +graceful degradation when the cache backend raises. + +Requires running PostgreSQL. +""" + +import json +from collections.abc import Generator +from unittest.mock import MagicMock + +import pytest +from sqlalchemy import delete + +from onyx.cache.interface import CacheBackend +from onyx.cache.postgres_backend import PostgresCacheBackend +from onyx.db.engine.sql_engine import get_session_with_tenant +from onyx.db.models import CacheStore +from onyx.db.models import KVStore +from onyx.key_value_store.interface import KvKeyNotFoundError +from onyx.key_value_store.store import PgRedisKVStore +from onyx.key_value_store.store import REDIS_KEY_PREFIX +from tests.external_dependency_unit.constants import TEST_TENANT_ID + + +@pytest.fixture(autouse=True) +def _clean_kv() -> Generator[None, None, None]: + yield + with get_session_with_tenant(tenant_id=TEST_TENANT_ID) as session: + session.execute(delete(KVStore)) + session.execute(delete(CacheStore)) + session.commit() + + +@pytest.fixture +def kv_store(pg_cache: PostgresCacheBackend) -> PgRedisKVStore: + return PgRedisKVStore(cache=pg_cache) + + +class TestStoreAndLoad: + def test_store_populates_cache_and_pg( + self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend + ) -> None: + kv_store.store("k1", {"hello": "world"}) + + cached = pg_cache.get(REDIS_KEY_PREFIX + "k1") + assert cached is not None + assert json.loads(cached) == {"hello": "world"} + + loaded = kv_store.load("k1") + assert loaded == {"hello": "world"} + + def test_load_returns_cached_value_without_pg_hit( + self, pg_cache: PostgresCacheBackend + ) -> None: + """If the cache already has the value, PG should not be queried.""" + pg_cache.set(REDIS_KEY_PREFIX + "cached_only", json.dumps({"from": "cache"})) + kv = PgRedisKVStore(cache=pg_cache) + assert kv.load("cached_only") == {"from": "cache"} + + def test_load_falls_through_to_pg_on_cache_miss( + self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend + ) -> None: + kv_store.store("k2", [1, 2, 3]) + + pg_cache.delete(REDIS_KEY_PREFIX + "k2") + assert pg_cache.get(REDIS_KEY_PREFIX + "k2") is None + + loaded = kv_store.load("k2") + assert loaded == [1, 2, 3] + + repopulated = pg_cache.get(REDIS_KEY_PREFIX + "k2") + assert repopulated is not None + assert json.loads(repopulated) == [1, 2, 3] + + def test_load_with_refresh_cache_skips_cache( + self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend + ) -> None: + kv_store.store("k3", "original") + + pg_cache.set(REDIS_KEY_PREFIX + "k3", json.dumps("stale")) + + loaded = kv_store.load("k3", refresh_cache=True) + assert loaded == "original" + + +class TestDelete: + def test_delete_removes_from_cache_and_pg( + self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend + ) -> None: + kv_store.store("del_me", "bye") + kv_store.delete("del_me") + + assert pg_cache.get(REDIS_KEY_PREFIX + "del_me") is None + + with pytest.raises(KvKeyNotFoundError): + kv_store.load("del_me") + + def test_delete_missing_key_raises(self, kv_store: PgRedisKVStore) -> None: + with pytest.raises(KvKeyNotFoundError): + kv_store.delete("nonexistent") + + +class TestCacheFailureGracefulDegradation: + def test_store_succeeds_when_cache_set_raises(self) -> None: + failing_cache = MagicMock(spec=CacheBackend) + failing_cache.set.side_effect = ConnectionError("cache down") + + kv = PgRedisKVStore(cache=failing_cache) + kv.store("resilient", {"data": True}) + + working_cache = MagicMock(spec=CacheBackend) + working_cache.get.return_value = None + kv_reader = PgRedisKVStore(cache=working_cache) + loaded = kv_reader.load("resilient") + assert loaded == {"data": True} + + def test_load_falls_through_when_cache_get_raises(self) -> None: + failing_cache = MagicMock(spec=CacheBackend) + failing_cache.get.side_effect = ConnectionError("cache down") + failing_cache.set.side_effect = ConnectionError("cache down") + + kv = PgRedisKVStore(cache=failing_cache) + kv.store("survive", 42) + loaded = kv.load("survive") + assert loaded == 42 diff --git a/backend/tests/external_dependency_unit/cache/test_postgres_cache_backend.py b/backend/tests/external_dependency_unit/cache/test_postgres_cache_backend.py new file mode 100644 index 00000000000..54975c54a5b --- /dev/null +++ b/backend/tests/external_dependency_unit/cache/test_postgres_cache_backend.py @@ -0,0 +1,229 @@ +"""Tests for PostgresCacheBackend against real PostgreSQL. + +Covers every method on the backend: KV CRUD, TTL behaviour, advisory +locks (acquire / release / contention), list operations (rpush / blpop), +and the periodic cleanup function. +""" + +import time +from uuid import uuid4 + +from sqlalchemy import select + +from onyx.cache.interface import TTL_KEY_NOT_FOUND +from onyx.cache.interface import TTL_NO_EXPIRY +from onyx.cache.postgres_backend import cleanup_expired_cache_entries +from onyx.cache.postgres_backend import PostgresCacheBackend +from onyx.db.models import CacheStore + + +def _key() -> str: + return f"test_{uuid4().hex[:12]}" + + +# ------------------------------------------------------------------ +# Basic KV +# ------------------------------------------------------------------ + + +class TestKV: + def test_get_set(self, pg_cache: PostgresCacheBackend) -> None: + k = _key() + pg_cache.set(k, b"hello") + assert pg_cache.get(k) == b"hello" + + def test_get_missing(self, pg_cache: PostgresCacheBackend) -> None: + assert pg_cache.get(_key()) is None + + def test_set_overwrite(self, pg_cache: PostgresCacheBackend) -> None: + k = _key() + pg_cache.set(k, b"first") + pg_cache.set(k, b"second") + assert pg_cache.get(k) == b"second" + + def test_set_string_value(self, pg_cache: PostgresCacheBackend) -> None: + k = _key() + pg_cache.set(k, "string_val") + assert pg_cache.get(k) == b"string_val" + + def test_set_int_value(self, pg_cache: PostgresCacheBackend) -> None: + k = _key() + pg_cache.set(k, 42) + assert pg_cache.get(k) == b"42" + + def test_delete(self, pg_cache: PostgresCacheBackend) -> None: + k = _key() + pg_cache.set(k, b"to_delete") + pg_cache.delete(k) + assert pg_cache.get(k) is None + + def test_delete_missing_is_noop(self, pg_cache: PostgresCacheBackend) -> None: + pg_cache.delete(_key()) + + def test_exists(self, pg_cache: PostgresCacheBackend) -> None: + k = _key() + assert not pg_cache.exists(k) + pg_cache.set(k, b"x") + assert pg_cache.exists(k) + + +# ------------------------------------------------------------------ +# TTL +# ------------------------------------------------------------------ + + +class TestTTL: + def test_set_with_ttl_expires(self, pg_cache: PostgresCacheBackend) -> None: + k = _key() + pg_cache.set(k, b"ephemeral", ex=1) + assert pg_cache.get(k) == b"ephemeral" + time.sleep(1.5) + assert pg_cache.get(k) is None + + def test_ttl_no_expiry(self, pg_cache: PostgresCacheBackend) -> None: + k = _key() + pg_cache.set(k, b"forever") + assert pg_cache.ttl(k) == TTL_NO_EXPIRY + + def test_ttl_missing_key(self, pg_cache: PostgresCacheBackend) -> None: + assert pg_cache.ttl(_key()) == TTL_KEY_NOT_FOUND + + def test_ttl_remaining(self, pg_cache: PostgresCacheBackend) -> None: + k = _key() + pg_cache.set(k, b"x", ex=10) + remaining = pg_cache.ttl(k) + assert 8 <= remaining <= 10 + + def test_ttl_expired_key(self, pg_cache: PostgresCacheBackend) -> None: + k = _key() + pg_cache.set(k, b"x", ex=1) + time.sleep(1.5) + assert pg_cache.ttl(k) == TTL_KEY_NOT_FOUND + + def test_expire_adds_ttl(self, pg_cache: PostgresCacheBackend) -> None: + k = _key() + pg_cache.set(k, b"x") + assert pg_cache.ttl(k) == TTL_NO_EXPIRY + pg_cache.expire(k, 10) + assert 8 <= pg_cache.ttl(k) <= 10 + + def test_exists_respects_ttl(self, pg_cache: PostgresCacheBackend) -> None: + k = _key() + pg_cache.set(k, b"x", ex=1) + assert pg_cache.exists(k) + time.sleep(1.5) + assert not pg_cache.exists(k) + + +# ------------------------------------------------------------------ +# Locks +# ------------------------------------------------------------------ + + +class TestLock: + def test_acquire_release(self, pg_cache: PostgresCacheBackend) -> None: + lock = pg_cache.lock(f"lock_{uuid4().hex[:8]}") + assert lock.acquire(blocking=False) + assert lock.owned() + lock.release() + assert not lock.owned() + + def test_contention(self, pg_cache: PostgresCacheBackend) -> None: + name = f"contention_{uuid4().hex[:8]}" + lock1 = pg_cache.lock(name) + lock2 = pg_cache.lock(name) + + assert lock1.acquire(blocking=False) + assert not lock2.acquire(blocking=False) + + lock1.release() + assert lock2.acquire(blocking=False) + lock2.release() + + def test_context_manager(self, pg_cache: PostgresCacheBackend) -> None: + with pg_cache.lock(f"ctx_{uuid4().hex[:8]}") as lock: + assert lock.owned() + assert not lock.owned() + + def test_blocking_timeout(self, pg_cache: PostgresCacheBackend) -> None: + name = f"timeout_{uuid4().hex[:8]}" + holder = pg_cache.lock(name) + holder.acquire(blocking=False) + + waiter = pg_cache.lock(name, timeout=0.3) + start = time.monotonic() + assert not waiter.acquire(blocking=True, blocking_timeout=0.3) + elapsed = time.monotonic() - start + assert elapsed >= 0.25 + + holder.release() + + +# ------------------------------------------------------------------ +# List (rpush / blpop) +# ------------------------------------------------------------------ + + +class TestList: + def test_rpush_blpop(self, pg_cache: PostgresCacheBackend) -> None: + k = f"list_{uuid4().hex[:8]}" + pg_cache.rpush(k, b"item1") + result = pg_cache.blpop([k], timeout=1) + assert result is not None + assert result == (k.encode(), b"item1") + + def test_blpop_timeout(self, pg_cache: PostgresCacheBackend) -> None: + result = pg_cache.blpop([f"empty_{uuid4().hex[:8]}"], timeout=1) + assert result is None + + def test_fifo_order(self, pg_cache: PostgresCacheBackend) -> None: + k = f"fifo_{uuid4().hex[:8]}" + pg_cache.rpush(k, b"first") + time.sleep(0.01) + pg_cache.rpush(k, b"second") + + r1 = pg_cache.blpop([k], timeout=1) + r2 = pg_cache.blpop([k], timeout=1) + assert r1 is not None and r1[1] == b"first" + assert r2 is not None and r2[1] == b"second" + + def test_multiple_keys(self, pg_cache: PostgresCacheBackend) -> None: + k1 = f"mk1_{uuid4().hex[:8]}" + k2 = f"mk2_{uuid4().hex[:8]}" + pg_cache.rpush(k2, b"from_k2") + + result = pg_cache.blpop([k1, k2], timeout=1) + assert result is not None + assert result == (k2.encode(), b"from_k2") + + +# ------------------------------------------------------------------ +# Cleanup +# ------------------------------------------------------------------ + + +class TestCleanup: + def test_removes_expired_rows(self, pg_cache: PostgresCacheBackend) -> None: + from onyx.db.engine.sql_engine import get_session_with_current_tenant + + k = _key() + pg_cache.set(k, b"stale", ex=1) + time.sleep(1.5) + cleanup_expired_cache_entries() + + stmt = select(CacheStore.key).where(CacheStore.key == k) + with get_session_with_current_tenant() as session: + row = session.execute(stmt).first() + assert row is None, "expired row should be physically deleted" + + def test_preserves_unexpired_rows(self, pg_cache: PostgresCacheBackend) -> None: + k = _key() + pg_cache.set(k, b"fresh", ex=300) + cleanup_expired_cache_entries() + assert pg_cache.get(k) == b"fresh" + + def test_preserves_no_ttl_rows(self, pg_cache: PostgresCacheBackend) -> None: + k = _key() + pg_cache.set(k, b"permanent") + cleanup_expired_cache_entries() + assert pg_cache.get(k) == b"permanent" diff --git a/backend/tests/external_dependency_unit/celery/test_docprocessing_priority.py b/backend/tests/external_dependency_unit/celery/test_docprocessing_priority.py index 3fb0a64d00e..c8db850546a 100644 --- a/backend/tests/external_dependency_unit/celery/test_docprocessing_priority.py +++ b/backend/tests/external_dependency_unit/celery/test_docprocessing_priority.py @@ -145,6 +145,10 @@ class TestDocprocessingPriorityInDocumentExtraction: @patch("onyx.background.indexing.run_docfetching.get_document_batch_storage") @patch("onyx.background.indexing.run_docfetching.MemoryTracer") @patch("onyx.background.indexing.run_docfetching._get_connector_runner") + @patch( + "onyx.background.indexing.run_docfetching.strip_null_characters", + side_effect=lambda batch: batch, + ) @patch( "onyx.background.indexing.run_docfetching.get_recent_completed_attempts_for_cc_pair" ) @@ -169,6 +173,7 @@ def test_docprocessing_priority_based_on_last_successful_index_time( mock_save_checkpoint: MagicMock, # noqa: ARG002 mock_get_last_successful_attempt_poll_range_end: MagicMock, mock_get_recent_completed_attempts: MagicMock, + mock_strip_null_characters: MagicMock, # noqa: ARG002 mock_get_connector_runner: MagicMock, mock_memory_tracer_class: MagicMock, mock_get_batch_storage: MagicMock, diff --git a/backend/tests/external_dependency_unit/celery/test_pruning_hierarchy_nodes.py b/backend/tests/external_dependency_unit/celery/test_pruning_hierarchy_nodes.py index 4e40b8adbc3..4bba85cb23f 100644 --- a/backend/tests/external_dependency_unit/celery/test_pruning_hierarchy_nodes.py +++ b/backend/tests/external_dependency_unit/celery/test_pruning_hierarchy_nodes.py @@ -5,6 +5,10 @@ 1. extract_ids_from_runnable_connector correctly separates hierarchy nodes from doc IDs 2. Extracted hierarchy nodes are correctly upserted to Postgres via upsert_hierarchy_nodes_batch 3. Upserting is idempotent (running twice doesn't duplicate nodes) +4. Document-to-hierarchy-node linkage is updated during pruning +5. link_hierarchy_nodes_to_documents links nodes that are also documents +6. HierarchyNodeByConnectorCredentialPair join table population and pruning +7. Orphaned hierarchy node deletion and re-parenting Uses a mock SlimConnectorWithPermSync that yields known hierarchy nodes and slim documents, combined with a real PostgreSQL database for verifying persistence. @@ -22,14 +26,29 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnectorWithPermSync from onyx.connectors.models import HierarchyNode as PydanticHierarchyNode +from onyx.connectors.models import InputType from onyx.connectors.models import SlimDocument +from onyx.db.enums import AccessType +from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import HierarchyNodeType +from onyx.db.hierarchy import delete_orphaned_hierarchy_nodes from onyx.db.hierarchy import ensure_source_node_exists from onyx.db.hierarchy import get_all_hierarchy_nodes_for_source from onyx.db.hierarchy import get_hierarchy_node_by_raw_id +from onyx.db.hierarchy import link_hierarchy_nodes_to_documents +from onyx.db.hierarchy import remove_stale_hierarchy_node_cc_pair_entries +from onyx.db.hierarchy import reparent_orphaned_hierarchy_nodes +from onyx.db.hierarchy import update_document_parent_hierarchy_nodes +from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries from onyx.db.hierarchy import upsert_hierarchy_nodes_batch +from onyx.db.models import Connector +from onyx.db.models import ConnectorCredentialPair +from onyx.db.models import Credential +from onyx.db.models import Document as DbDocument from onyx.db.models import HierarchyNode as DBHierarchyNode +from onyx.db.models import HierarchyNodeByConnectorCredentialPair from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface +from onyx.kg.models import KGStage # --------------------------------------------------------------------------- # Constants @@ -89,8 +108,18 @@ def _make_hierarchy_nodes() -> list[PydanticHierarchyNode]: ] +DOC_PARENT_MAP = { + "msg-001": CHANNEL_A_ID, + "msg-002": CHANNEL_A_ID, + "msg-003": CHANNEL_B_ID, +} + + def _make_slim_docs() -> list[SlimDocument | PydanticHierarchyNode]: - return [SlimDocument(id=doc_id) for doc_id in SLIM_DOC_IDS] + return [ + SlimDocument(id=doc_id, parent_hierarchy_raw_node_id=DOC_PARENT_MAP.get(doc_id)) + for doc_id in SLIM_DOC_IDS + ] class MockSlimConnectorWithPermSync(SlimConnectorWithPermSync): @@ -126,14 +155,98 @@ def _generate(self) -> Iterator[list[SlimDocument | PydanticHierarchyNode]]: # --------------------------------------------------------------------------- -def _cleanup_test_hierarchy_nodes(db_session: Session) -> None: - """Remove all hierarchy nodes for TEST_SOURCE to isolate tests.""" +def _create_cc_pair( + db_session: Session, + source: DocumentSource = TEST_SOURCE, +) -> ConnectorCredentialPair: + """Create a real Connector + Credential + ConnectorCredentialPair for testing.""" + connector = Connector( + name=f"Test {source.value} Connector", + source=source, + input_type=InputType.LOAD_STATE, + connector_specific_config={}, + ) + db_session.add(connector) + db_session.flush() + + credential = Credential( + source=source, + credential_json={}, + admin_public=True, + ) + db_session.add(credential) + db_session.flush() + db_session.expire(credential) + + cc_pair = ConnectorCredentialPair( + connector_id=connector.id, + credential_id=credential.id, + name=f"Test {source.value} CC Pair", + status=ConnectorCredentialPairStatus.ACTIVE, + access_type=AccessType.PUBLIC, + ) + db_session.add(cc_pair) + db_session.commit() + db_session.refresh(cc_pair) + return cc_pair + + +def _cleanup_test_data(db_session: Session) -> None: + """Remove all test hierarchy nodes and documents to isolate tests.""" + for doc_id in SLIM_DOC_IDS: + db_session.query(DbDocument).filter(DbDocument.id == doc_id).delete() + + test_connector_ids_q = db_session.query(Connector.id).filter( + Connector.source == TEST_SOURCE, + Connector.name.like("Test %"), + ) + + db_session.query(HierarchyNodeByConnectorCredentialPair).filter( + HierarchyNodeByConnectorCredentialPair.connector_id.in_(test_connector_ids_q) + ).delete(synchronize_session="fetch") db_session.query(DBHierarchyNode).filter( DBHierarchyNode.source == TEST_SOURCE ).delete() + db_session.flush() + + # Collect credential IDs before deleting cc_pairs (bulk query.delete() + # bypasses ORM-level cascade, so credentials won't be auto-removed). + credential_ids = [ + row[0] + for row in db_session.query(ConnectorCredentialPair.credential_id) + .filter(ConnectorCredentialPair.connector_id.in_(test_connector_ids_q)) + .all() + ] + + db_session.query(ConnectorCredentialPair).filter( + ConnectorCredentialPair.connector_id.in_(test_connector_ids_q) + ).delete(synchronize_session="fetch") + db_session.query(Connector).filter( + Connector.source == TEST_SOURCE, + Connector.name.like("Test %"), + ).delete(synchronize_session="fetch") + if credential_ids: + db_session.query(Credential).filter(Credential.id.in_(credential_ids)).delete( + synchronize_session="fetch" + ) db_session.commit() +def _create_test_documents(db_session: Session) -> list[DbDocument]: + """Insert minimal Document rows for our test doc IDs.""" + docs = [] + for doc_id in SLIM_DOC_IDS: + doc = DbDocument( + id=doc_id, + semantic_id=doc_id, + kg_stage=KGStage.NOT_STARTED, + ) + db_session.add(doc) + docs.append(doc) + db_session.commit() + return docs + + # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- @@ -146,15 +259,8 @@ def test_pruning_extracts_hierarchy_nodes(db_session: Session) -> None: # noqa: result = extract_ids_from_runnable_connector(connector, callback=None) - # Doc IDs should include both slim doc IDs and hierarchy node raw_node_ids - # (hierarchy node IDs are added to doc_ids so they aren't pruned) - expected_ids = { - CHANNEL_A_ID, - CHANNEL_B_ID, - CHANNEL_C_ID, - *SLIM_DOC_IDS, - } - assert result.doc_ids == expected_ids + # raw_id_to_parent should contain ONLY document IDs, not hierarchy node IDs + assert result.raw_id_to_parent.keys() == set(SLIM_DOC_IDS) # Hierarchy nodes should be the 3 channels assert len(result.hierarchy_nodes) == 3 @@ -165,7 +271,7 @@ def test_pruning_extracts_hierarchy_nodes(db_session: Session) -> None: # noqa: def test_pruning_upserts_hierarchy_nodes_to_db(db_session: Session) -> None: """Full flow: extract hierarchy nodes from mock connector, upsert to Postgres, then verify the DB state (node count, parent relationships, permissions).""" - _cleanup_test_hierarchy_nodes(db_session) + _cleanup_test_data(db_session) # Step 1: ensure the SOURCE node exists (mirrors what the pruning task does) source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True) @@ -230,7 +336,7 @@ def test_pruning_upserts_hierarchy_nodes_public_connector( ) -> None: """When the connector's access type is PUBLIC, all hierarchy nodes must be marked is_public=True regardless of their external_access settings.""" - _cleanup_test_hierarchy_nodes(db_session) + _cleanup_test_data(db_session) ensure_source_node_exists(db_session, TEST_SOURCE, commit=True) @@ -257,7 +363,7 @@ def test_pruning_upserts_hierarchy_nodes_public_connector( def test_pruning_hierarchy_node_upsert_idempotency(db_session: Session) -> None: """Upserting the same hierarchy nodes twice must not create duplicates. The second call should update existing rows in place.""" - _cleanup_test_hierarchy_nodes(db_session) + _cleanup_test_data(db_session) ensure_source_node_exists(db_session, TEST_SOURCE, commit=True) @@ -295,7 +401,7 @@ def test_pruning_hierarchy_node_upsert_idempotency(db_session: Session) -> None: def test_pruning_hierarchy_node_upsert_updates_fields(db_session: Session) -> None: """Upserting a hierarchy node with changed fields should update the existing row.""" - _cleanup_test_hierarchy_nodes(db_session) + _cleanup_test_data(db_session) ensure_source_node_exists(db_session, TEST_SOURCE, commit=True) @@ -342,3 +448,431 @@ def test_pruning_hierarchy_node_upsert_updates_fields(db_session: Session) -> No assert db_node.is_public is True assert db_node.external_user_emails is not None assert set(db_node.external_user_emails) == {"new_user@example.com"} + + +# --------------------------------------------------------------------------- +# Document-to-hierarchy-node linkage tests +# --------------------------------------------------------------------------- + + +def test_extraction_preserves_parent_hierarchy_raw_node_id( + db_session: Session, # noqa: ARG001 +) -> None: + """extract_ids_from_runnable_connector should carry the + parent_hierarchy_raw_node_id from SlimDocument into the raw_id_to_parent dict.""" + connector = MockSlimConnectorWithPermSync() + result = extract_ids_from_runnable_connector(connector, callback=None) + + for doc_id, expected_parent in DOC_PARENT_MAP.items(): + assert ( + result.raw_id_to_parent[doc_id] == expected_parent + ), f"raw_id_to_parent[{doc_id}] should be {expected_parent}" + + # Hierarchy node IDs should NOT be in raw_id_to_parent + for channel_id in [CHANNEL_A_ID, CHANNEL_B_ID, CHANNEL_C_ID]: + assert channel_id not in result.raw_id_to_parent + + +def test_update_document_parent_hierarchy_nodes(db_session: Session) -> None: + """update_document_parent_hierarchy_nodes should set + Document.parent_hierarchy_node_id for each document in the mapping.""" + _cleanup_test_data(db_session) + + source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True) + upserted = upsert_hierarchy_nodes_batch( + db_session=db_session, + nodes=_make_hierarchy_nodes(), + source=TEST_SOURCE, + commit=True, + is_connector_public=False, + ) + node_id_by_raw = {n.raw_node_id: n.id for n in upserted} + + # Create documents with no parent set + docs = _create_test_documents(db_session) + for doc in docs: + assert doc.parent_hierarchy_node_id is None + + # Build resolved map (same logic as _resolve_and_update_document_parents) + resolved: dict[str, int | None] = {} + for doc_id, raw_parent in DOC_PARENT_MAP.items(): + resolved[doc_id] = node_id_by_raw.get(raw_parent, source_node.id) + + updated = update_document_parent_hierarchy_nodes( + db_session=db_session, + doc_parent_map=resolved, + commit=True, + ) + assert updated == len(SLIM_DOC_IDS) + + # Verify each document now points to the correct hierarchy node + db_session.expire_all() + for doc_id, raw_parent in DOC_PARENT_MAP.items(): + tmp_doc = db_session.get(DbDocument, doc_id) + assert tmp_doc is not None + doc = tmp_doc + expected_node_id = node_id_by_raw[raw_parent] + assert ( + doc.parent_hierarchy_node_id == expected_node_id + ), f"Document {doc_id} should point to node for {raw_parent}" + + +def test_update_document_parent_is_idempotent(db_session: Session) -> None: + """Running update_document_parent_hierarchy_nodes a second time with the + same mapping should update zero rows.""" + _cleanup_test_data(db_session) + + ensure_source_node_exists(db_session, TEST_SOURCE, commit=True) + upserted = upsert_hierarchy_nodes_batch( + db_session=db_session, + nodes=_make_hierarchy_nodes(), + source=TEST_SOURCE, + commit=True, + is_connector_public=False, + ) + node_id_by_raw = {n.raw_node_id: n.id for n in upserted} + _create_test_documents(db_session) + + resolved: dict[str, int | None] = { + doc_id: node_id_by_raw[raw_parent] + for doc_id, raw_parent in DOC_PARENT_MAP.items() + } + + first_updated = update_document_parent_hierarchy_nodes( + db_session=db_session, + doc_parent_map=resolved, + commit=True, + ) + assert first_updated == len(SLIM_DOC_IDS) + + second_updated = update_document_parent_hierarchy_nodes( + db_session=db_session, + doc_parent_map=resolved, + commit=True, + ) + assert second_updated == 0 + + +def test_link_hierarchy_nodes_to_documents_for_confluence( + db_session: Session, +) -> None: + """For sources in SOURCES_WITH_HIERARCHY_NODE_DOCUMENTS (e.g. Confluence), + link_hierarchy_nodes_to_documents should set HierarchyNode.document_id + when a hierarchy node's raw_node_id matches a document ID.""" + _cleanup_test_data(db_session) + confluence_source = DocumentSource.CONFLUENCE + + # Clean up any existing Confluence hierarchy nodes + db_session.query(DBHierarchyNode).filter( + DBHierarchyNode.source == confluence_source + ).delete() + db_session.commit() + + ensure_source_node_exists(db_session, confluence_source, commit=True) + + # Create a hierarchy node whose raw_node_id matches a document ID + page_node_id = "confluence-page-123" + nodes = [ + PydanticHierarchyNode( + raw_node_id=page_node_id, + raw_parent_id=None, + display_name="Test Page", + link="https://wiki.example.com/page/123", + node_type=HierarchyNodeType.PAGE, + ), + ] + upsert_hierarchy_nodes_batch( + db_session=db_session, + nodes=nodes, + source=confluence_source, + commit=True, + is_connector_public=False, + ) + + # Verify the node exists but has no document_id yet + db_node = get_hierarchy_node_by_raw_id(db_session, page_node_id, confluence_source) + assert db_node is not None + assert db_node.document_id is None + + # Create a document with the same ID as the hierarchy node + doc = DbDocument( + id=page_node_id, + semantic_id="Test Page", + kg_stage=KGStage.NOT_STARTED, + ) + db_session.add(doc) + db_session.commit() + + # Link nodes to documents + linked = link_hierarchy_nodes_to_documents( + db_session=db_session, + document_ids=[page_node_id], + source=confluence_source, + commit=True, + ) + assert linked == 1 + + # Verify the hierarchy node now has document_id set + db_session.expire_all() + db_node = get_hierarchy_node_by_raw_id(db_session, page_node_id, confluence_source) + assert db_node is not None + assert db_node.document_id == page_node_id + + # Cleanup + db_session.query(DbDocument).filter(DbDocument.id == page_node_id).delete() + db_session.query(DBHierarchyNode).filter( + DBHierarchyNode.source == confluence_source + ).delete() + db_session.commit() + + +def test_link_hierarchy_nodes_skips_non_hierarchy_sources( + db_session: Session, +) -> None: + """link_hierarchy_nodes_to_documents should return 0 for sources that + don't support hierarchy-node-as-document (e.g. Slack, Google Drive).""" + linked = link_hierarchy_nodes_to_documents( + db_session=db_session, + document_ids=SLIM_DOC_IDS, + source=TEST_SOURCE, # Slack — not in SOURCES_WITH_HIERARCHY_NODE_DOCUMENTS + commit=False, + ) + assert linked == 0 + + +# --------------------------------------------------------------------------- +# Join table + pruning tests +# --------------------------------------------------------------------------- + + +def test_upsert_hierarchy_node_cc_pair_entries(db_session: Session) -> None: + """upsert_hierarchy_node_cc_pair_entries should insert rows and be idempotent.""" + _cleanup_test_data(db_session) + ensure_source_node_exists(db_session, TEST_SOURCE, commit=True) + cc_pair = _create_cc_pair(db_session) + + upserted = upsert_hierarchy_nodes_batch( + db_session=db_session, + nodes=_make_hierarchy_nodes(), + source=TEST_SOURCE, + commit=True, + is_connector_public=False, + ) + node_ids = [n.id for n in upserted] + + # First call — should insert rows + upsert_hierarchy_node_cc_pair_entries( + db_session=db_session, + hierarchy_node_ids=node_ids, + connector_id=cc_pair.connector_id, + credential_id=cc_pair.credential_id, + commit=True, + ) + + rows = ( + db_session.query(HierarchyNodeByConnectorCredentialPair) + .filter( + HierarchyNodeByConnectorCredentialPair.connector_id == cc_pair.connector_id, + HierarchyNodeByConnectorCredentialPair.credential_id + == cc_pair.credential_id, + ) + .all() + ) + assert len(rows) == 3 + + # Second call — idempotent, same count + upsert_hierarchy_node_cc_pair_entries( + db_session=db_session, + hierarchy_node_ids=node_ids, + connector_id=cc_pair.connector_id, + credential_id=cc_pair.credential_id, + commit=True, + ) + rows_after = ( + db_session.query(HierarchyNodeByConnectorCredentialPair) + .filter( + HierarchyNodeByConnectorCredentialPair.connector_id == cc_pair.connector_id, + HierarchyNodeByConnectorCredentialPair.credential_id + == cc_pair.credential_id, + ) + .all() + ) + assert len(rows_after) == 3 + + +def test_remove_stale_entries_and_delete_orphans(db_session: Session) -> None: + """After removing stale join-table entries, orphaned hierarchy nodes should + be deleted and the SOURCE node should survive.""" + _cleanup_test_data(db_session) + source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True) + cc_pair = _create_cc_pair(db_session) + + upserted = upsert_hierarchy_nodes_batch( + db_session=db_session, + nodes=_make_hierarchy_nodes(), + source=TEST_SOURCE, + commit=True, + is_connector_public=False, + ) + all_ids = [n.id for n in upserted] + upsert_hierarchy_node_cc_pair_entries( + db_session=db_session, + hierarchy_node_ids=all_ids, + connector_id=cc_pair.connector_id, + credential_id=cc_pair.credential_id, + commit=True, + ) + + # Now simulate a pruning run where only channel A survived + channel_a = get_hierarchy_node_by_raw_id(db_session, CHANNEL_A_ID, TEST_SOURCE) + assert channel_a is not None + live_ids = {channel_a.id} + + stale_removed = remove_stale_hierarchy_node_cc_pair_entries( + db_session=db_session, + connector_id=cc_pair.connector_id, + credential_id=cc_pair.credential_id, + live_hierarchy_node_ids=live_ids, + commit=True, + ) + assert stale_removed == 2 + + # Delete orphaned nodes + deleted_raw_ids = delete_orphaned_hierarchy_nodes( + db_session=db_session, + source=TEST_SOURCE, + commit=True, + ) + assert set(deleted_raw_ids) == {CHANNEL_B_ID, CHANNEL_C_ID} + + # Verify only channel A + SOURCE remain + remaining = get_all_hierarchy_nodes_for_source(db_session, TEST_SOURCE) + remaining_raw = {n.raw_node_id for n in remaining} + assert remaining_raw == {CHANNEL_A_ID, source_node.raw_node_id} + + +def test_multi_cc_pair_prevents_premature_deletion(db_session: Session) -> None: + """A hierarchy node shared by two cc_pairs should NOT be deleted when only + one cc_pair removes its association.""" + _cleanup_test_data(db_session) + ensure_source_node_exists(db_session, TEST_SOURCE, commit=True) + cc_pair_1 = _create_cc_pair(db_session) + cc_pair_2 = _create_cc_pair(db_session) + + upserted = upsert_hierarchy_nodes_batch( + db_session=db_session, + nodes=_make_hierarchy_nodes(), + source=TEST_SOURCE, + commit=True, + is_connector_public=False, + ) + all_ids = [n.id for n in upserted] + + # cc_pair 1 owns all 3 + upsert_hierarchy_node_cc_pair_entries( + db_session=db_session, + hierarchy_node_ids=all_ids, + connector_id=cc_pair_1.connector_id, + credential_id=cc_pair_1.credential_id, + commit=True, + ) + # cc_pair 2 also owns all 3 + upsert_hierarchy_node_cc_pair_entries( + db_session=db_session, + hierarchy_node_ids=all_ids, + connector_id=cc_pair_2.connector_id, + credential_id=cc_pair_2.credential_id, + commit=True, + ) + + # cc_pair 1 prunes — keeps none + remove_stale_hierarchy_node_cc_pair_entries( + db_session=db_session, + connector_id=cc_pair_1.connector_id, + credential_id=cc_pair_1.credential_id, + live_hierarchy_node_ids=set(), + commit=True, + ) + + # Orphan deletion should find nothing because cc_pair 2 still references them + deleted = delete_orphaned_hierarchy_nodes( + db_session=db_session, + source=TEST_SOURCE, + commit=True, + ) + assert deleted == [] + + # All 3 nodes + SOURCE should still exist + remaining = get_all_hierarchy_nodes_for_source(db_session, TEST_SOURCE) + assert len(remaining) == 4 + + +def test_reparent_orphaned_children(db_session: Session) -> None: + """After deleting a parent hierarchy node, its children should be + re-parented to the SOURCE node.""" + _cleanup_test_data(db_session) + source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True) + cc_pair = _create_cc_pair(db_session) + + # Create a parent node and a child node + parent_node = PydanticHierarchyNode( + raw_node_id="PARENT", + raw_parent_id=None, + display_name="Parent", + node_type=HierarchyNodeType.CHANNEL, + ) + child_node = PydanticHierarchyNode( + raw_node_id="CHILD", + raw_parent_id="PARENT", + display_name="Child", + node_type=HierarchyNodeType.CHANNEL, + ) + upserted = upsert_hierarchy_nodes_batch( + db_session=db_session, + nodes=[parent_node, child_node], + source=TEST_SOURCE, + commit=True, + is_connector_public=False, + ) + assert len(upserted) == 2 + + parent_db = get_hierarchy_node_by_raw_id(db_session, "PARENT", TEST_SOURCE) + child_db = get_hierarchy_node_by_raw_id(db_session, "CHILD", TEST_SOURCE) + assert parent_db is not None and child_db is not None + assert child_db.parent_id == parent_db.id + + # Associate only the child with a cc_pair (parent is orphaned) + upsert_hierarchy_node_cc_pair_entries( + db_session=db_session, + hierarchy_node_ids=[child_db.id], + connector_id=cc_pair.connector_id, + credential_id=cc_pair.credential_id, + commit=True, + ) + + # Delete orphaned nodes (parent has no cc_pair entry) + deleted = delete_orphaned_hierarchy_nodes( + db_session=db_session, + source=TEST_SOURCE, + commit=True, + ) + assert "PARENT" in deleted + + # Child should now have parent_id=NULL (SET NULL cascade) + db_session.expire_all() + child_db = get_hierarchy_node_by_raw_id(db_session, "CHILD", TEST_SOURCE) + assert child_db is not None + assert child_db.parent_id is None + + # Re-parent orphans to SOURCE + reparented = reparent_orphaned_hierarchy_nodes( + db_session=db_session, + source=TEST_SOURCE, + commit=True, + ) + assert len(reparented) == 1 + + db_session.expire_all() + child_db = get_hierarchy_node_by_raw_id(db_session, "CHILD", TEST_SOURCE) + assert child_db is not None + assert child_db.parent_id == source_node.id diff --git a/backend/tests/external_dependency_unit/connectors/jira/test_jira_doc_sync.py b/backend/tests/external_dependency_unit/connectors/jira/test_jira_doc_sync.py index a089ec439de..d6dfa9c9bf6 100644 --- a/backend/tests/external_dependency_unit/connectors/jira/test_jira_doc_sync.py +++ b/backend/tests/external_dependency_unit/connectors/jira/test_jira_doc_sync.py @@ -1,5 +1,6 @@ from typing import Any +import pytest from pydantic import BaseModel from sqlalchemy.orm import Session @@ -14,13 +15,14 @@ from onyx.db.models import Credential from onyx.db.utils import DocumentRow from onyx.db.utils import SortOrder -from onyx.utils.variable_functionality import global_version # In order to get these tests to run, use the credentials from Bitwarden. # Search up "ENV vars for local and Github tests", and find the Jira relevant key-value pairs. # Required env vars: JIRA_USER_EMAIL, JIRA_API_TOKEN +pytestmark = pytest.mark.usefixtures("enable_ee") + class DocExternalAccessSet(BaseModel): """A version of DocExternalAccess that uses sets for comparison.""" @@ -52,9 +54,6 @@ def test_jira_doc_sync( This test uses the AS project which has applicationRole permission, meaning all documents should be marked as public. """ - # NOTE: must set EE on or else the connector will skip the perm syncing - global_version.set_ee() - try: # Use AS project specifically for this test connector_config = { @@ -150,9 +149,6 @@ def test_jira_doc_sync_with_specific_permissions( This test uses a project that has specific user permissions to verify that specific users are correctly extracted. """ - # NOTE: must set EE on or else the connector will skip the perm syncing - global_version.set_ee() - try: # Use SUP project which has specific user permissions connector_config = { diff --git a/backend/tests/external_dependency_unit/connectors/jira/test_jira_group_sync.py b/backend/tests/external_dependency_unit/connectors/jira/test_jira_group_sync.py index 052fb53ec1a..b82e95bddd1 100644 --- a/backend/tests/external_dependency_unit/connectors/jira/test_jira_group_sync.py +++ b/backend/tests/external_dependency_unit/connectors/jira/test_jira_group_sync.py @@ -1,5 +1,6 @@ from typing import Any +import pytest from sqlalchemy.orm import Session from ee.onyx.external_permissions.jira.group_sync import jira_group_sync @@ -18,6 +19,8 @@ # Search up "ENV vars for local and Github tests", and find the Jira relevant key-value pairs. # Required env vars: JIRA_USER_EMAIL, JIRA_API_TOKEN +pytestmark = pytest.mark.usefixtures("enable_ee") + # Expected groups from the danswerai.atlassian.net Jira instance # Note: These groups are shared with Confluence since they're both Atlassian products # App accounts (bots, integrations) are filtered out diff --git a/backend/tests/external_dependency_unit/db/test_credential_sensitive_value.py b/backend/tests/external_dependency_unit/db/test_credential_sensitive_value.py new file mode 100644 index 00000000000..6337faa2249 --- /dev/null +++ b/backend/tests/external_dependency_unit/db/test_credential_sensitive_value.py @@ -0,0 +1,90 @@ +"""Test that Credential with nested JSON round-trips through SensitiveValue correctly. + +Exercises the full encrypt → store → read → decrypt → SensitiveValue path +with realistic nested OAuth credential data, and verifies SQLAlchemy dirty +tracking works with nested dict comparison. + +Requires a running Postgres instance. +""" + +from sqlalchemy.orm import Session + +from onyx.configs.constants import DocumentSource +from onyx.db.models import Credential +from onyx.utils.sensitive import SensitiveValue + +# NOTE: this is not the real shape of a Drive credential, +# but it is intended to test nested JSON credential handling + +_NESTED_CRED_JSON = { + "oauth_tokens": { + "access_token": "ya29.abc123", + "refresh_token": "1//xEg-def456", + }, + "scopes": ["read", "write", "admin"], + "client_config": { + "client_id": "123.apps.googleusercontent.com", + "client_secret": "GOCSPX-secret", + }, +} + + +def test_nested_credential_json_round_trip(db_session: Session) -> None: + """Nested OAuth credential survives encrypt → store → read → decrypt.""" + credential = Credential( + source=DocumentSource.GOOGLE_DRIVE, + credential_json=_NESTED_CRED_JSON, + ) + db_session.add(credential) + db_session.flush() + + # Immediate read (no DB round-trip) — tests the set event wrapping + assert isinstance(credential.credential_json, SensitiveValue) + assert credential.credential_json.get_value(apply_mask=False) == _NESTED_CRED_JSON + + # DB round-trip — tests process_result_value + db_session.expire(credential) + reloaded = credential.credential_json + assert isinstance(reloaded, SensitiveValue) + assert reloaded.get_value(apply_mask=False) == _NESTED_CRED_JSON + + db_session.rollback() + + +def test_reassign_same_nested_json_not_dirty(db_session: Session) -> None: + """Re-assigning the same nested dict should not mark the session dirty.""" + credential = Credential( + source=DocumentSource.GOOGLE_DRIVE, + credential_json=_NESTED_CRED_JSON, + ) + db_session.add(credential) + db_session.flush() + + # Clear dirty state from the insert + db_session.expire(credential) + _ = credential.credential_json # force reload + + # Re-assign identical value + credential.credential_json = _NESTED_CRED_JSON # type: ignore[assignment] + assert not db_session.is_modified(credential) + + db_session.rollback() + + +def test_assign_different_nested_json_is_dirty(db_session: Session) -> None: + """Assigning a different nested dict should mark the session dirty.""" + credential = Credential( + source=DocumentSource.GOOGLE_DRIVE, + credential_json=_NESTED_CRED_JSON, + ) + db_session.add(credential) + db_session.flush() + + db_session.expire(credential) + _ = credential.credential_json # force reload + + modified_cred = {**_NESTED_CRED_JSON, "scopes": ["read"]} + credential.credential_json = modified_cred # type: ignore[assignment] + assert db_session.is_modified(credential) + + db_session.rollback() diff --git a/backend/tests/external_dependency_unit/db/test_rotate_encryption_key.py b/backend/tests/external_dependency_unit/db/test_rotate_encryption_key.py new file mode 100644 index 00000000000..cded9f8b54a --- /dev/null +++ b/backend/tests/external_dependency_unit/db/test_rotate_encryption_key.py @@ -0,0 +1,305 @@ +"""Tests for rotate_encryption_key against real Postgres. + +Uses real ORM models (Credential, InternetSearchProvider) and the actual +Postgres database. Discovery is mocked in rotation tests to scope mutations +to only the test rows — the real _discover_encrypted_columns walk is tested +separately in TestDiscoverEncryptedColumns. + +Requires a running Postgres instance. Run with:: + + python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/db/test_rotate_encryption_key.py +""" + +import json +from collections.abc import Generator +from unittest.mock import patch + +import pytest +from sqlalchemy import LargeBinary +from sqlalchemy import select +from sqlalchemy import text +from sqlalchemy.orm import Session + +from ee.onyx.utils.encryption import _decrypt_bytes +from ee.onyx.utils.encryption import _encrypt_string +from ee.onyx.utils.encryption import _get_trimmed_key +from onyx.configs.constants import DocumentSource +from onyx.db.models import Credential +from onyx.db.models import EncryptedJson +from onyx.db.models import EncryptedString +from onyx.db.models import InternetSearchProvider +from onyx.db.rotate_encryption_key import _discover_encrypted_columns +from onyx.db.rotate_encryption_key import rotate_encryption_key +from onyx.utils.variable_functionality import fetch_versioned_implementation +from onyx.utils.variable_functionality import global_version + +EE_MODULE = "ee.onyx.utils.encryption" +ROTATE_MODULE = "onyx.db.rotate_encryption_key" + +OLD_KEY = "o" * 16 +NEW_KEY = "n" * 16 + + +@pytest.fixture(autouse=True) +def _enable_ee() -> Generator[None, None, None]: + prev = global_version._is_ee + global_version.set_ee() + fetch_versioned_implementation.cache_clear() + yield + global_version._is_ee = prev + fetch_versioned_implementation.cache_clear() + + +@pytest.fixture(autouse=True) +def _clear_key_cache() -> None: + _get_trimmed_key.cache_clear() + + +def _raw_credential_bytes(db_session: Session, credential_id: int) -> bytes | None: + """Read raw bytes from credential_json, bypassing the TypeDecorator.""" + col = Credential.__table__.c.credential_json + stmt = select(col.cast(LargeBinary)).where( + Credential.__table__.c.id == credential_id + ) + return db_session.execute(stmt).scalar() + + +def _raw_isp_bytes(db_session: Session, isp_id: int) -> bytes | None: + """Read raw bytes from InternetSearchProvider.api_key.""" + col = InternetSearchProvider.__table__.c.api_key + stmt = select(col.cast(LargeBinary)).where( + InternetSearchProvider.__table__.c.id == isp_id + ) + return db_session.execute(stmt).scalar() + + +class TestDiscoverEncryptedColumns: + """Verify _discover_encrypted_columns finds real production models.""" + + def test_discovers_credential_json(self) -> None: + results = _discover_encrypted_columns() + found = { + (model_cls.__tablename__, col_name, is_json) # type: ignore[attr-defined] + for model_cls, col_name, _, is_json in results + } + assert ("credential", "credential_json", True) in found + + def test_discovers_internet_search_provider_api_key(self) -> None: + results = _discover_encrypted_columns() + found = { + (model_cls.__tablename__, col_name, is_json) # type: ignore[attr-defined] + for model_cls, col_name, _, is_json in results + } + assert ("internet_search_provider", "api_key", False) in found + + def test_all_encrypted_string_columns_are_not_json(self) -> None: + results = _discover_encrypted_columns() + for model_cls, col_name, _, is_json in results: + col = getattr(model_cls, col_name).property.columns[0] + if isinstance(col.type, EncryptedString): + assert not is_json, ( + f"{model_cls.__tablename__}.{col_name} is EncryptedString " # type: ignore[attr-defined] + f"but is_json={is_json}" + ) + + def test_all_encrypted_json_columns_are_json(self) -> None: + results = _discover_encrypted_columns() + for model_cls, col_name, _, is_json in results: + col = getattr(model_cls, col_name).property.columns[0] + if isinstance(col.type, EncryptedJson): + assert is_json, ( + f"{model_cls.__tablename__}.{col_name} is EncryptedJson " # type: ignore[attr-defined] + f"but is_json={is_json}" + ) + + +class TestRotateCredential: + """Test rotation against the real Credential table (EncryptedJson). + + Discovery is scoped to only the Credential model to avoid mutating + other tables in the test database. + """ + + @pytest.fixture(autouse=True) + def _limit_discovery(self) -> Generator[None, None, None]: + with patch( + f"{ROTATE_MODULE}._discover_encrypted_columns", + return_value=[(Credential, "credential_json", ["id"], True)], + ): + yield + + @pytest.fixture() + def credential_id( + self, db_session: Session, tenant_context: None # noqa: ARG002 + ) -> Generator[int, None, None]: + """Insert a Credential row with raw encrypted bytes, clean up after.""" + config = {"api_key": "sk-test-1234", "endpoint": "https://example.com"} + encrypted = _encrypt_string(json.dumps(config), key=OLD_KEY) + + result = db_session.execute( + text( + "INSERT INTO credential " + "(source, credential_json, admin_public, curator_public) " + "VALUES (:source, :cred_json, true, false) " + "RETURNING id" + ), + {"source": DocumentSource.INGESTION_API.value, "cred_json": encrypted}, + ) + cred_id = result.scalar_one() + db_session.commit() + + yield cred_id + + db_session.execute( + text("DELETE FROM credential WHERE id = :id"), {"id": cred_id} + ) + db_session.commit() + + def test_rotates_credential_json( + self, db_session: Session, credential_id: int + ) -> None: + with ( + patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY), + patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY), + ): + totals = rotate_encryption_key(db_session, old_key=OLD_KEY) + + assert totals.get("credential.credential_json", 0) >= 1 + + raw = _raw_credential_bytes(db_session, credential_id) + assert raw is not None + decrypted = json.loads(_decrypt_bytes(raw, key=NEW_KEY)) + assert decrypted["api_key"] == "sk-test-1234" + assert decrypted["endpoint"] == "https://example.com" + + def test_skips_already_rotated( + self, db_session: Session, credential_id: int + ) -> None: + with ( + patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY), + patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY), + ): + rotate_encryption_key(db_session, old_key=OLD_KEY) + _ = rotate_encryption_key(db_session, old_key=OLD_KEY) + + raw = _raw_credential_bytes(db_session, credential_id) + assert raw is not None + decrypted = json.loads(_decrypt_bytes(raw, key=NEW_KEY)) + assert decrypted["api_key"] == "sk-test-1234" + + def test_dry_run_does_not_modify( + self, db_session: Session, credential_id: int + ) -> None: + original = _raw_credential_bytes(db_session, credential_id) + + with ( + patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY), + patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY), + ): + totals = rotate_encryption_key(db_session, old_key=OLD_KEY, dry_run=True) + + assert totals.get("credential.credential_json", 0) >= 1 + + raw_after = _raw_credential_bytes(db_session, credential_id) + assert raw_after == original + + +class TestRotateInternetSearchProvider: + """Test rotation against the real InternetSearchProvider table (EncryptedString). + + Discovery is scoped to only the InternetSearchProvider model to avoid + mutating other tables in the test database. + """ + + @pytest.fixture(autouse=True) + def _limit_discovery(self) -> Generator[None, None, None]: + with patch( + f"{ROTATE_MODULE}._discover_encrypted_columns", + return_value=[ + (InternetSearchProvider, "api_key", ["id"], False), + ], + ): + yield + + @pytest.fixture() + def isp_id( + self, db_session: Session, tenant_context: None # noqa: ARG002 + ) -> Generator[int, None, None]: + """Insert an InternetSearchProvider row with raw encrypted bytes.""" + encrypted = _encrypt_string("sk-secret-api-key", key=OLD_KEY) + + result = db_session.execute( + text( + "INSERT INTO internet_search_provider " + "(name, provider_type, api_key, is_active) " + "VALUES (:name, :ptype, :api_key, false) " + "RETURNING id" + ), + { + "name": f"test-rotation-{id(self)}", + "ptype": "test", + "api_key": encrypted, + }, + ) + isp_id = result.scalar_one() + db_session.commit() + + yield isp_id + + db_session.execute( + text("DELETE FROM internet_search_provider WHERE id = :id"), + {"id": isp_id}, + ) + db_session.commit() + + def test_rotates_api_key(self, db_session: Session, isp_id: int) -> None: + with ( + patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY), + patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY), + ): + totals = rotate_encryption_key(db_session, old_key=OLD_KEY) + + assert totals.get("internet_search_provider.api_key", 0) >= 1 + + raw = _raw_isp_bytes(db_session, isp_id) + assert raw is not None + assert _decrypt_bytes(raw, key=NEW_KEY) == "sk-secret-api-key" + + def test_rotates_from_unencrypted( + self, db_session: Session, tenant_context: None # noqa: ARG002 + ) -> None: + """Test rotating data that was stored without any encryption key.""" + result = db_session.execute( + text( + "INSERT INTO internet_search_provider " + "(name, provider_type, api_key, is_active) " + "VALUES (:name, :ptype, :api_key, false) " + "RETURNING id" + ), + { + "name": f"test-raw-{id(self)}", + "ptype": "test", + "api_key": b"raw-api-key", + }, + ) + isp_id = result.scalar_one() + db_session.commit() + + try: + with ( + patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY), + patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY), + ): + totals = rotate_encryption_key(db_session, old_key=None) + + assert totals.get("internet_search_provider.api_key", 0) >= 1 + + raw = _raw_isp_bytes(db_session, isp_id) + assert raw is not None + assert _decrypt_bytes(raw, key=NEW_KEY) == "raw-api-key" + finally: + db_session.execute( + text("DELETE FROM internet_search_provider WHERE id = :id"), + {"id": isp_id}, + ) + db_session.commit() diff --git a/backend/tests/external_dependency_unit/document_index/test_document_index_old.py b/backend/tests/external_dependency_unit/document_index/test_document_index_old.py new file mode 100644 index 00000000000..bd76a35f120 --- /dev/null +++ b/backend/tests/external_dependency_unit/document_index/test_document_index_old.py @@ -0,0 +1,398 @@ +"""External dependency tests for the old DocumentIndex interface. + +These tests assume Vespa and OpenSearch are running. + +TODO(ENG-3764)(andrei): Consolidate some of these test fixtures. +""" + +import os +import time +import uuid +from collections.abc import Generator +from unittest.mock import patch + +import httpx +import pytest + +from onyx.access.models import DocumentAccess +from onyx.configs.constants import DocumentSource +from onyx.connectors.models import Document +from onyx.context.search.models import IndexFilters +from onyx.db.enums import EmbeddingPrecision +from onyx.document_index.interfaces import DocumentIndex +from onyx.document_index.interfaces import IndexBatchParams +from onyx.document_index.interfaces import VespaChunkRequest +from onyx.document_index.interfaces import VespaDocumentUserFields +from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout +from onyx.document_index.opensearch.opensearch_document_index import ( + OpenSearchOldDocumentIndex, +) +from onyx.document_index.vespa.index import VespaIndex +from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client +from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout +from onyx.indexing.models import ChunkEmbedding +from onyx.indexing.models import DocMetadataAwareIndexChunk +from shared_configs.configs import MULTI_TENANT +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR +from shared_configs.contextvars import get_current_tenant_id +from tests.external_dependency_unit.constants import TEST_TENANT_ID + + +@pytest.fixture(scope="module") +def opensearch_available() -> Generator[None, None, None]: + """Verifies OpenSearch is running, fails the test if not.""" + if not wait_for_opensearch_with_timeout(): + pytest.fail("OpenSearch is not available.") + yield # Test runs here. + + +@pytest.fixture(scope="module") +def test_index_name() -> Generator[str, None, None]: + yield f"test_index_{uuid.uuid4().hex[:8]}" # Test runs here. + + +@pytest.fixture(scope="module") +def tenant_context() -> Generator[None, None, None]: + """Sets up tenant context for testing.""" + token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID) + try: + yield # Test runs here. + finally: + # Reset the tenant context after the test + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + + +@pytest.fixture(scope="module") +def httpx_client() -> Generator[httpx.Client, None, None]: + client = get_vespa_http_client() + try: + yield client + finally: + client.close() + + +@pytest.fixture(scope="module") +def vespa_document_index( + httpx_client: httpx.Client, + tenant_context: None, # noqa: ARG001 + test_index_name: str, +) -> Generator[VespaIndex, None, None]: + vespa_index = VespaIndex( + index_name=test_index_name, + secondary_index_name=None, + large_chunks_enabled=False, + secondary_large_chunks_enabled=None, + multitenant=MULTI_TENANT, + httpx_client=httpx_client, + ) + backend_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "..") + ) + with patch("os.getcwd", return_value=backend_dir): + vespa_index.ensure_indices_exist( + primary_embedding_dim=128, + primary_embedding_precision=EmbeddingPrecision.FLOAT, + secondary_index_embedding_dim=None, + secondary_index_embedding_precision=None, + ) + # Verify Vespa is running, fails the test if not. Try 90 seconds for testing + # in CI. We have to do this here because this endpoint only becomes live + # once we create an index. + if not wait_for_vespa_with_timeout(wait_limit=90): + pytest.fail("Vespa is not available.") + + # Wait until the schema is actually ready for writes on content nodes. We + # probe by attempting a PUT; 200 means the schema is live, 400 means not + # yet. This is so scuffed but running the test is really flakey otherwise; + # this is only temporary until we entirely move off of Vespa. + probe_doc = { + "fields": { + "document_id": "__probe__", + "chunk_id": 0, + "blurb": "", + "title": "", + "skip_title": True, + "content": "", + "content_summary": "", + "source_type": "file", + "source_links": "null", + "semantic_identifier": "", + "section_continuation": False, + "large_chunk_reference_ids": [], + "metadata": "{}", + "metadata_list": [], + "metadata_suffix": "", + "chunk_context": "", + "doc_summary": "", + "embeddings": {"full_chunk": [1.0] + [0.0] * 127}, + "access_control_list": {}, + "document_sets": {}, + "image_file_name": None, + "user_project": [], + "personas": [], + "boost": 0.0, + "aggregated_chunk_boost_factor": 0.0, + "primary_owners": [], + "secondary_owners": [], + } + } + schema_ready = False + probe_url = ( + f"http://localhost:8081/document/v1/default/{test_index_name}/docid/__probe__" + ) + for _ in range(60): + resp = httpx_client.post(probe_url, json=probe_doc) + if resp.status_code == 200: + schema_ready = True + # Clean up the probe document. + httpx_client.delete(probe_url) + break + time.sleep(1) + if not schema_ready: + pytest.fail(f"Vespa schema '{test_index_name}' did not become ready in time.") + + yield vespa_index # Test runs here. + + # TODO(ENG-3765)(andrei): Explicitly cleanup index. Not immediately + # pressing; in CI we should be using fresh instances of dependencies each + # time anyway. + + +@pytest.fixture(scope="module") +def opensearch_document_index( + opensearch_available: None, # noqa: ARG001 + tenant_context: None, # noqa: ARG001 + test_index_name: str, +) -> Generator[OpenSearchOldDocumentIndex, None, None]: + opensearch_index = OpenSearchOldDocumentIndex( + index_name=test_index_name, + embedding_dim=128, + embedding_precision=EmbeddingPrecision.FLOAT, + secondary_index_name=None, + secondary_embedding_dim=None, + secondary_embedding_precision=None, + large_chunks_enabled=False, + secondary_large_chunks_enabled=None, + multitenant=MULTI_TENANT, + ) + opensearch_index.ensure_indices_exist( + primary_embedding_dim=128, + primary_embedding_precision=EmbeddingPrecision.FLOAT, + secondary_index_embedding_dim=None, + secondary_index_embedding_precision=None, + ) + + yield opensearch_index # Test runs here. + + # TODO(ENG-3765)(andrei): Explicitly cleanup index. Not immediately + # pressing; in CI we should be using fresh instances of dependencies each + # time anyway. + + +@pytest.fixture(scope="module") +def document_indices( + vespa_document_index: VespaIndex, + opensearch_document_index: OpenSearchOldDocumentIndex, +) -> Generator[list[DocumentIndex], None, None]: + # Ideally these are parametrized; doing so with pytest fixtures is tricky. + yield [opensearch_document_index, vespa_document_index] # Test runs here. + + +@pytest.fixture(scope="function") +def chunks( + tenant_context: None, # noqa: ARG001 +) -> Generator[list[DocMetadataAwareIndexChunk], None, None]: + result = [] + chunk_count = 5 + doc_id = "test_doc" + tenant_id = get_current_tenant_id() + access = DocumentAccess.build( + user_emails=[], + user_groups=[], + external_user_emails=[], + external_user_group_ids=[], + is_public=True, + ) + document_sets: set[str] = set() + user_project: list[int] = list() + personas: list[int] = list() + boost = 0 + blurb = "blurb" + content = "content" + title_prefix = "" + doc_summary = "" + chunk_context = "" + title_embedding = [1.0] + [0] * 127 + # Full 0 vectors are not supported for cos similarity. + embeddings = ChunkEmbedding( + full_embedding=[1.0] + [0] * 127, mini_chunk_embeddings=[] + ) + source_document = Document( + id=doc_id, + semantic_identifier="semantic identifier", + source=DocumentSource.FILE, + sections=[], + metadata={}, + title="title", + ) + metadata_suffix_keyword = "" + image_file_id = None + source_links: dict[int, str] = {0: ""} + ancestor_hierarchy_node_ids: list[int] = [] + for i in range(chunk_count): + result.append( + DocMetadataAwareIndexChunk( + tenant_id=tenant_id, + access=access, + document_sets=document_sets, + user_project=user_project, + personas=personas, + boost=boost, + aggregated_chunk_boost_factor=0, + ancestor_hierarchy_node_ids=ancestor_hierarchy_node_ids, + embeddings=embeddings, + title_embedding=title_embedding, + source_document=source_document, + title_prefix=title_prefix, + metadata_suffix_keyword=metadata_suffix_keyword, + metadata_suffix_semantic="", + contextual_rag_reserved_tokens=0, + doc_summary=doc_summary, + chunk_context=chunk_context, + mini_chunk_texts=None, + large_chunk_id=None, + chunk_id=i, + blurb=blurb, + content=content, + source_links=source_links, + image_file_id=image_file_id, + section_continuation=False, + ) + ) + yield result # Test runs here. + + +@pytest.fixture(scope="function") +def index_batch_params( + tenant_context: None, # noqa: ARG001 +) -> Generator[IndexBatchParams, None, None]: + # WARNING: doc_id_to_previous_chunk_cnt={"test_doc": 0} is hardcoded to 0, + # which is only correct on the very first index call. The document_indices + # fixture is scope="module", meaning the same OpenSearch and Vespa backends + # persist across all test functions in this module. When a second test + # function uses this fixture and calls document_index.index(...), the + # backend already has 5 chunks for "test_doc" from the previous test run, + # but the batch params still claim 0 prior chunks exist. This can lead to + # orphaned/duplicate chunks that make subsequent assertions incorrect. + # TODO: Whenever adding a second test, either change this or cleanup the + # index between test cases. + yield IndexBatchParams( + doc_id_to_previous_chunk_cnt={"test_doc": 0}, + doc_id_to_new_chunk_cnt={"test_doc": 5}, + tenant_id=get_current_tenant_id(), + large_chunks_enabled=False, + ) + + +class TestDocumentIndexOld: + """Tests the old DocumentIndex interface.""" + + def test_update_single_can_clear_user_projects_and_personas( + self, + document_indices: list[DocumentIndex], + # This test case assumes all these chunks correspond to one document. + chunks: list[DocMetadataAwareIndexChunk], + index_batch_params: IndexBatchParams, + ) -> None: + """ + Tests that update_single can clear user_projects and personas. + """ + for document_index in document_indices: + # Precondition. + # Ensure there is some non-empty value for user project and + # personas. + for chunk in chunks: + chunk.user_project = [1] + chunk.personas = [2] + document_index.index(chunks, index_batch_params) + + # Ensure that we can get chunks as expected with filters. + doc_id = chunks[0].source_document.id + chunk_count = len(chunks) + tenant_id = get_current_tenant_id() + # We need to specify the chunk index range and specify + # batch_retrieval=True below to trigger the codepath for Vespa's + # search API, which uses the expected additive filtering for + # project_id and persona_id. Otherwise we would use the codepath for + # the visit API, which does not have this kind of filtering + # implemented. + chunk_request = VespaChunkRequest( + document_id=doc_id, min_chunk_ind=0, max_chunk_ind=chunk_count - 1 + ) + project_persona_filters = IndexFilters( + access_control_list=None, + tenant_id=tenant_id, + project_id=1, + persona_id=2, + # We need this even though none of the chunks belong to a + # document set because project_id and persona_id are only + # additive filters in the event the agent has knowledge scope; + # if the agent does not, it is implied that it can see + # everything it is allowed to. + document_set=["1"], + ) + # Not best practice here but the API for refreshing the index to + # ensure that the latest data is present is not exposed in this + # class and is not the same for Vespa and OpenSearch, so we just + # tolerate a sleep for now. As a consequence the number of tests in + # this suite should be small. We only need to tolerate this for as + # long as we continue to use Vespa, we can consider exposing + # something for OpenSearch later. + time.sleep(1) + inference_chunks = document_index.id_based_retrieval( + chunk_requests=[chunk_request], + filters=project_persona_filters, + batch_retrieval=True, + ) + assert len(inference_chunks) == chunk_count + # Sort by chunk id to easily test if we have all chunks. + for i, inference_chunk in enumerate( + sorted(inference_chunks, key=lambda x: x.chunk_id) + ): + assert inference_chunk.chunk_id == i + assert inference_chunk.document_id == doc_id + + # Under test. + # Explicitly set empty fields here. + user_fields = VespaDocumentUserFields(user_projects=[], personas=[]) + document_index.update_single( + doc_id=doc_id, + chunk_count=chunk_count, + tenant_id=tenant_id, + fields=None, + user_fields=user_fields, + ) + + # Postcondition. + filters = IndexFilters(access_control_list=None, tenant_id=tenant_id) + # We should expect to get back all expected chunks with no filters. + # Again, not best practice here. + time.sleep(1) + inference_chunks = document_index.id_based_retrieval( + chunk_requests=[chunk_request], filters=filters, batch_retrieval=True + ) + assert len(inference_chunks) == chunk_count + # Sort by chunk id to easily test if we have all chunks. + for i, inference_chunk in enumerate( + sorted(inference_chunks, key=lambda x: x.chunk_id) + ): + assert inference_chunk.chunk_id == i + assert inference_chunk.document_id == doc_id + # Now, we should expect to not get any chunks if we specify the user + # project and personas filters. + inference_chunks = document_index.id_based_retrieval( + chunk_requests=[chunk_request], + filters=project_persona_filters, + batch_retrieval=True, + ) + assert len(inference_chunks) == 0 diff --git a/backend/tests/external_dependency_unit/hierarchy/test_hierarchy_access_filter.py b/backend/tests/external_dependency_unit/hierarchy/test_hierarchy_access_filter.py index d256fca689b..364cec642a5 100644 --- a/backend/tests/external_dependency_unit/hierarchy/test_hierarchy_access_filter.py +++ b/backend/tests/external_dependency_unit/hierarchy/test_hierarchy_access_filter.py @@ -85,7 +85,7 @@ def test_group_overlap_filter( results = _get_accessible_hierarchy_nodes_for_source( db_session, source=DocumentSource.GOOGLE_DRIVE, - user_email=None, + user_email="", external_group_ids=["group_engineering"], ) result_ids = {n.raw_node_id for n in results} @@ -124,7 +124,7 @@ def test_no_credentials_returns_only_public( results = _get_accessible_hierarchy_nodes_for_source( db_session, source=DocumentSource.GOOGLE_DRIVE, - user_email=None, + user_email="", external_group_ids=[], ) result_ids = {n.raw_node_id for n in results} diff --git a/backend/tests/external_dependency_unit/llm/test_llm_provider.py b/backend/tests/external_dependency_unit/llm/test_llm_provider.py index 6c1a01f2bac..3eb8ccddb4d 100644 --- a/backend/tests/external_dependency_unit/llm/test_llm_provider.py +++ b/backend/tests/external_dependency_unit/llm/test_llm_provider.py @@ -11,7 +11,6 @@ from uuid import uuid4 import pytest -from fastapi import HTTPException from sqlalchemy.orm import Session from onyx.db.enums import LLMModelFlowType @@ -20,6 +19,8 @@ from onyx.db.llm import update_default_provider from onyx.db.llm import upsert_llm_provider from onyx.db.models import UserRole +from onyx.error_handling.error_codes import OnyxErrorCode +from onyx.error_handling.exceptions import OnyxError from onyx.llm.constants import LlmProviderNames from onyx.llm.interfaces import LLM from onyx.server.manage.llm.api import ( @@ -122,16 +123,16 @@ def mock_test_llm_success(llm: LLM) -> str | None: finally: db_session.rollback() - def test_failed_llm_test_raises_http_exception( + def test_failed_llm_test_raises_onyx_error( self, db_session: Session, provider_name: str, # noqa: ARG002 ) -> None: """ - Test that a failed LLM test raises an HTTPException with status 400. + Test that a failed LLM test raises an OnyxError with VALIDATION_ERROR. When test_llm returns an error message, the endpoint should raise - an HTTPException with the error details. + an OnyxError with the error details. """ error_message = "Invalid API key: Authentication failed" @@ -143,7 +144,7 @@ def mock_test_llm_failure(llm: LLM) -> str | None: # noqa: ARG001 with patch( "onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_failure ): - with pytest.raises(HTTPException) as exc_info: + with pytest.raises(OnyxError) as exc_info: run_test_llm_configuration( test_llm_request=LLMTestRequest( provider=LlmProviderNames.OPENAI, @@ -156,8 +157,7 @@ def mock_test_llm_failure(llm: LLM) -> str | None: # noqa: ARG001 db_session=db_session, ) - # Verify the exception details - assert exc_info.value.status_code == 400 + assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR assert exc_info.value.detail == error_message finally: @@ -536,10 +536,10 @@ def test_no_default_provider_raises_exception( remove_llm_provider(db_session, provider.id) # Now run_test_default_provider should fail - with pytest.raises(HTTPException) as exc_info: + with pytest.raises(OnyxError) as exc_info: run_test_default_provider(_=_create_mock_admin()) - assert exc_info.value.status_code == 400 + assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR assert "No LLM Provider setup" in exc_info.value.detail finally: @@ -581,10 +581,10 @@ def mock_test_llm_failure(llm: LLM) -> str | None: # noqa: ARG001 with patch( "onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_failure ): - with pytest.raises(HTTPException) as exc_info: + with pytest.raises(OnyxError) as exc_info: run_test_default_provider(_=_create_mock_admin()) - assert exc_info.value.status_code == 400 + assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR assert exc_info.value.detail == error_message finally: diff --git a/backend/tests/external_dependency_unit/llm/test_llm_provider_api_base.py b/backend/tests/external_dependency_unit/llm/test_llm_provider_api_base.py index 40468bd2105..3b53d965bd8 100644 --- a/backend/tests/external_dependency_unit/llm/test_llm_provider_api_base.py +++ b/backend/tests/external_dependency_unit/llm/test_llm_provider_api_base.py @@ -16,13 +16,14 @@ from uuid import uuid4 import pytest -from fastapi import HTTPException from sqlalchemy.orm import Session from onyx.db.llm import fetch_existing_llm_provider from onyx.db.llm import remove_llm_provider from onyx.db.llm import upsert_llm_provider from onyx.db.models import UserRole +from onyx.error_handling.error_codes import OnyxErrorCode +from onyx.error_handling.exceptions import OnyxError from onyx.llm.constants import LlmProviderNames from onyx.server.manage.llm.api import _mask_string from onyx.server.manage.llm.api import put_llm_provider @@ -100,7 +101,7 @@ def test_blocks_api_base_change_without_key_change__multi_tenant( api_base="https://attacker.example.com", ) - with pytest.raises(HTTPException) as exc_info: + with pytest.raises(OnyxError) as exc_info: put_llm_provider( llm_provider_upsert_request=update_request, is_creation=False, @@ -108,7 +109,7 @@ def test_blocks_api_base_change_without_key_change__multi_tenant( db_session=db_session, ) - assert exc_info.value.status_code == 400 + assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR assert "cannot be changed without changing the API key" in str( exc_info.value.detail ) @@ -236,7 +237,7 @@ def test_blocks_clearing_api_base__multi_tenant( api_base=None, ) - with pytest.raises(HTTPException) as exc_info: + with pytest.raises(OnyxError) as exc_info: put_llm_provider( llm_provider_upsert_request=update_request, is_creation=False, @@ -244,7 +245,7 @@ def test_blocks_clearing_api_base__multi_tenant( db_session=db_session, ) - assert exc_info.value.status_code == 400 + assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR assert "cannot be changed without changing the API key" in str( exc_info.value.detail ) @@ -339,7 +340,7 @@ def test_blocks_custom_config_change_without_key_change__multi_tenant( custom_config_changed=True, ) - with pytest.raises(HTTPException) as exc_info: + with pytest.raises(OnyxError) as exc_info: put_llm_provider( llm_provider_upsert_request=update_request, is_creation=False, @@ -347,7 +348,7 @@ def test_blocks_custom_config_change_without_key_change__multi_tenant( db_session=db_session, ) - assert exc_info.value.status_code == 400 + assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR assert "cannot be changed without changing the API key" in str( exc_info.value.detail ) @@ -375,7 +376,7 @@ def test_blocks_adding_custom_config_without_key_change__multi_tenant( custom_config_changed=True, ) - with pytest.raises(HTTPException) as exc_info: + with pytest.raises(OnyxError) as exc_info: put_llm_provider( llm_provider_upsert_request=update_request, is_creation=False, @@ -383,7 +384,7 @@ def test_blocks_adding_custom_config_without_key_change__multi_tenant( db_session=db_session, ) - assert exc_info.value.status_code == 400 + assert exc_info.value.error_code == OnyxErrorCode.VALIDATION_ERROR assert "cannot be changed without changing the API key" in str( exc_info.value.detail ) diff --git a/backend/tests/external_dependency_unit/llm/test_llm_provider_auto_mode.py b/backend/tests/external_dependency_unit/llm/test_llm_provider_auto_mode.py index d0510f3a73f..c954108c422 100644 --- a/backend/tests/external_dependency_unit/llm/test_llm_provider_auto_mode.py +++ b/backend/tests/external_dependency_unit/llm/test_llm_provider_auto_mode.py @@ -698,6 +698,99 @@ def test_sync_auto_mode_creates_flow_rows( class TestAutoModeTransitionsAndResync: """Tests for auto/manual transitions, config evolution, and sync idempotency.""" + def test_transition_to_auto_mode_preserves_default( + self, + db_session: Session, + provider_name: str, + ) -> None: + """When the default provider transitions from manual to auto mode, + the global default should be preserved (set to the recommended model). + + Steps: + 1. Create a manual-mode provider with models, set it as global default. + 2. Transition to auto mode (model_configurations=[] triggers cascade + delete of old ModelConfigurations and their LLMModelFlow rows). + 3. Verify the provider is still the global default, now using the + recommended default model from the GitHub config. + """ + initial_models = [ + ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True), + ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True), + ] + + auto_config = _create_mock_llm_recommendations( + provider=LlmProviderNames.OPENAI, + default_model_name="gpt-4o-mini", + additional_models=["gpt-4o"], + ) + + try: + # Step 1: Create manual-mode provider and set as default + put_llm_provider( + llm_provider_upsert_request=LLMProviderUpsertRequest( + name=provider_name, + provider=LlmProviderNames.OPENAI, + api_key="sk-test-key-00000000000000000000000000000000000", + api_key_changed=True, + is_auto_mode=False, + model_configurations=initial_models, + ), + is_creation=True, + _=_create_mock_admin(), + db_session=db_session, + ) + + db_session.expire_all() + provider = fetch_existing_llm_provider( + name=provider_name, db_session=db_session + ) + assert provider is not None + update_default_provider(provider.id, "gpt-4o", db_session) + + default_before = fetch_default_llm_model(db_session) + assert default_before is not None + assert default_before.name == "gpt-4o" + assert default_before.llm_provider_id == provider.id + + # Step 2: Transition to auto mode + with patch( + "onyx.server.manage.llm.api.fetch_llm_recommendations_from_github", + return_value=auto_config, + ): + put_llm_provider( + llm_provider_upsert_request=LLMProviderUpsertRequest( + id=provider.id, + name=provider_name, + provider=LlmProviderNames.OPENAI, + api_key=None, + api_key_changed=False, + is_auto_mode=True, + model_configurations=[], + ), + is_creation=False, + _=_create_mock_admin(), + db_session=db_session, + ) + + # Step 3: Default should be preserved on this provider + db_session.expire_all() + default_after = fetch_default_llm_model(db_session) + assert default_after is not None, ( + "Default model should not be None after transitioning to auto mode — " + "the provider was the default before and should remain so" + ) + assert ( + default_after.llm_provider_id == provider.id + ), "Default should still belong to the same provider after transition" + assert default_after.name == "gpt-4o-mini", ( + f"Default should be updated to the recommended model 'gpt-4o-mini', " + f"got '{default_after.name}'" + ) + + finally: + db_session.rollback() + _cleanup_provider(db_session, provider_name) + def test_auto_to_manual_mode_preserves_models_and_stops_syncing( self, db_session: Session, @@ -1042,14 +1135,195 @@ def test_default_model_hidden_when_removed_from_config( assert visibility["gpt-4o"] is False, "Removed default should be hidden" assert visibility["gpt-4o-mini"] is True, "New default should be visible" - # The LLMModelFlow row for gpt-4o still exists (is_default=True), - # but the model is hidden. fetch_default_llm_model filters on - # is_visible=True, so it should NOT return gpt-4o. + # The old default (gpt-4o) is now hidden. sync_auto_mode_models + # should update the global default to the new recommended default + # (gpt-4o-mini) so that it is not silently lost. + db_session.expire_all() + default_after = fetch_default_llm_model(db_session) + assert default_after is not None, ( + "Default model should not be None — sync should set the new " + "recommended default when the old one is hidden" + ) + assert default_after.name == "gpt-4o-mini", ( + f"Default should be updated to the new recommended model " + f"'gpt-4o-mini', but got '{default_after.name}'" + ) + + finally: + db_session.rollback() + _cleanup_provider(db_session, provider_name) + + def test_sync_updates_default_when_recommended_default_changes( + self, + db_session: Session, + provider_name: str, + ) -> None: + """When the provider owns the CHAT default and a sync arrives with a + different recommended default model (both models still in config), + the global default should be updated to the new recommendation. + + Steps: + 1. Create auto-mode provider with config v1: default=gpt-4o. + 2. Set gpt-4o as the global CHAT default. + 3. Re-sync with config v2: default=gpt-4o-mini (gpt-4o still present). + 4. Verify the CHAT default switched to gpt-4o-mini and both models + remain visible. + """ + config_v1 = _create_mock_llm_recommendations( + provider=LlmProviderNames.OPENAI, + default_model_name="gpt-4o", + additional_models=["gpt-4o-mini"], + ) + config_v2 = _create_mock_llm_recommendations( + provider=LlmProviderNames.OPENAI, + default_model_name="gpt-4o-mini", + additional_models=["gpt-4o"], + ) + + try: + with patch( + "onyx.server.manage.llm.api.fetch_llm_recommendations_from_github", + return_value=config_v1, + ): + put_llm_provider( + llm_provider_upsert_request=LLMProviderUpsertRequest( + name=provider_name, + provider=LlmProviderNames.OPENAI, + api_key="sk-test-key-00000000000000000000000000000000000", + api_key_changed=True, + is_auto_mode=True, + model_configurations=[], + ), + is_creation=True, + _=_create_mock_admin(), + db_session=db_session, + ) + + # Set gpt-4o as the global CHAT default db_session.expire_all() + provider = fetch_existing_llm_provider( + name=provider_name, db_session=db_session + ) + assert provider is not None + update_default_provider(provider.id, "gpt-4o", db_session) + + default_before = fetch_default_llm_model(db_session) + assert default_before is not None + assert default_before.name == "gpt-4o" + + # Re-sync with config v2 (recommended default changed) + db_session.expire_all() + provider = fetch_existing_llm_provider( + name=provider_name, db_session=db_session + ) + assert provider is not None + + changes = sync_auto_mode_models( + db_session=db_session, + provider=provider, + llm_recommendations=config_v2, + ) + assert changes > 0, "Sync should report changes when default switches" + + # Both models should remain visible + db_session.expire_all() + provider = fetch_existing_llm_provider( + name=provider_name, db_session=db_session + ) + assert provider is not None + visibility = { + mc.name: mc.is_visible for mc in provider.model_configurations + } + assert visibility["gpt-4o"] is True + assert visibility["gpt-4o-mini"] is True + + # The CHAT default should now be gpt-4o-mini default_after = fetch_default_llm_model(db_session) + assert default_after is not None assert ( - default_after is None or default_after.name != "gpt-4o" - ), "Hidden model should not be returned as the default" + default_after.name == "gpt-4o-mini" + ), f"Default should be updated to 'gpt-4o-mini', got '{default_after.name}'" + + finally: + db_session.rollback() + _cleanup_provider(db_session, provider_name) + + def test_sync_idempotent_when_default_already_matches( + self, + db_session: Session, + provider_name: str, + ) -> None: + """When the provider owns the CHAT default and it already matches the + recommended default, re-syncing should report zero changes. + + This is a regression test for the bug where changes was unconditionally + incremented even when the default was already correct. + """ + config = _create_mock_llm_recommendations( + provider=LlmProviderNames.OPENAI, + default_model_name="gpt-4o", + additional_models=["gpt-4o-mini"], + ) + + try: + with patch( + "onyx.server.manage.llm.api.fetch_llm_recommendations_from_github", + return_value=config, + ): + put_llm_provider( + llm_provider_upsert_request=LLMProviderUpsertRequest( + name=provider_name, + provider=LlmProviderNames.OPENAI, + api_key="sk-test-key-00000000000000000000000000000000000", + api_key_changed=True, + is_auto_mode=True, + model_configurations=[], + ), + is_creation=True, + _=_create_mock_admin(), + db_session=db_session, + ) + + # Set gpt-4o (the recommended default) as global CHAT default + db_session.expire_all() + provider = fetch_existing_llm_provider( + name=provider_name, db_session=db_session + ) + assert provider is not None + update_default_provider(provider.id, "gpt-4o", db_session) + + # First sync to stabilize state + db_session.expire_all() + provider = fetch_existing_llm_provider( + name=provider_name, db_session=db_session + ) + assert provider is not None + sync_auto_mode_models( + db_session=db_session, + provider=provider, + llm_recommendations=config, + ) + + # Second sync — default already matches, should be a no-op + db_session.expire_all() + provider = fetch_existing_llm_provider( + name=provider_name, db_session=db_session + ) + assert provider is not None + changes = sync_auto_mode_models( + db_session=db_session, + provider=provider, + llm_recommendations=config, + ) + assert changes == 0, ( + f"Expected 0 changes when default already matches recommended, " + f"got {changes}" + ) + + # Default should still be gpt-4o + default_model = fetch_default_llm_model(db_session) + assert default_model is not None + assert default_model.name == "gpt-4o" finally: db_session.rollback() diff --git a/backend/tests/external_dependency_unit/llm/test_llm_provider_default_model_protection.py b/backend/tests/external_dependency_unit/llm/test_llm_provider_default_model_protection.py new file mode 100644 index 00000000000..f7004ceaeb1 --- /dev/null +++ b/backend/tests/external_dependency_unit/llm/test_llm_provider_default_model_protection.py @@ -0,0 +1,220 @@ +""" +This should act as the main point of reference for testing that default model +logic is consisten. + + - +""" + +from collections.abc import Generator +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from onyx.db.llm import fetch_existing_llm_provider +from onyx.db.llm import remove_llm_provider +from onyx.db.llm import update_default_provider +from onyx.db.llm import update_default_vision_provider +from onyx.db.llm import upsert_llm_provider +from onyx.llm.constants import LlmProviderNames +from onyx.server.manage.llm.models import LLMProviderUpsertRequest +from onyx.server.manage.llm.models import LLMProviderView +from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest + + +def _create_test_provider( + db_session: Session, + name: str, + models: list[ModelConfigurationUpsertRequest] | None = None, +) -> LLMProviderView: + """Helper to create a test LLM provider with multiple models.""" + if models is None: + models = [ + ModelConfigurationUpsertRequest( + name="gpt-4o", is_visible=True, supports_image_input=True + ), + ModelConfigurationUpsertRequest( + name="gpt-4o-mini", is_visible=True, supports_image_input=False + ), + ] + return upsert_llm_provider( + LLMProviderUpsertRequest( + name=name, + provider=LlmProviderNames.OPENAI, + api_key="sk-test-key-00000000000000000000000000000000000", + api_key_changed=True, + model_configurations=models, + ), + db_session=db_session, + ) + + +def _cleanup_provider(db_session: Session, name: str) -> None: + """Helper to clean up a test provider by name.""" + provider = fetch_existing_llm_provider(name=name, db_session=db_session) + if provider: + remove_llm_provider(db_session, provider.id) + + +@pytest.fixture +def provider_name(db_session: Session) -> Generator[str, None, None]: + """Generate a unique provider name for each test, with automatic cleanup.""" + name = f"test-provider-{uuid4().hex[:8]}" + yield name + db_session.rollback() + _cleanup_provider(db_session, name) + + +class TestDefaultModelProtection: + """Tests that the default model cannot be removed or hidden.""" + + def test_cannot_remove_default_text_model( + self, + db_session: Session, + provider_name: str, + ) -> None: + """Removing the default text model from a provider should raise ValueError.""" + provider = _create_test_provider(db_session, provider_name) + update_default_provider(provider.id, "gpt-4o", db_session) + + # Try to update the provider without the default model + with pytest.raises(ValueError, match="Cannot remove the default model"): + upsert_llm_provider( + LLMProviderUpsertRequest( + id=provider.id, + name=provider_name, + provider=LlmProviderNames.OPENAI, + api_key="sk-test-key-00000000000000000000000000000000000", + api_key_changed=True, + model_configurations=[ + ModelConfigurationUpsertRequest( + name="gpt-4o-mini", is_visible=True + ), + ], + ), + db_session=db_session, + ) + + def test_cannot_hide_default_text_model( + self, + db_session: Session, + provider_name: str, + ) -> None: + """Setting is_visible=False on the default text model should raise ValueError.""" + provider = _create_test_provider(db_session, provider_name) + update_default_provider(provider.id, "gpt-4o", db_session) + + # Try to hide the default model + with pytest.raises(ValueError, match="Cannot hide the default model"): + upsert_llm_provider( + LLMProviderUpsertRequest( + id=provider.id, + name=provider_name, + provider=LlmProviderNames.OPENAI, + api_key="sk-test-key-00000000000000000000000000000000000", + api_key_changed=True, + model_configurations=[ + ModelConfigurationUpsertRequest( + name="gpt-4o", is_visible=False + ), + ModelConfigurationUpsertRequest( + name="gpt-4o-mini", is_visible=True + ), + ], + ), + db_session=db_session, + ) + + def test_cannot_remove_default_vision_model( + self, + db_session: Session, + provider_name: str, + ) -> None: + """Removing the default vision model from a provider should raise ValueError.""" + provider = _create_test_provider(db_session, provider_name) + # Set gpt-4o as both the text and vision default + update_default_provider(provider.id, "gpt-4o", db_session) + update_default_vision_provider(provider.id, "gpt-4o", db_session) + + # Try to remove the default vision model + with pytest.raises(ValueError, match="Cannot remove the default model"): + upsert_llm_provider( + LLMProviderUpsertRequest( + id=provider.id, + name=provider_name, + provider=LlmProviderNames.OPENAI, + api_key="sk-test-key-00000000000000000000000000000000000", + api_key_changed=True, + model_configurations=[ + ModelConfigurationUpsertRequest( + name="gpt-4o-mini", is_visible=True + ), + ], + ), + db_session=db_session, + ) + + def test_can_remove_non_default_model( + self, + db_session: Session, + provider_name: str, + ) -> None: + """Removing a non-default model should succeed.""" + provider = _create_test_provider(db_session, provider_name) + update_default_provider(provider.id, "gpt-4o", db_session) + + # Remove gpt-4o-mini (not default) — should succeed + updated = upsert_llm_provider( + LLMProviderUpsertRequest( + id=provider.id, + name=provider_name, + provider=LlmProviderNames.OPENAI, + api_key="sk-test-key-00000000000000000000000000000000000", + api_key_changed=True, + model_configurations=[ + ModelConfigurationUpsertRequest( + name="gpt-4o", is_visible=True, supports_image_input=True + ), + ], + ), + db_session=db_session, + ) + + model_names = {mc.name for mc in updated.model_configurations} + assert "gpt-4o" in model_names + assert "gpt-4o-mini" not in model_names + + def test_can_hide_non_default_model( + self, + db_session: Session, + provider_name: str, + ) -> None: + """Hiding a non-default model should succeed.""" + provider = _create_test_provider(db_session, provider_name) + update_default_provider(provider.id, "gpt-4o", db_session) + + # Hide gpt-4o-mini (not default) — should succeed + updated = upsert_llm_provider( + LLMProviderUpsertRequest( + id=provider.id, + name=provider_name, + provider=LlmProviderNames.OPENAI, + api_key="sk-test-key-00000000000000000000000000000000000", + api_key_changed=True, + model_configurations=[ + ModelConfigurationUpsertRequest( + name="gpt-4o", is_visible=True, supports_image_input=True + ), + ModelConfigurationUpsertRequest( + name="gpt-4o-mini", is_visible=False + ), + ], + ), + db_session=db_session, + ) + + model_visibility = { + mc.name: mc.is_visible for mc in updated.model_configurations + } + assert model_visibility["gpt-4o"] is True + assert model_visibility["gpt-4o-mini"] is False diff --git a/backend/tests/external_dependency_unit/opensearch_migration/test_opensearch_migration_tasks.py b/backend/tests/external_dependency_unit/opensearch_migration/test_opensearch_migration_tasks.py index 983d3c870e9..3a0684d6fec 100644 --- a/backend/tests/external_dependency_unit/opensearch_migration/test_opensearch_migration_tasks.py +++ b/backend/tests/external_dependency_unit/opensearch_migration/test_opensearch_migration_tasks.py @@ -17,6 +17,9 @@ import pytest from sqlalchemy.orm import Session +from onyx.background.celery.tasks.opensearch_migration.constants import ( + GET_VESPA_CHUNKS_SLICE_COUNT, +) from onyx.background.celery.tasks.opensearch_migration.tasks import ( is_continuation_token_done_for_all_slices, ) @@ -236,6 +239,8 @@ def full_deployment_setup() -> Generator[None, None, None]: NOTE: We deliberately duplicate this logic from backend/tests/external_dependency_unit/conftest.py because we need to set opensearch_available just for this module, not the entire test session. + + TODO(ENG-3764)(andrei): Consolidate some of these test fixtures. """ # Patch ENABLE_OPENSEARCH_INDEXING_FOR_ONYX just for this test because we # don't yet want that enabled for all tests. @@ -320,9 +325,15 @@ def test_embedding_dimension(db_session: Session) -> Generator[int, None, None]: @pytest.fixture(scope="function") def patch_get_vespa_chunks_page_size() -> Generator[int, None, None]: test_page_size = 5 - with patch( - "onyx.background.celery.tasks.opensearch_migration.tasks.GET_VESPA_CHUNKS_PAGE_SIZE", - test_page_size, + with ( + patch( + "onyx.background.celery.tasks.opensearch_migration.tasks.GET_VESPA_CHUNKS_PAGE_SIZE", + test_page_size, + ), + patch( + "onyx.background.celery.tasks.opensearch_migration.constants.GET_VESPA_CHUNKS_PAGE_SIZE", + test_page_size, + ), ): yield test_page_size # Test runs here. @@ -582,6 +593,175 @@ def test_chunk_migration_resumes_from_continuation_token( document_chunks[document.id][opensearch_chunk.chunk_index], ) + def test_chunk_migration_visits_all_chunks_even_when_batch_size_varies( + self, + db_session: Session, + test_documents: list[Document], + vespa_document_index: VespaDocumentIndex, + opensearch_client: OpenSearchIndexClient, + test_embedding_dimension: int, + clean_migration_tables: None, # noqa: ARG002 + enable_opensearch_indexing_for_onyx: None, # noqa: ARG002 + ) -> None: + """ + Tests that chunk migration works correctly even when the batch size + changes halfway through a migration. + + Simulates task time running out my mocking the locking behavior. + """ + # Precondition. + # Index chunks into Vespa. + document_chunks: dict[str, list[dict[str, Any]]] = { + document.id: [ + _create_raw_document_chunk( + document_id=document.id, + chunk_index=i, + content=f"Test content {i} for {document.id}", + embedding=_generate_test_vector(test_embedding_dimension), + now=datetime.now(), + title=f"Test title {document.id}", + title_embedding=_generate_test_vector(test_embedding_dimension), + ) + for i in range(CHUNK_COUNT) + ] + for document in test_documents + } + all_chunks: list[dict[str, Any]] = [] + for chunks in document_chunks.values(): + all_chunks.extend(chunks) + vespa_document_index.index_raw_chunks(all_chunks) + + # Run the initial batch. To simulate partial progress we will mock the + # redis lock to return True for the first invocation of .owned() and + # False subsequently. + # NOTE: The batch size is currently set to 5 in + # patch_get_vespa_chunks_page_size. + mock_redis_client = Mock() + mock_lock = Mock() + mock_lock.owned.side_effect = [True, False, False] + mock_lock.acquire.return_value = True + mock_redis_client.lock.return_value = mock_lock + with patch( + "onyx.background.celery.tasks.opensearch_migration.tasks.get_redis_client", + return_value=mock_redis_client, + ): + result_1 = migrate_chunks_from_vespa_to_opensearch_task( + tenant_id=get_current_tenant_id() + ) + + assert result_1 is True + # Expire the session cache to see the committed changes from the task. + db_session.expire_all() + + # Verify partial progress was saved. + tenant_record = db_session.query(OpenSearchTenantMigrationRecord).first() + assert tenant_record is not None + partial_chunks_migrated = tenant_record.total_chunks_migrated + assert partial_chunks_migrated > 0 + # page_size applies per slice, so one iteration can fetch up to + # page_size * GET_VESPA_CHUNKS_SLICE_COUNT chunks total. + assert partial_chunks_migrated <= 5 * GET_VESPA_CHUNKS_SLICE_COUNT + assert tenant_record.vespa_visit_continuation_token is not None + # Slices are not necessarily evenly distributed across all document + # chunks so we can't test that every token is non-None, but certainly at + # least one must be. + assert any(json.loads(tenant_record.vespa_visit_continuation_token).values()) + assert tenant_record.migration_completed_at is None + assert tenant_record.approx_chunk_count_in_vespa is not None + + # Under test. + # Now patch the batch size to be some other number, like 2. + mock_redis_client = Mock() + mock_lock = Mock() + mock_lock.owned.side_effect = [True, False, False] + mock_lock.acquire.return_value = True + mock_redis_client.lock.return_value = mock_lock + with ( + patch( + "onyx.background.celery.tasks.opensearch_migration.tasks.GET_VESPA_CHUNKS_PAGE_SIZE", + 2, + ), + patch( + "onyx.background.celery.tasks.opensearch_migration.constants.GET_VESPA_CHUNKS_PAGE_SIZE", + 2, + ), + patch( + "onyx.background.celery.tasks.opensearch_migration.tasks.get_redis_client", + return_value=mock_redis_client, + ), + ): + result_2 = migrate_chunks_from_vespa_to_opensearch_task( + tenant_id=get_current_tenant_id() + ) + + # Postcondition. + assert result_2 is True + # Expire the session cache to see the committed changes from the task. + db_session.expire_all() + + # Verify next partial progress was saved. + tenant_record = db_session.query(OpenSearchTenantMigrationRecord).first() + assert tenant_record is not None + new_partial_chunks_migrated = tenant_record.total_chunks_migrated + assert new_partial_chunks_migrated > partial_chunks_migrated + # page_size applies per slice, so one iteration can fetch up to + # page_size * GET_VESPA_CHUNKS_SLICE_COUNT chunks total. + assert new_partial_chunks_migrated <= (5 + 2) * GET_VESPA_CHUNKS_SLICE_COUNT + assert tenant_record.vespa_visit_continuation_token is not None + # Slices are not necessarily evenly distributed across all document + # chunks so we can't test that every token is non-None, but certainly at + # least one must be. + assert any(json.loads(tenant_record.vespa_visit_continuation_token).values()) + assert tenant_record.migration_completed_at is None + assert tenant_record.approx_chunk_count_in_vespa is not None + + # Under test. + # Run the remainder of the migration. + with ( + patch( + "onyx.background.celery.tasks.opensearch_migration.tasks.GET_VESPA_CHUNKS_PAGE_SIZE", + 2, + ), + patch( + "onyx.background.celery.tasks.opensearch_migration.constants.GET_VESPA_CHUNKS_PAGE_SIZE", + 2, + ), + ): + result_3 = migrate_chunks_from_vespa_to_opensearch_task( + tenant_id=get_current_tenant_id() + ) + + # Postcondition. + assert result_3 is True + # Expire the session cache to see the committed changes from the task. + db_session.expire_all() + + # Verify completion. + tenant_record = db_session.query(OpenSearchTenantMigrationRecord).first() + assert tenant_record is not None + assert tenant_record.total_chunks_migrated > new_partial_chunks_migrated + assert tenant_record.total_chunks_migrated == len(all_chunks) + # Visit is complete so continuation token should be None. + assert tenant_record.vespa_visit_continuation_token is not None + assert is_continuation_token_done_for_all_slices( + json.loads(tenant_record.vespa_visit_continuation_token) + ) + assert tenant_record.migration_completed_at is not None + assert tenant_record.approx_chunk_count_in_vespa == len(all_chunks) + + # Verify chunks were indexed in OpenSearch. + for document in test_documents: + opensearch_chunks = _get_document_chunks_from_opensearch( + opensearch_client, document.id, get_current_tenant_id() + ) + assert len(opensearch_chunks) == CHUNK_COUNT + opensearch_chunks.sort(key=lambda x: x.chunk_index) + for opensearch_chunk in opensearch_chunks: + _assert_chunk_matches_vespa_chunk( + opensearch_chunk, + document_chunks[document.id][opensearch_chunk.chunk_index], + ) + def test_chunk_migration_empty_vespa( self, db_session: Session, diff --git a/backend/tests/external_dependency_unit/search_settings/test_search_settings.py b/backend/tests/external_dependency_unit/search_settings/test_search_settings.py index 94d7e139186..6236706e8ca 100644 --- a/backend/tests/external_dependency_unit/search_settings/test_search_settings.py +++ b/backend/tests/external_dependency_unit/search_settings/test_search_settings.py @@ -11,6 +11,7 @@ from onyx.context.search.models import SearchSettingsCreationRequest from onyx.db.enums import EmbeddingPrecision from onyx.db.llm import fetch_default_contextual_rag_model +from onyx.db.llm import fetch_existing_llm_provider from onyx.db.llm import update_default_contextual_model from onyx.db.llm import upsert_llm_provider from onyx.db.models import IndexModelStatus @@ -37,6 +38,8 @@ def _create_llm_provider_and_model( model_name: str, ) -> None: """Insert an LLM provider with a single visible model configuration.""" + if fetch_existing_llm_provider(name=provider_name, db_session=db_session): + return upsert_llm_provider( LLMProviderUpsertRequest( name=provider_name, @@ -146,8 +149,8 @@ def baseline_search_settings( ) -@pytest.mark.skip(reason="Set new search settings is temporarily disabled.") @patch("onyx.db.swap_index.get_all_document_indices") +@patch("onyx.server.manage.search_settings.get_all_document_indices") @patch("onyx.server.manage.search_settings.get_default_document_index") @patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag") @patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler") @@ -155,6 +158,7 @@ def test_indexing_pipeline_uses_contextual_rag_settings_from_create( mock_index_handler: MagicMock, mock_get_llm: MagicMock, mock_get_doc_index: MagicMock, # noqa: ARG001 + mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001 mock_get_all_doc_indices: MagicMock, baseline_search_settings: None, # noqa: ARG001 db_session: Session, @@ -196,8 +200,8 @@ def test_indexing_pipeline_uses_contextual_rag_settings_from_create( ) -@pytest.mark.skip(reason="Set new search settings is temporarily disabled.") @patch("onyx.db.swap_index.get_all_document_indices") +@patch("onyx.server.manage.search_settings.get_all_document_indices") @patch("onyx.server.manage.search_settings.get_default_document_index") @patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag") @patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler") @@ -205,6 +209,7 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings( mock_index_handler: MagicMock, mock_get_llm: MagicMock, mock_get_doc_index: MagicMock, # noqa: ARG001 + mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001 mock_get_all_doc_indices: MagicMock, baseline_search_settings: None, # noqa: ARG001 db_session: Session, @@ -266,7 +271,7 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings( ) -@pytest.mark.skip(reason="Set new search settings is temporarily disabled.") +@patch("onyx.server.manage.search_settings.get_all_document_indices") @patch("onyx.server.manage.search_settings.get_default_document_index") @patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag") @patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler") @@ -274,6 +279,7 @@ def test_indexing_pipeline_skips_llm_when_contextual_rag_disabled( mock_index_handler: MagicMock, mock_get_llm: MagicMock, mock_get_doc_index: MagicMock, # noqa: ARG001 + mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001 baseline_search_settings: None, # noqa: ARG001 db_session: Session, ) -> None: diff --git a/backend/tests/external_dependency_unit/slack_bot/__init__.py b/backend/tests/external_dependency_unit/slack_bot/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/tests/external_dependency_unit/slack_bot/test_slack_bot_crud.py b/backend/tests/external_dependency_unit/slack_bot/test_slack_bot_crud.py new file mode 100644 index 00000000000..89d6215eedc --- /dev/null +++ b/backend/tests/external_dependency_unit/slack_bot/test_slack_bot_crud.py @@ -0,0 +1,85 @@ +"""Tests that SlackBot CRUD operations return properly typed SensitiveValue fields. + +Regression test for the bug where insert_slack_bot/update_slack_bot returned +objects with raw string tokens instead of SensitiveValue wrappers, causing +'str object has no attribute get_value' errors in SlackBot.from_model(). +""" + +from uuid import uuid4 + +from sqlalchemy.orm import Session + +from onyx.db.slack_bot import insert_slack_bot +from onyx.db.slack_bot import update_slack_bot +from onyx.server.manage.models import SlackBot +from onyx.utils.sensitive import SensitiveValue + + +def _unique(prefix: str) -> str: + return f"{prefix}-{uuid4().hex[:8]}" + + +def test_insert_slack_bot_returns_sensitive_values(db_session: Session) -> None: + bot_token = _unique("xoxb-insert") + app_token = _unique("xapp-insert") + user_token = _unique("xoxp-insert") + + slack_bot = insert_slack_bot( + db_session=db_session, + name=_unique("test-bot-insert"), + enabled=True, + bot_token=bot_token, + app_token=app_token, + user_token=user_token, + ) + + assert isinstance(slack_bot.bot_token, SensitiveValue) + assert isinstance(slack_bot.app_token, SensitiveValue) + assert isinstance(slack_bot.user_token, SensitiveValue) + + assert slack_bot.bot_token.get_value(apply_mask=False) == bot_token + assert slack_bot.app_token.get_value(apply_mask=False) == app_token + assert slack_bot.user_token.get_value(apply_mask=False) == user_token + + # Verify from_model works without error + pydantic_bot = SlackBot.from_model(slack_bot) + assert pydantic_bot.bot_token # masked, but not empty + assert pydantic_bot.app_token + + +def test_update_slack_bot_returns_sensitive_values(db_session: Session) -> None: + slack_bot = insert_slack_bot( + db_session=db_session, + name=_unique("test-bot-update"), + enabled=True, + bot_token=_unique("xoxb-update"), + app_token=_unique("xapp-update"), + ) + + new_bot_token = _unique("xoxb-update-new") + new_app_token = _unique("xapp-update-new") + new_user_token = _unique("xoxp-update-new") + + updated = update_slack_bot( + db_session=db_session, + slack_bot_id=slack_bot.id, + name=_unique("test-bot-updated"), + enabled=False, + bot_token=new_bot_token, + app_token=new_app_token, + user_token=new_user_token, + ) + + assert isinstance(updated.bot_token, SensitiveValue) + assert isinstance(updated.app_token, SensitiveValue) + assert isinstance(updated.user_token, SensitiveValue) + + assert updated.bot_token.get_value(apply_mask=False) == new_bot_token + assert updated.app_token.get_value(apply_mask=False) == new_app_token + assert updated.user_token.get_value(apply_mask=False) == new_user_token + + # Verify from_model works without error + pydantic_bot = SlackBot.from_model(updated) + assert pydantic_bot.bot_token + assert pydantic_bot.app_token + assert pydantic_bot.user_token is not None diff --git a/backend/tests/external_dependency_unit/tools/test_oauth_config_crud.py b/backend/tests/external_dependency_unit/tools/test_oauth_config_crud.py index 77917968059..4033577e047 100644 --- a/backend/tests/external_dependency_unit/tools/test_oauth_config_crud.py +++ b/backend/tests/external_dependency_unit/tools/test_oauth_config_crud.py @@ -21,6 +21,8 @@ from onyx.db.oauth_config import get_user_oauth_token from onyx.db.oauth_config import update_oauth_config from onyx.db.oauth_config import upsert_user_oauth_token +from onyx.db.tools import delete_tool__no_commit +from onyx.db.tools import update_tool from tests.external_dependency_unit.conftest import create_test_user @@ -146,8 +148,16 @@ def test_update_oauth_config_preserves_secrets(self, db_session: Session) -> Non ) # Secrets should be preserved - assert updated_config.client_id == original_client_id - assert updated_config.client_secret == original_client_secret + assert updated_config.client_id is not None + assert original_client_id is not None + assert updated_config.client_id.get_value( + apply_mask=False + ) == original_client_id.get_value(apply_mask=False) + assert updated_config.client_secret is not None + assert original_client_secret is not None + assert updated_config.client_secret.get_value( + apply_mask=False + ) == original_client_secret.get_value(apply_mask=False) # But name should be updated assert updated_config.name == new_name @@ -171,9 +181,14 @@ def test_update_oauth_config_clear_client_id(self, db_session: Session) -> None: ) # client_id should be cleared (empty string) - assert updated_config.client_id == "" + assert updated_config.client_id is not None + assert updated_config.client_id.get_value(apply_mask=False) == "" # client_secret should be preserved - assert updated_config.client_secret == original_client_secret + assert updated_config.client_secret is not None + assert original_client_secret is not None + assert updated_config.client_secret.get_value( + apply_mask=False + ) == original_client_secret.get_value(apply_mask=False) def test_update_oauth_config_clear_client_secret(self, db_session: Session) -> None: """Test clearing client_secret while preserving client_id""" @@ -188,9 +203,14 @@ def test_update_oauth_config_clear_client_secret(self, db_session: Session) -> N ) # client_secret should be cleared (empty string) - assert updated_config.client_secret == "" + assert updated_config.client_secret is not None + assert updated_config.client_secret.get_value(apply_mask=False) == "" # client_id should be preserved - assert updated_config.client_id == original_client_id + assert updated_config.client_id is not None + assert original_client_id is not None + assert updated_config.client_id.get_value( + apply_mask=False + ) == original_client_id.get_value(apply_mask=False) def test_update_oauth_config_clear_both_secrets(self, db_session: Session) -> None: """Test clearing both client_id and client_secret""" @@ -205,8 +225,10 @@ def test_update_oauth_config_clear_both_secrets(self, db_session: Session) -> No ) # Both should be cleared (empty strings) - assert updated_config.client_id == "" - assert updated_config.client_secret == "" + assert updated_config.client_id is not None + assert updated_config.client_id.get_value(apply_mask=False) == "" + assert updated_config.client_secret is not None + assert updated_config.client_secret.get_value(apply_mask=False) == "" def test_update_oauth_config_authorization_url(self, db_session: Session) -> None: """Test updating authorization_url""" @@ -273,7 +295,8 @@ def test_update_oauth_config_multiple_fields(self, db_session: Session) -> None: assert updated_config.token_url == new_token_url assert updated_config.scopes == new_scopes assert updated_config.additional_params == new_params - assert updated_config.client_id == new_client_id + assert updated_config.client_id is not None + assert updated_config.client_id.get_value(apply_mask=False) == new_client_id def test_delete_oauth_config(self, db_session: Session) -> None: """Test deleting an OAuth configuration""" @@ -312,6 +335,85 @@ def test_delete_oauth_config_sets_tool_reference_to_null( # Tool should still exist but oauth_config_id should be NULL assert tool.oauth_config_id is None + def test_update_tool_cleans_up_orphaned_oauth_config( + self, db_session: Session + ) -> None: + """Test that changing a tool's oauth_config_id deletes the old config if no other tool uses it.""" + old_config = _create_test_oauth_config(db_session) + new_config = _create_test_oauth_config(db_session) + tool = _create_test_tool_with_oauth(db_session, old_config) + old_config_id = old_config.id + + update_tool( + tool_id=tool.id, + name=None, + description=None, + openapi_schema=None, + custom_headers=None, + user_id=None, + db_session=db_session, + passthrough_auth=None, + oauth_config_id=new_config.id, + ) + + assert tool.oauth_config_id == new_config.id + assert get_oauth_config(old_config_id, db_session) is None + + def test_delete_tool_cleans_up_orphaned_oauth_config( + self, db_session: Session + ) -> None: + """Test that deleting the last tool referencing an OAuthConfig also deletes the config.""" + config = _create_test_oauth_config(db_session) + tool = _create_test_tool_with_oauth(db_session, config) + config_id = config.id + + delete_tool__no_commit(tool.id, db_session) + db_session.commit() + + assert get_oauth_config(config_id, db_session) is None + + def test_update_tool_preserves_shared_oauth_config( + self, db_session: Session + ) -> None: + """Test that updating one tool's oauth_config_id preserves the config when another tool still uses it.""" + shared_config = _create_test_oauth_config(db_session) + new_config = _create_test_oauth_config(db_session) + tool_a = _create_test_tool_with_oauth(db_session, shared_config) + tool_b = _create_test_tool_with_oauth(db_session, shared_config) + shared_config_id = shared_config.id + + # Move tool_a to a new config; tool_b still references shared_config + update_tool( + tool_id=tool_a.id, + name=None, + description=None, + openapi_schema=None, + custom_headers=None, + user_id=None, + db_session=db_session, + passthrough_auth=None, + oauth_config_id=new_config.id, + ) + + assert tool_a.oauth_config_id == new_config.id + assert tool_b.oauth_config_id == shared_config_id + assert get_oauth_config(shared_config_id, db_session) is not None + + def test_delete_tool_preserves_shared_oauth_config( + self, db_session: Session + ) -> None: + """Test that deleting one tool preserves the config when another tool still uses it.""" + shared_config = _create_test_oauth_config(db_session) + tool_a = _create_test_tool_with_oauth(db_session, shared_config) + tool_b = _create_test_tool_with_oauth(db_session, shared_config) + shared_config_id = shared_config.id + + delete_tool__no_commit(tool_a.id, db_session) + db_session.commit() + + assert tool_b.oauth_config_id == shared_config_id + assert get_oauth_config(shared_config_id, db_session) is not None + class TestOAuthUserTokenCRUD: """Tests for OAuth user token CRUD operations""" @@ -335,7 +437,8 @@ def test_upsert_user_oauth_token_create(self, db_session: Session) -> None: assert user_token.id is not None assert user_token.oauth_config_id == oauth_config.id assert user_token.user_id == user.id - assert user_token.token_data == token_data + assert user_token.token_data is not None + assert user_token.token_data.get_value(apply_mask=False) == token_data assert user_token.created_at is not None assert user_token.updated_at is not None @@ -365,8 +468,13 @@ def test_upsert_user_oauth_token_update(self, db_session: Session) -> None: # Should be the same token record (updated, not inserted) assert updated_token.id == initial_token_id - assert updated_token.token_data == updated_token_data - assert updated_token.token_data != initial_token_data + assert updated_token.token_data is not None + assert ( + updated_token.token_data.get_value(apply_mask=False) == updated_token_data + ) + assert ( + updated_token.token_data.get_value(apply_mask=False) != initial_token_data + ) def test_get_user_oauth_token(self, db_session: Session) -> None: """Test retrieving a user's OAuth token""" @@ -382,7 +490,8 @@ def test_get_user_oauth_token(self, db_session: Session) -> None: assert retrieved_token is not None assert retrieved_token.id == created_token.id - assert retrieved_token.token_data == token_data + assert retrieved_token.token_data is not None + assert retrieved_token.token_data.get_value(apply_mask=False) == token_data def test_get_user_oauth_token_not_found(self, db_session: Session) -> None: """Test retrieving a non-existent user token returns None""" @@ -438,7 +547,8 @@ def test_unique_constraint_on_user_config(self, db_session: Session) -> None: retrieved_token = get_user_oauth_token(oauth_config.id, user.id, db_session) assert retrieved_token is not None assert retrieved_token.id == updated_token.id - assert retrieved_token.token_data == token_data2 + assert retrieved_token.token_data is not None + assert retrieved_token.token_data.get_value(apply_mask=False) == token_data2 def test_cascade_delete_user_tokens_on_config_deletion( self, db_session: Session diff --git a/backend/tests/external_dependency_unit/tools/test_oauth_token_manager.py b/backend/tests/external_dependency_unit/tools/test_oauth_token_manager.py index 77502737068..b4ee624fbcf 100644 --- a/backend/tests/external_dependency_unit/tools/test_oauth_token_manager.py +++ b/backend/tests/external_dependency_unit/tools/test_oauth_token_manager.py @@ -374,8 +374,14 @@ def test_exchange_code_for_token_success( assert call_args[0][0] == oauth_config.token_url assert call_args[1]["data"]["grant_type"] == "authorization_code" assert call_args[1]["data"]["code"] == "auth_code_123" - assert call_args[1]["data"]["client_id"] == oauth_config.client_id - assert call_args[1]["data"]["client_secret"] == oauth_config.client_secret + assert oauth_config.client_id is not None + assert oauth_config.client_secret is not None + assert call_args[1]["data"]["client_id"] == oauth_config.client_id.get_value( + apply_mask=False + ) + assert call_args[1]["data"][ + "client_secret" + ] == oauth_config.client_secret.get_value(apply_mask=False) assert call_args[1]["data"]["redirect_uri"] == "https://example.com/callback" @patch("onyx.auth.oauth_token_manager.requests.post") diff --git a/backend/tests/external_dependency_unit/tools/test_python_tool.py b/backend/tests/external_dependency_unit/tools/test_python_tool.py index 5e0915b9560..0f03c07e358 100644 --- a/backend/tests/external_dependency_unit/tools/test_python_tool.py +++ b/backend/tests/external_dependency_unit/tools/test_python_tool.py @@ -933,6 +933,7 @@ import pytest from fastapi import UploadFile +from fastapi.background import BackgroundTasks from sqlalchemy.orm import Session from starlette.datastructures import Headers @@ -949,6 +950,7 @@ from onyx.server.query_and_chat.streaming_models import PythonToolDelta from onyx.server.query_and_chat.streaming_models import PythonToolStart from onyx.server.query_and_chat.streaming_models import SectionEnd +from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta from onyx.tools.tool_implementations.python.python_tool import PythonTool from tests.external_dependency_unit.answer.stream_test_builder import StreamTestBuilder from tests.external_dependency_unit.answer.stream_test_utils import create_chat_session @@ -1026,6 +1028,13 @@ def do_POST(self) -> None: else: self._respond_json(404, {"error": "not found"}) + def do_GET(self) -> None: + self._capture("GET", b"") + if self.path == "/health": + self._respond_json(200, {"status": "ok"}) + else: + self._respond_json(404, {"error": "not found"}) + def do_DELETE(self) -> None: self._capture("DELETE", b"") self.send_response(200) @@ -1106,6 +1115,14 @@ def mock_ci_server() -> Generator[MockCodeInterpreterServer, None, None]: server.shutdown() +@pytest.fixture(autouse=True) +def _clear_health_cache() -> None: + """Reset the health check cache before every test.""" + import onyx.tools.tool_implementations.python.code_interpreter_client as mod + + mod._health_cache = {} + + @pytest.fixture() def _attach_python_tool_to_default_persona(db_session: Session) -> None: """Ensure the default persona (id=0) has the PythonTool attached.""" @@ -1139,6 +1156,7 @@ def test_code_interpreter_receives_chat_files( # Upload a test CSV csv_content = b"name,age,city\nAlice,30,NYC\nBob,25,SF\n" result = upload_user_files( + bg_tasks=BackgroundTasks(), files=[ UploadFile( file=io.BytesIO(csv_content), @@ -1277,9 +1295,18 @@ def test_code_interpreter_replay_packets_include_code_and_output( ).expect( Packet( placement=create_placement(0), - obj=PythonToolStart(code=code), + obj=ToolCallArgumentDelta( + tool_type="python", + argument_deltas={"code": code}, + ), ), forward=2, + ).expect( + Packet( + placement=create_placement(0), + obj=PythonToolStart(code=code), + ), + forward=False, ).expect( Packet( placement=create_placement(0), diff --git a/backend/tests/integration/common_utils/managers/scim_client.py b/backend/tests/integration/common_utils/managers/scim_client.py new file mode 100644 index 00000000000..e1becbf752b --- /dev/null +++ b/backend/tests/integration/common_utils/managers/scim_client.py @@ -0,0 +1,66 @@ +import requests + +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS + + +class ScimClient: + """HTTP client for making authenticated SCIM v2 requests.""" + + @staticmethod + def _headers(raw_token: str) -> dict[str, str]: + return { + **GENERAL_HEADERS, + "Authorization": f"Bearer {raw_token}", + } + + @staticmethod + def get(path: str, raw_token: str) -> requests.Response: + return requests.get( + f"{API_SERVER_URL}/scim/v2{path}", + headers=ScimClient._headers(raw_token), + timeout=60, + ) + + @staticmethod + def post(path: str, raw_token: str, json: dict) -> requests.Response: + return requests.post( + f"{API_SERVER_URL}/scim/v2{path}", + json=json, + headers=ScimClient._headers(raw_token), + timeout=60, + ) + + @staticmethod + def put(path: str, raw_token: str, json: dict) -> requests.Response: + return requests.put( + f"{API_SERVER_URL}/scim/v2{path}", + json=json, + headers=ScimClient._headers(raw_token), + timeout=60, + ) + + @staticmethod + def patch(path: str, raw_token: str, json: dict) -> requests.Response: + return requests.patch( + f"{API_SERVER_URL}/scim/v2{path}", + json=json, + headers=ScimClient._headers(raw_token), + timeout=60, + ) + + @staticmethod + def delete(path: str, raw_token: str) -> requests.Response: + return requests.delete( + f"{API_SERVER_URL}/scim/v2{path}", + headers=ScimClient._headers(raw_token), + timeout=60, + ) + + @staticmethod + def get_no_auth(path: str) -> requests.Response: + return requests.get( + f"{API_SERVER_URL}/scim/v2{path}", + headers=GENERAL_HEADERS, + timeout=60, + ) diff --git a/backend/tests/integration/common_utils/managers/scim_token.py b/backend/tests/integration/common_utils/managers/scim_token.py index 3ea020a07e2..1894ed4f321 100644 --- a/backend/tests/integration/common_utils/managers/scim_token.py +++ b/backend/tests/integration/common_utils/managers/scim_token.py @@ -1,7 +1,6 @@ import requests from tests.integration.common_utils.constants import API_SERVER_URL -from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.test_models import DATestScimToken from tests.integration.common_utils.test_models import DATestUser @@ -51,29 +50,3 @@ def get_active( created_at=data["created_at"], last_used_at=data.get("last_used_at"), ) - - @staticmethod - def get_scim_headers(raw_token: str) -> dict[str, str]: - return { - **GENERAL_HEADERS, - "Authorization": f"Bearer {raw_token}", - } - - @staticmethod - def scim_get( - path: str, - raw_token: str, - ) -> requests.Response: - return requests.get( - f"{API_SERVER_URL}/scim/v2{path}", - headers=ScimTokenManager.get_scim_headers(raw_token), - timeout=60, - ) - - @staticmethod - def scim_get_no_auth(path: str) -> requests.Response: - return requests.get( - f"{API_SERVER_URL}/scim/v2{path}", - headers=GENERAL_HEADERS, - timeout=60, - ) diff --git a/backend/tests/integration/tests/discord_bot/test_discord_bot_db.py b/backend/tests/integration/tests/discord_bot/test_discord_bot_db.py index 2a9abf7083a..c4cb95e9d32 100644 --- a/backend/tests/integration/tests/discord_bot/test_discord_bot_db.py +++ b/backend/tests/integration/tests/discord_bot/test_discord_bot_db.py @@ -64,7 +64,8 @@ def test_create_bot_config(self, db_session: Session) -> None: db_session.commit() assert config is not None - assert config.bot_token == "test_token_123" + assert config.bot_token is not None + assert config.bot_token.get_value(apply_mask=False) == "test_token_123" # Cleanup delete_discord_bot_config(db_session) diff --git a/backend/tests/integration/tests/indexing/test_checkpointing.py b/backend/tests/integration/tests/indexing/test_checkpointing.py index 7760234d24e..869272a1848 100644 --- a/backend/tests/integration/tests/indexing/test_checkpointing.py +++ b/backend/tests/integration/tests/indexing/test_checkpointing.py @@ -414,6 +414,24 @@ def test_mock_connector_checkpoint_recovery( ) assert finished_index_attempt.status == IndexingStatus.FAILED + # Pause the connector immediately to prevent check_for_indexing from + # creating automatic retry attempts while we reset the mock server. + # Without this, the INITIAL_INDEXING status causes immediate retries + # that would consume (or fail against) the mock server before we can + # set up the recovery behavior. + CCPairManager.pause_cc_pair(cc_pair, user_performing_action=admin_user) + + # Collect all index attempt IDs created so far (the initial one plus + # any automatic retries that may have started before the pause took effect). + all_prior_attempt_ids: list[int] = [] + index_attempts_page = IndexAttemptManager.get_index_attempt_page( + cc_pair_id=cc_pair.id, + page=0, + page_size=100, + user_performing_action=admin_user, + ) + all_prior_attempt_ids = [ia.id for ia in index_attempts_page.items] + # Verify initial state: both docs should be indexed with get_session_with_current_tenant() as db_session: documents = DocumentManager.fetch_documents_for_cc_pair( @@ -465,17 +483,14 @@ def test_mock_connector_checkpoint_recovery( ) assert response.status_code == 200 - # After the failure, the connector is in repeated error state and paused. - # Set the manual indexing trigger first (while paused), then unpause. - # This ensures the trigger is set before CHECK_FOR_INDEXING runs, which will - # prevent the connector from being re-paused when repeated error state is detected. + # Set the manual indexing trigger, then unpause to allow the recovery run. CCPairManager.run_once( cc_pair, from_beginning=False, user_performing_action=admin_user ) CCPairManager.unpause_cc_pair(cc_pair, user_performing_action=admin_user) recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start( cc_pair_id=cc_pair.id, - index_attempts_to_ignore=[initial_index_attempt.id], + index_attempts_to_ignore=all_prior_attempt_ids, user_performing_action=admin_user, ) IndexAttemptManager.wait_for_index_attempt_completion( diff --git a/backend/tests/integration/tests/indexing/test_repeated_error_state.py b/backend/tests/integration/tests/indexing/test_repeated_error_state.py index 039da0e6148..cea30095fca 100644 --- a/backend/tests/integration/tests/indexing/test_repeated_error_state.py +++ b/backend/tests/integration/tests/indexing/test_repeated_error_state.py @@ -130,8 +130,8 @@ def test_repeated_error_state_detection_and_recovery( # ) break - if time.monotonic() - start_time > 30: - assert False, "CC pair did not enter repeated error state within 30 seconds" + if time.monotonic() - start_time > 90: + assert False, "CC pair did not enter repeated error state within 90 seconds" time.sleep(2) diff --git a/backend/tests/integration/tests/llm_provider/test_llm_provider.py b/backend/tests/integration/tests/llm_provider/test_llm_provider.py index 945924068a5..f5690fccc03 100644 --- a/backend/tests/integration/tests/llm_provider/test_llm_provider.py +++ b/backend/tests/integration/tests/llm_provider/test_llm_provider.py @@ -386,6 +386,261 @@ def test_delete_llm_provider( assert provider_data is None +def test_delete_default_llm_provider_rejected(reset: None) -> None: # noqa: ARG001 + """Deleting the default LLM provider should return 400.""" + admin_user = UserManager.create(name="admin_user") + + # Create a provider + response = requests.put( + f"{API_SERVER_URL}/admin/llm/provider?is_creation=true", + headers=admin_user.headers, + json={ + "name": "test-provider-default-delete", + "provider": LlmProviderNames.OPENAI, + "api_key": "sk-000000000000000000000000000000000000000000000000", + "model_configurations": [ + ModelConfigurationUpsertRequest( + name="gpt-4o-mini", is_visible=True + ).model_dump() + ], + "is_public": True, + "groups": [], + }, + ) + assert response.status_code == 200 + created_provider = response.json() + + # Set this provider as the default + set_default_response = requests.post( + f"{API_SERVER_URL}/admin/llm/default", + headers=admin_user.headers, + json={ + "provider_id": created_provider["id"], + "model_name": "gpt-4o-mini", + }, + ) + assert set_default_response.status_code == 200 + + # Attempt to delete the default provider — should be rejected + delete_response = requests.delete( + f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}", + headers=admin_user.headers, + ) + assert delete_response.status_code == 400 + assert "Cannot delete the default LLM provider" in delete_response.json()["detail"] + + # Verify provider still exists + provider_data = _get_provider_by_id(admin_user, created_provider["id"]) + assert provider_data is not None + + +def test_delete_non_default_llm_provider_with_default_set( + reset: None, # noqa: ARG001 +) -> None: + """Deleting a non-default provider should succeed even when a default is set.""" + admin_user = UserManager.create(name="admin_user") + + # Create two providers + response_default = requests.put( + f"{API_SERVER_URL}/admin/llm/provider?is_creation=true", + headers=admin_user.headers, + json={ + "name": "default-provider", + "provider": LlmProviderNames.OPENAI, + "api_key": "sk-000000000000000000000000000000000000000000000000", + "model_configurations": [ + ModelConfigurationUpsertRequest( + name="gpt-4o-mini", is_visible=True + ).model_dump() + ], + "is_public": True, + "groups": [], + }, + ) + assert response_default.status_code == 200 + default_provider = response_default.json() + + response_other = requests.put( + f"{API_SERVER_URL}/admin/llm/provider?is_creation=true", + headers=admin_user.headers, + json={ + "name": "other-provider", + "provider": LlmProviderNames.OPENAI, + "api_key": "sk-000000000000000000000000000000000000000000000000", + "model_configurations": [ + ModelConfigurationUpsertRequest( + name="gpt-4o", is_visible=True + ).model_dump() + ], + "is_public": True, + "groups": [], + }, + ) + assert response_other.status_code == 200 + other_provider = response_other.json() + + # Set the first provider as default + set_default_response = requests.post( + f"{API_SERVER_URL}/admin/llm/default", + headers=admin_user.headers, + json={ + "provider_id": default_provider["id"], + "model_name": "gpt-4o-mini", + }, + ) + assert set_default_response.status_code == 200 + + # Delete the non-default provider — should succeed + delete_response = requests.delete( + f"{API_SERVER_URL}/admin/llm/provider/{other_provider['id']}", + headers=admin_user.headers, + ) + assert delete_response.status_code == 200 + + # Verify the non-default provider is gone + provider_data = _get_provider_by_id(admin_user, other_provider["id"]) + assert provider_data is None + + # Verify the default provider still exists + default_data = _get_provider_by_id(admin_user, default_provider["id"]) + assert default_data is not None + + +def test_force_delete_default_llm_provider( + reset: None, # noqa: ARG001 +) -> None: + """Force-deleting the default LLM provider should succeed.""" + admin_user = UserManager.create(name="admin_user") + + # Create a provider + response = requests.put( + f"{API_SERVER_URL}/admin/llm/provider?is_creation=true", + headers=admin_user.headers, + json={ + "name": "test-provider-force-delete", + "provider": LlmProviderNames.OPENAI, + "api_key": "sk-000000000000000000000000000000000000000000000000", + "model_configurations": [ + ModelConfigurationUpsertRequest( + name="gpt-4o-mini", is_visible=True + ).model_dump() + ], + "is_public": True, + "groups": [], + }, + ) + assert response.status_code == 200 + created_provider = response.json() + + # Set this provider as the default + set_default_response = requests.post( + f"{API_SERVER_URL}/admin/llm/default", + headers=admin_user.headers, + json={ + "provider_id": created_provider["id"], + "model_name": "gpt-4o-mini", + }, + ) + assert set_default_response.status_code == 200 + + # Attempt to delete without force — should be rejected + delete_response = requests.delete( + f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}", + headers=admin_user.headers, + ) + assert delete_response.status_code == 400 + + # Force delete — should succeed + force_delete_response = requests.delete( + f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}?force=true", + headers=admin_user.headers, + ) + assert force_delete_response.status_code == 200 + + # Verify provider is gone + provider_data = _get_provider_by_id(admin_user, created_provider["id"]) + assert provider_data is None + + +def test_delete_default_vision_provider_clears_vision_default( + reset: None, # noqa: ARG001 +) -> None: + """Deleting the default vision provider should succeed and clear the vision default.""" + admin_user = UserManager.create(name="admin_user") + + # Create a text provider and set it as default (so we have a default text provider) + text_response = requests.put( + f"{API_SERVER_URL}/admin/llm/provider?is_creation=true", + headers=admin_user.headers, + json={ + "name": "text-provider", + "provider": LlmProviderNames.OPENAI, + "api_key": "sk-000000000000000000000000000000000000000000000001", + "model_configurations": [ + ModelConfigurationUpsertRequest( + name="gpt-4o-mini", is_visible=True + ).model_dump() + ], + "is_public": True, + "groups": [], + }, + ) + assert text_response.status_code == 200 + text_provider = text_response.json() + _set_default_provider(admin_user, text_provider["id"], "gpt-4o-mini") + + # Create a vision provider and set it as default vision + vision_response = requests.put( + f"{API_SERVER_URL}/admin/llm/provider?is_creation=true", + headers=admin_user.headers, + json={ + "name": "vision-provider", + "provider": LlmProviderNames.OPENAI, + "api_key": "sk-000000000000000000000000000000000000000000000002", + "model_configurations": [ + ModelConfigurationUpsertRequest( + name="gpt-4o", + is_visible=True, + supports_image_input=True, + ).model_dump() + ], + "is_public": True, + "groups": [], + }, + ) + assert vision_response.status_code == 200 + vision_provider = vision_response.json() + _set_default_vision_provider(admin_user, vision_provider["id"], "gpt-4o") + + # Verify vision default is set + data = _get_providers_admin(admin_user) + assert data is not None + _, _, vision_default = _unpack_data(data) + assert vision_default is not None + assert vision_default["provider_id"] == vision_provider["id"] + + # Delete the vision provider — should succeed (only text default is protected) + delete_response = requests.delete( + f"{API_SERVER_URL}/admin/llm/provider/{vision_provider['id']}", + headers=admin_user.headers, + ) + assert delete_response.status_code == 200 + + # Verify the vision provider is gone + provider_data = _get_provider_by_id(admin_user, vision_provider["id"]) + assert provider_data is None + + # Verify there is no default vision provider + data = _get_providers_admin(admin_user) + assert data is not None + _, text_default, vision_default = _unpack_data(data) + assert vision_default is None + + # Verify the text default is still intact + assert text_default is not None + assert text_default["provider_id"] == text_provider["id"] + + def test_duplicate_provider_name_rejected(reset: None) -> None: # noqa: ARG001 """Creating a provider with a name that already exists should return 400.""" admin_user = UserManager.create(name="admin_user") @@ -418,7 +673,7 @@ def test_duplicate_provider_name_rejected(reset: None) -> None: # noqa: ARG001 headers=admin_user.headers, json=base_payload, ) - assert response.status_code == 400 + assert response.status_code == 409 assert "already exists" in response.json()["detail"] diff --git a/backend/tests/integration/tests/llm_provider/test_llm_provider_access_control.py b/backend/tests/integration/tests/llm_provider/test_llm_provider_access_control.py index b1ce4bf34e9..f66e19fef24 100644 --- a/backend/tests/integration/tests/llm_provider/test_llm_provider_access_control.py +++ b/backend/tests/integration/tests/llm_provider/test_llm_provider_access_control.py @@ -243,6 +243,116 @@ def test_can_user_access_llm_provider_or_logic( ) +def test_public_provider_with_persona_restrictions( + users: tuple[DATestUser, DATestUser], +) -> None: + """Public providers should still enforce persona restrictions. + + Regression test for the bug where is_public=True caused + can_user_access_llm_provider() to return True immediately, + bypassing persona whitelist checks entirely. + """ + admin_user, _basic_user = users + + with get_session_with_current_tenant() as db_session: + # Public provider with persona restrictions + public_restricted = _create_llm_provider( + db_session, + name="public-persona-restricted", + default_model_name="gpt-4o", + is_public=True, + is_default=True, + ) + + whitelisted_persona = _create_persona( + db_session, + name="whitelisted-persona", + provider_name=public_restricted.name, + ) + non_whitelisted_persona = _create_persona( + db_session, + name="non-whitelisted-persona", + provider_name=public_restricted.name, + ) + + # Only whitelist one persona + db_session.add( + LLMProvider__Persona( + llm_provider_id=public_restricted.id, + persona_id=whitelisted_persona.id, + ) + ) + db_session.flush() + db_session.refresh(public_restricted) + + admin_model = db_session.get(User, admin_user.id) + assert admin_model is not None + admin_group_ids = fetch_user_group_ids(db_session, admin_model) + + # Whitelisted persona — should be allowed + assert can_user_access_llm_provider( + public_restricted, + admin_group_ids, + whitelisted_persona, + ) + + # Non-whitelisted persona — should be denied despite is_public=True + assert not can_user_access_llm_provider( + public_restricted, + admin_group_ids, + non_whitelisted_persona, + ) + + # No persona context (e.g. global provider list) — should be denied + # because provider has persona restrictions set + assert not can_user_access_llm_provider( + public_restricted, + admin_group_ids, + persona=None, + ) + + +def test_public_provider_without_persona_restrictions( + users: tuple[DATestUser, DATestUser], +) -> None: + """Public providers with no persona restrictions remain accessible to all.""" + admin_user, basic_user = users + + with get_session_with_current_tenant() as db_session: + public_unrestricted = _create_llm_provider( + db_session, + name="public-unrestricted", + default_model_name="gpt-4o", + is_public=True, + is_default=True, + ) + + any_persona = _create_persona( + db_session, + name="any-persona", + provider_name=public_unrestricted.name, + ) + + admin_model = db_session.get(User, admin_user.id) + basic_model = db_session.get(User, basic_user.id) + assert admin_model is not None + assert basic_model is not None + + admin_group_ids = fetch_user_group_ids(db_session, admin_model) + basic_group_ids = fetch_user_group_ids(db_session, basic_model) + + # Any user, any persona — all allowed + assert can_user_access_llm_provider( + public_unrestricted, admin_group_ids, any_persona + ) + assert can_user_access_llm_provider( + public_unrestricted, basic_group_ids, any_persona + ) + assert can_user_access_llm_provider( + public_unrestricted, admin_group_ids, persona=None + ) + + def test_get_llm_for_persona_falls_back_when_access_denied( users: tuple[DATestUser, DATestUser], ) -> None: diff --git a/backend/tests/integration/tests/llm_workflows/test_nightly_provider_chat_workflow.py b/backend/tests/integration/tests/llm_workflows/test_nightly_provider_chat_workflow.py index 2df6fbeac73..50b9124f049 100644 --- a/backend/tests/integration/tests/llm_workflows/test_nightly_provider_chat_workflow.py +++ b/backend/tests/integration/tests/llm_workflows/test_nightly_provider_chat_workflow.py @@ -42,6 +42,78 @@ class NightlyProviderConfig(BaseModel): strict: bool +def _stringify_custom_config_value(value: object) -> str: + if isinstance(value, str): + return value + if isinstance(value, (dict, list)): + return json.dumps(value) + return str(value) + + +def _looks_like_vertex_credentials_payload( + raw_custom_config: dict[object, object], +) -> bool: + normalized_keys = {str(key).strip().lower() for key in raw_custom_config} + provider_specific_keys = { + "vertex_credentials", + "credentials_file", + "vertex_credentials_file", + "google_application_credentials", + "vertex_location", + "location", + "vertex_region", + "region", + } + if normalized_keys & provider_specific_keys: + return False + + normalized_type = str(raw_custom_config.get("type", "")).strip().lower() + if normalized_type not in {"service_account", "external_account"}: + return False + + # Service account JSON usually includes private_key/client_email, while external + # account JSON includes credential_source. Either shape should be accepted. + has_service_account_markers = any( + key in normalized_keys for key in {"private_key", "client_email"} + ) + has_external_account_markers = "credential_source" in normalized_keys + return has_service_account_markers or has_external_account_markers + + +def _normalize_custom_config( + provider: str, raw_custom_config: dict[object, object] +) -> dict[str, str]: + if provider == "vertex_ai" and _looks_like_vertex_credentials_payload( + raw_custom_config + ): + return {"vertex_credentials": json.dumps(raw_custom_config)} + + normalized: dict[str, str] = {} + for raw_key, raw_value in raw_custom_config.items(): + key = str(raw_key).strip() + key_lower = key.lower() + + if provider == "vertex_ai": + if key_lower in { + "vertex_credentials", + "credentials_file", + "vertex_credentials_file", + "google_application_credentials", + }: + key = "vertex_credentials" + elif key_lower in { + "vertex_location", + "location", + "vertex_region", + "region", + }: + key = "vertex_location" + + normalized[key] = _stringify_custom_config_value(raw_value) + + return normalized + + def _env_true(env_var: str, default: bool = False) -> bool: value = os.environ.get(env_var) if value is None: @@ -80,7 +152,9 @@ def _load_provider_config() -> NightlyProviderConfig: parsed = json.loads(custom_config_json) if not isinstance(parsed, dict): raise ValueError(f"{_ENV_CUSTOM_CONFIG_JSON} must be a JSON object") - custom_config = {str(key): str(value) for key, value in parsed.items()} + custom_config = _normalize_custom_config( + provider=provider, raw_custom_config=parsed + ) if provider == "ollama_chat" and api_key and not custom_config: custom_config = {"OLLAMA_API_KEY": api_key} @@ -148,6 +222,23 @@ def _validate_provider_config(config: NightlyProviderConfig) -> None: ), ) + if config.provider == "vertex_ai": + has_vertex_credentials = bool( + config.custom_config and config.custom_config.get("vertex_credentials") + ) + if not has_vertex_credentials: + configured_keys = ( + sorted(config.custom_config.keys()) if config.custom_config else [] + ) + _skip_or_fail( + strict=config.strict, + message=( + f"{_ENV_CUSTOM_CONFIG_JSON} must include 'vertex_credentials' " + f"for provider '{config.provider}'. " + f"Found keys: {configured_keys}" + ), + ) + def _assert_integration_mode_enabled() -> None: assert ( @@ -193,6 +284,7 @@ def _create_provider_payload( return { "name": provider_name, "provider": provider, + "model": model_name, "api_key": api_key, "api_base": api_base, "api_version": api_version, @@ -208,24 +300,23 @@ def _create_provider_payload( } -def _ensure_provider_is_default(provider_id: int, admin_user: DATestUser) -> None: +def _ensure_provider_is_default( + provider_id: int, model_name: str, admin_user: DATestUser +) -> None: list_response = requests.get( f"{API_SERVER_URL}/admin/llm/provider", headers=admin_user.headers, ) list_response.raise_for_status() - providers = list_response.json() - - current_default = next( - (provider for provider in providers if provider.get("is_default_provider")), - None, + default_text = list_response.json().get("default_text") + assert default_text is not None, "Expected a default provider after setting default" + assert default_text.get("provider_id") == provider_id, ( + f"Expected provider {provider_id} to be default, " + f"found {default_text.get('provider_id')}" ) assert ( - current_default is not None - ), "Expected a default provider after setting provider as default" - assert ( - current_default["id"] == provider_id - ), f"Expected provider {provider_id} to be default, found {current_default['id']}" + default_text.get("model_name") == model_name + ), f"Expected default model {model_name}, found {default_text.get('model_name')}" def _run_chat_assertions( @@ -326,8 +417,9 @@ def _create_and_test_provider_for_model( try: set_default_response = requests.post( - f"{API_SERVER_URL}/admin/llm/provider/{provider_id}/default", + f"{API_SERVER_URL}/admin/llm/default", headers=admin_user.headers, + json={"provider_id": provider_id, "model_name": model_name}, ) assert set_default_response.status_code == 200, ( f"Setting default provider failed for provider={config.provider} " @@ -335,7 +427,9 @@ def _create_and_test_provider_for_model( f"{set_default_response.text}" ) - _ensure_provider_is_default(provider_id=provider_id, admin_user=admin_user) + _ensure_provider_is_default( + provider_id=provider_id, model_name=model_name, admin_user=admin_user + ) _run_chat_assertions( admin_user=admin_user, search_tool_id=search_tool_id, diff --git a/backend/tests/integration/tests/no_vectordb/test_no_vectordb_file_lifecycle.py b/backend/tests/integration/tests/no_vectordb/test_no_vectordb_file_lifecycle.py new file mode 100644 index 00000000000..fa552b4d7aa --- /dev/null +++ b/backend/tests/integration/tests/no_vectordb/test_no_vectordb_file_lifecycle.py @@ -0,0 +1,160 @@ +"""Integration test for the full user-file lifecycle in no-vector-DB mode. + +Covers: upload → COMPLETED → unlink from project → delete → gone. + +The entire lifecycle is handled by FastAPI BackgroundTasks (no Celery workers +needed). The conftest-level ``pytestmark`` ensures these tests are skipped +when the server is running with vector DB enabled. +""" + +import time +from uuid import UUID + +import requests + +from onyx.db.enums import UserFileStatus +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.managers.project import ProjectManager +from tests.integration.common_utils.test_models import DATestLLMProvider +from tests.integration.common_utils.test_models import DATestUser + +POLL_INTERVAL_SECONDS = 1 +POLL_TIMEOUT_SECONDS = 30 + + +def _poll_file_status( + file_id: UUID, + user: DATestUser, + target_status: UserFileStatus, + timeout: int = POLL_TIMEOUT_SECONDS, +) -> None: + """Poll GET /user/projects/file/{file_id} until the file reaches *target_status*.""" + deadline = time.time() + timeout + while time.time() < deadline: + resp = requests.get( + f"{API_SERVER_URL}/user/projects/file/{file_id}", + headers=user.headers, + ) + if resp.ok: + status = resp.json().get("status") + if status == target_status.value: + return + time.sleep(POLL_INTERVAL_SECONDS) + raise TimeoutError( + f"File {file_id} did not reach {target_status.value} within {timeout}s" + ) + + +def _file_is_gone(file_id: UUID, user: DATestUser, timeout: int = 15) -> None: + """Poll until GET /user/projects/file/{file_id} returns 404.""" + deadline = time.time() + timeout + while time.time() < deadline: + resp = requests.get( + f"{API_SERVER_URL}/user/projects/file/{file_id}", + headers=user.headers, + ) + if resp.status_code == 404: + return + time.sleep(POLL_INTERVAL_SECONDS) + raise TimeoutError( + f"File {file_id} still accessible after {timeout}s (expected 404)" + ) + + +def test_file_upload_process_delete_lifecycle( + reset: None, # noqa: ARG001 + admin_user: DATestUser, + llm_provider: DATestLLMProvider, # noqa: ARG001 +) -> None: + """Full lifecycle: upload → COMPLETED → unlink → delete → 404. + + Validates that the API server handles all background processing + (via FastAPI BackgroundTasks) without any Celery workers running. + """ + project = ProjectManager.create( + name="lifecycle-test", user_performing_action=admin_user + ) + + file_content = b"Integration test file content for lifecycle verification." + upload_result = ProjectManager.upload_files( + project_id=project.id, + files=[("lifecycle.txt", file_content)], + user_performing_action=admin_user, + ) + assert upload_result.user_files, "Expected at least one file in upload response" + + user_file = upload_result.user_files[0] + file_id = user_file.id + + _poll_file_status(file_id, admin_user, UserFileStatus.COMPLETED) + + project_files = ProjectManager.get_project_files(project.id, admin_user) + assert any( + f.id == file_id for f in project_files + ), "File should be listed in project files after processing" + + # Unlink the file from the project so the delete endpoint will proceed + unlink_resp = requests.delete( + f"{API_SERVER_URL}/user/projects/{project.id}/files/{file_id}", + headers=admin_user.headers, + ) + assert ( + unlink_resp.status_code == 204 + ), f"Expected 204 on unlink, got {unlink_resp.status_code}: {unlink_resp.text}" + + delete_resp = requests.delete( + f"{API_SERVER_URL}/user/projects/file/{file_id}", + headers=admin_user.headers, + ) + assert ( + delete_resp.ok + ), f"Delete request failed: {delete_resp.status_code} {delete_resp.text}" + body = delete_resp.json() + assert ( + body["has_associations"] is False + ), f"File still has associations after unlink: {body}" + + _file_is_gone(file_id, admin_user) + + project_files_after = ProjectManager.get_project_files(project.id, admin_user) + assert not any( + f.id == file_id for f in project_files_after + ), "Deleted file should not appear in project files" + + +def test_delete_blocked_while_associated( + reset: None, # noqa: ARG001 + admin_user: DATestUser, + llm_provider: DATestLLMProvider, # noqa: ARG001 +) -> None: + """Deleting a file that still belongs to a project should return + has_associations=True without actually deleting the file.""" + project = ProjectManager.create( + name="assoc-test", user_performing_action=admin_user + ) + + upload_result = ProjectManager.upload_files( + project_id=project.id, + files=[("assoc.txt", b"associated file content")], + user_performing_action=admin_user, + ) + file_id = upload_result.user_files[0].id + + _poll_file_status(file_id, admin_user, UserFileStatus.COMPLETED) + + # Attempt to delete while still linked + delete_resp = requests.delete( + f"{API_SERVER_URL}/user/projects/file/{file_id}", + headers=admin_user.headers, + ) + assert delete_resp.ok + body = delete_resp.json() + assert body["has_associations"] is True, "Should report existing associations" + assert project.name in body["project_names"] + + # File should still be accessible + get_resp = requests.get( + f"{API_SERVER_URL}/user/projects/file/{file_id}", + headers=admin_user.headers, + ) + assert get_resp.status_code == 200, "File should still exist after blocked delete" diff --git a/backend/tests/integration/tests/permissions/test_file_connector_permissions.py b/backend/tests/integration/tests/permissions/test_file_connector_permissions.py new file mode 100644 index 00000000000..2b582e636ce --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_file_connector_permissions.py @@ -0,0 +1,234 @@ +import io +import json +import os + +import pytest +import requests + +from onyx.db.enums import AccessType +from onyx.db.models import UserRole +from onyx.server.documents.models import DocumentSource +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.connector import ConnectorManager +from tests.integration.common_utils.managers.credential import CredentialManager +from tests.integration.common_utils.managers.user import DATestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def _upload_connector_file( + *, + user_performing_action: DATestUser, + file_name: str, + content: bytes, +) -> tuple[str, str]: + headers = user_performing_action.headers.copy() + headers.pop("Content-Type", None) + + response = requests.post( + f"{API_SERVER_URL}/manage/admin/connector/file/upload", + files=[("files", (file_name, io.BytesIO(content), "text/plain"))], + headers=headers, + ) + response.raise_for_status() + payload = response.json() + return payload["file_paths"][0], payload["file_names"][0] + + +def _update_connector_files( + *, + connector_id: int, + user_performing_action: DATestUser, + file_ids_to_remove: list[str], + new_file_name: str, + new_file_content: bytes, +) -> requests.Response: + headers = user_performing_action.headers.copy() + headers.pop("Content-Type", None) + + return requests.post( + f"{API_SERVER_URL}/manage/admin/connector/{connector_id}/files/update", + data={"file_ids_to_remove": json.dumps(file_ids_to_remove)}, + files=[("files", (new_file_name, io.BytesIO(new_file_content), "text/plain"))], + headers=headers, + ) + + +def _list_connector_files( + *, + connector_id: int, + user_performing_action: DATestUser, +) -> requests.Response: + return requests.get( + f"{API_SERVER_URL}/manage/admin/connector/{connector_id}/files", + headers=user_performing_action.headers, + ) + + +@pytest.mark.skipif( + os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true", + reason="Curator and user group tests are enterprise only", +) +@pytest.mark.usefixtures("reset") +def test_only_global_curator_can_update_public_file_connector_files() -> None: + admin_user = UserManager.create(name="admin_user") + + global_curator_creator = UserManager.create(name="global_curator_creator") + global_curator_creator = UserManager.set_role( + user_to_set=global_curator_creator, + target_role=UserRole.GLOBAL_CURATOR, + user_performing_action=admin_user, + ) + + global_curator_editor = UserManager.create(name="global_curator_editor") + global_curator_editor = UserManager.set_role( + user_to_set=global_curator_editor, + target_role=UserRole.GLOBAL_CURATOR, + user_performing_action=admin_user, + ) + + curator_user = UserManager.create(name="curator_user") + curator_group = UserGroupManager.create( + name="curator_group", + user_ids=[curator_user.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[curator_group], + user_performing_action=admin_user, + ) + UserGroupManager.set_curator_status( + test_user_group=curator_group, + user_to_set_as_curator=curator_user, + user_performing_action=admin_user, + ) + + initial_file_id, initial_file_name = _upload_connector_file( + user_performing_action=global_curator_creator, + file_name="initial-file.txt", + content=b"initial file content", + ) + + connector = ConnectorManager.create( + user_performing_action=global_curator_creator, + name="public_file_connector", + source=DocumentSource.FILE, + connector_specific_config={ + "file_locations": [initial_file_id], + "file_names": [initial_file_name], + "zip_metadata_file_id": None, + }, + access_type=AccessType.PUBLIC, + groups=[], + ) + credential = CredentialManager.create( + user_performing_action=global_curator_creator, + source=DocumentSource.FILE, + curator_public=True, + groups=[], + name="public_file_connector_credential", + ) + CCPairManager.create( + connector_id=connector.id, + credential_id=credential.id, + user_performing_action=global_curator_creator, + access_type=AccessType.PUBLIC, + groups=[], + name="public_file_connector_cc_pair", + ) + + curator_list_response = _list_connector_files( + connector_id=connector.id, + user_performing_action=curator_user, + ) + curator_list_response.raise_for_status() + curator_list_payload = curator_list_response.json() + assert any(f["file_id"] == initial_file_id for f in curator_list_payload["files"]) + + global_curator_list_response = _list_connector_files( + connector_id=connector.id, + user_performing_action=global_curator_editor, + ) + global_curator_list_response.raise_for_status() + global_curator_list_payload = global_curator_list_response.json() + assert any( + f["file_id"] == initial_file_id for f in global_curator_list_payload["files"] + ) + + denied_response = _update_connector_files( + connector_id=connector.id, + user_performing_action=curator_user, + file_ids_to_remove=[initial_file_id], + new_file_name="curator-file.txt", + new_file_content=b"curator updated file", + ) + assert denied_response.status_code == 403 + + allowed_response = _update_connector_files( + connector_id=connector.id, + user_performing_action=global_curator_editor, + file_ids_to_remove=[initial_file_id], + new_file_name="global-curator-file.txt", + new_file_content=b"global curator updated file", + ) + allowed_response.raise_for_status() + + payload = allowed_response.json() + assert initial_file_id not in payload["file_paths"] + assert "global-curator-file.txt" in payload["file_names"] + + creator_group = UserGroupManager.create( + name="creator_group", + user_ids=[global_curator_creator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[creator_group], + user_performing_action=admin_user, + ) + + private_file_id, private_file_name = _upload_connector_file( + user_performing_action=global_curator_creator, + file_name="private-initial-file.txt", + content=b"private initial file content", + ) + + private_connector = ConnectorManager.create( + user_performing_action=global_curator_creator, + name="private_file_connector", + source=DocumentSource.FILE, + connector_specific_config={ + "file_locations": [private_file_id], + "file_names": [private_file_name], + "zip_metadata_file_id": None, + }, + access_type=AccessType.PRIVATE, + groups=[creator_group.id], + ) + private_credential = CredentialManager.create( + user_performing_action=global_curator_creator, + source=DocumentSource.FILE, + curator_public=False, + groups=[creator_group.id], + name="private_file_connector_credential", + ) + CCPairManager.create( + connector_id=private_connector.id, + credential_id=private_credential.id, + user_performing_action=global_curator_creator, + access_type=AccessType.PRIVATE, + groups=[creator_group.id], + name="private_file_connector_cc_pair", + ) + + private_denied_response = _update_connector_files( + connector_id=private_connector.id, + user_performing_action=global_curator_editor, + file_ids_to_remove=[private_file_id], + new_file_name="global-curator-private-file.txt", + new_file_content=b"global curator private update", + ) + assert private_denied_response.status_code == 403 diff --git a/backend/tests/integration/tests/scim/test_scim_groups.py b/backend/tests/integration/tests/scim/test_scim_groups.py new file mode 100644 index 00000000000..03b8b8e576d --- /dev/null +++ b/backend/tests/integration/tests/scim/test_scim_groups.py @@ -0,0 +1,552 @@ +"""Integration tests for SCIM group provisioning endpoints. + +Covers the full group lifecycle as driven by an IdP (Okta / Azure AD): +1. Create a group via POST /Groups +2. Retrieve a group via GET /Groups/{id} +3. List, filter, and paginate groups via GET /Groups +4. Replace a group via PUT /Groups/{id} +5. Patch a group (add/remove members, rename) via PATCH /Groups/{id} +6. Delete a group via DELETE /Groups/{id} +7. Error cases: duplicate name, not-found, invalid member IDs + +All tests are parameterized across IdP request styles (Okta sends lowercase +PATCH ops; Entra sends capitalized ops like ``"Replace"``). The server +normalizes both — these tests verify that. + +Auth tests live in test_scim_tokens.py. +User lifecycle tests live in test_scim_users.py. +""" + +import pytest +import requests + +from onyx.auth.schemas import UserRole +from tests.integration.common_utils.managers.scim_client import ScimClient +from tests.integration.common_utils.managers.scim_token import ScimTokenManager + + +SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group" +SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User" +SCIM_PATCH_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp" + + +@pytest.fixture(scope="module", params=["okta", "entra"]) +def idp_style(request: pytest.FixtureRequest) -> str: + """Parameterized IdP style — runs every test with both Okta and Entra request formats.""" + return request.param + + +@pytest.fixture(scope="module") +def scim_token(idp_style: str) -> str: + """Create a single SCIM token shared across all tests in this module. + + Creating a new token revokes the previous one, so we create exactly once + per IdP-style run and reuse. Uses UserManager directly to avoid + fixture-scope conflicts with the function-scoped admin_user fixture. + """ + from tests.integration.common_utils.constants import ADMIN_USER_NAME + from tests.integration.common_utils.constants import GENERAL_HEADERS + from tests.integration.common_utils.managers.user import build_email + from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD + from tests.integration.common_utils.managers.user import UserManager + from tests.integration.common_utils.test_models import DATestUser + + try: + admin = UserManager.create(name=ADMIN_USER_NAME) + except Exception: + admin = UserManager.login_as_user( + DATestUser( + id="", + email=build_email(ADMIN_USER_NAME), + password=DEFAULT_PASSWORD, + headers=GENERAL_HEADERS, + role=UserRole.ADMIN, + is_active=True, + ) + ) + + token = ScimTokenManager.create( + name=f"scim-group-tests-{idp_style}", + user_performing_action=admin, + ).raw_token + assert token is not None + return token + + +def _make_group_resource( + display_name: str, + external_id: str | None = None, + members: list[dict] | None = None, +) -> dict: + """Build a minimal SCIM GroupResource payload.""" + resource: dict = { + "schemas": [SCIM_GROUP_SCHEMA], + "displayName": display_name, + } + if external_id is not None: + resource["externalId"] = external_id + if members is not None: + resource["members"] = members + return resource + + +def _make_user_resource(email: str, external_id: str) -> dict: + """Build a minimal SCIM UserResource payload for member creation.""" + return { + "schemas": [SCIM_USER_SCHEMA], + "userName": email, + "externalId": external_id, + "name": {"givenName": "Test", "familyName": "User"}, + "active": True, + } + + +def _make_patch_request(operations: list[dict], idp_style: str = "okta") -> dict: + """Build a SCIM PatchOp payload, applying IdP-specific operation casing. + + Entra sends capitalized operations (e.g. ``"Replace"`` instead of + ``"replace"``). The server's ``normalize_operation`` validator lowercases + them — these tests verify that both casings are accepted. + """ + cased_operations = [] + for operation in operations: + cased = dict(operation) + if idp_style == "entra": + cased["op"] = operation["op"].capitalize() + cased_operations.append(cased) + return { + "schemas": [SCIM_PATCH_SCHEMA], + "Operations": cased_operations, + } + + +def _create_scim_user(token: str, email: str, external_id: str) -> requests.Response: + return ScimClient.post( + "/Users", token, json=_make_user_resource(email, external_id) + ) + + +def _create_scim_group( + token: str, + display_name: str, + external_id: str | None = None, + members: list[dict] | None = None, +) -> requests.Response: + return ScimClient.post( + "/Groups", + token, + json=_make_group_resource(display_name, external_id, members), + ) + + +# ------------------------------------------------------------------ +# Lifecycle: create → get → list → replace → patch → delete +# ------------------------------------------------------------------ + + +def test_create_group(scim_token: str, idp_style: str) -> None: + """POST /Groups creates a group and returns 201.""" + name = f"Engineering {idp_style}" + resp = _create_scim_group(scim_token, name, external_id=f"ext-eng-{idp_style}") + assert resp.status_code == 201 + + body = resp.json() + assert body["displayName"] == name + assert body["externalId"] == f"ext-eng-{idp_style}" + assert body["id"] # integer ID assigned by server + assert body["meta"]["resourceType"] == "Group" + + +def test_create_group_with_members(scim_token: str, idp_style: str) -> None: + """POST /Groups with members populates the member list.""" + user = _create_scim_user( + scim_token, f"grp_member1_{idp_style}@example.com", f"ext-gm-{idp_style}" + ).json() + + resp = _create_scim_group( + scim_token, + f"Backend Team {idp_style}", + external_id=f"ext-backend-{idp_style}", + members=[{"value": user["id"]}], + ) + assert resp.status_code == 201 + + body = resp.json() + member_ids = [m["value"] for m in body["members"]] + assert user["id"] in member_ids + + +def test_get_group(scim_token: str, idp_style: str) -> None: + """GET /Groups/{id} returns the group resource including members.""" + user = _create_scim_user( + scim_token, f"grp_get_m_{idp_style}@example.com", f"ext-ggm-{idp_style}" + ).json() + created = _create_scim_group( + scim_token, + f"Frontend Team {idp_style}", + external_id=f"ext-fe-{idp_style}", + members=[{"value": user["id"]}], + ).json() + + resp = ScimClient.get(f"/Groups/{created['id']}", scim_token) + assert resp.status_code == 200 + + body = resp.json() + assert body["id"] == created["id"] + assert body["displayName"] == f"Frontend Team {idp_style}" + assert body["externalId"] == f"ext-fe-{idp_style}" + member_ids = [m["value"] for m in body["members"]] + assert user["id"] in member_ids + + +def test_list_groups(scim_token: str, idp_style: str) -> None: + """GET /Groups returns a ListResponse containing provisioned groups.""" + name = f"DevOps Team {idp_style}" + _create_scim_group(scim_token, name, external_id=f"ext-devops-{idp_style}") + + resp = ScimClient.get("/Groups", scim_token) + assert resp.status_code == 200 + + body = resp.json() + assert body["totalResults"] >= 1 + names = [r["displayName"] for r in body["Resources"]] + assert name in names + + +def test_list_groups_pagination(scim_token: str, idp_style: str) -> None: + """GET /Groups with startIndex and count returns correct pagination.""" + _create_scim_group( + scim_token, f"Page Group A {idp_style}", external_id=f"ext-page-a-{idp_style}" + ) + _create_scim_group( + scim_token, f"Page Group B {idp_style}", external_id=f"ext-page-b-{idp_style}" + ) + + resp = ScimClient.get("/Groups?startIndex=1&count=1", scim_token) + assert resp.status_code == 200 + + body = resp.json() + assert body["startIndex"] == 1 + assert body["itemsPerPage"] == 1 + assert body["totalResults"] >= 2 + assert len(body["Resources"]) == 1 + + +def test_filter_groups_by_display_name(scim_token: str, idp_style: str) -> None: + """GET /Groups?filter=displayName eq '...' returns only matching groups.""" + name = f"Unique QA Team {idp_style}" + _create_scim_group(scim_token, name, external_id=f"ext-qa-filter-{idp_style}") + + resp = ScimClient.get(f'/Groups?filter=displayName eq "{name}"', scim_token) + assert resp.status_code == 200 + + body = resp.json() + assert body["totalResults"] == 1 + assert body["Resources"][0]["displayName"] == name + + +def test_filter_groups_by_external_id(scim_token: str, idp_style: str) -> None: + """GET /Groups?filter=externalId eq '...' returns the matching group.""" + ext_id = f"ext-unique-group-id-{idp_style}" + _create_scim_group( + scim_token, f"ExtId Filter Group {idp_style}", external_id=ext_id + ) + + resp = ScimClient.get(f'/Groups?filter=externalId eq "{ext_id}"', scim_token) + assert resp.status_code == 200 + + body = resp.json() + assert body["totalResults"] == 1 + assert body["Resources"][0]["externalId"] == ext_id + + +def test_replace_group(scim_token: str, idp_style: str) -> None: + """PUT /Groups/{id} replaces the group resource.""" + created = _create_scim_group( + scim_token, + f"Original Name {idp_style}", + external_id=f"ext-replace-g-{idp_style}", + ).json() + + user = _create_scim_user( + scim_token, f"grp_replace_m_{idp_style}@example.com", f"ext-grm-{idp_style}" + ).json() + + updated_resource = _make_group_resource( + display_name=f"Renamed Group {idp_style}", + external_id=f"ext-replace-g-{idp_style}", + members=[{"value": user["id"]}], + ) + resp = ScimClient.put(f"/Groups/{created['id']}", scim_token, json=updated_resource) + assert resp.status_code == 200 + + body = resp.json() + assert body["displayName"] == f"Renamed Group {idp_style}" + member_ids = [m["value"] for m in body["members"]] + assert user["id"] in member_ids + + +def test_replace_group_clears_members(scim_token: str, idp_style: str) -> None: + """PUT /Groups/{id} with empty members removes all memberships.""" + user = _create_scim_user( + scim_token, f"grp_clear_m_{idp_style}@example.com", f"ext-gcm-{idp_style}" + ).json() + created = _create_scim_group( + scim_token, + f"Clear Members Group {idp_style}", + external_id=f"ext-clear-g-{idp_style}", + members=[{"value": user["id"]}], + ).json() + + assert len(created["members"]) == 1 + + resp = ScimClient.put( + f"/Groups/{created['id']}", + scim_token, + json=_make_group_resource( + f"Clear Members Group {idp_style}", f"ext-clear-g-{idp_style}", members=[] + ), + ) + assert resp.status_code == 200 + assert resp.json()["members"] == [] + + +def test_patch_add_member(scim_token: str, idp_style: str) -> None: + """PATCH /Groups/{id} with op=add adds a member.""" + created = _create_scim_group( + scim_token, + f"Patch Add Group {idp_style}", + external_id=f"ext-patch-add-{idp_style}", + ).json() + user = _create_scim_user( + scim_token, f"grp_patch_add_{idp_style}@example.com", f"ext-gpa-{idp_style}" + ).json() + + resp = ScimClient.patch( + f"/Groups/{created['id']}", + scim_token, + json=_make_patch_request( + [{"op": "add", "path": "members", "value": [{"value": user["id"]}]}], + idp_style, + ), + ) + assert resp.status_code == 200 + + member_ids = [m["value"] for m in resp.json()["members"]] + assert user["id"] in member_ids + + +def test_patch_remove_member(scim_token: str, idp_style: str) -> None: + """PATCH /Groups/{id} with op=remove removes a specific member.""" + user = _create_scim_user( + scim_token, f"grp_patch_rm_{idp_style}@example.com", f"ext-gpr-{idp_style}" + ).json() + created = _create_scim_group( + scim_token, + f"Patch Remove Group {idp_style}", + external_id=f"ext-patch-rm-{idp_style}", + members=[{"value": user["id"]}], + ).json() + assert len(created["members"]) == 1 + + resp = ScimClient.patch( + f"/Groups/{created['id']}", + scim_token, + json=_make_patch_request( + [ + { + "op": "remove", + "path": f'members[value eq "{user["id"]}"]', + } + ], + idp_style, + ), + ) + assert resp.status_code == 200 + assert resp.json()["members"] == [] + + +def test_patch_replace_members(scim_token: str, idp_style: str) -> None: + """PATCH /Groups/{id} with op=replace on members swaps the entire list.""" + user_a = _create_scim_user( + scim_token, f"grp_repl_a_{idp_style}@example.com", f"ext-gra-{idp_style}" + ).json() + user_b = _create_scim_user( + scim_token, f"grp_repl_b_{idp_style}@example.com", f"ext-grb-{idp_style}" + ).json() + created = _create_scim_group( + scim_token, + f"Patch Replace Group {idp_style}", + external_id=f"ext-patch-repl-{idp_style}", + members=[{"value": user_a["id"]}], + ).json() + + # Replace member list: swap A for B + resp = ScimClient.patch( + f"/Groups/{created['id']}", + scim_token, + json=_make_patch_request( + [ + { + "op": "replace", + "path": "members", + "value": [{"value": user_b["id"]}], + } + ], + idp_style, + ), + ) + assert resp.status_code == 200 + + member_ids = [m["value"] for m in resp.json()["members"]] + assert user_b["id"] in member_ids + assert user_a["id"] not in member_ids + + +def test_patch_rename_group(scim_token: str, idp_style: str) -> None: + """PATCH /Groups/{id} with op=replace on displayName renames the group.""" + created = _create_scim_group( + scim_token, + f"Old Group Name {idp_style}", + external_id=f"ext-rename-g-{idp_style}", + ).json() + + new_name = f"New Group Name {idp_style}" + resp = ScimClient.patch( + f"/Groups/{created['id']}", + scim_token, + json=_make_patch_request( + [{"op": "replace", "path": "displayName", "value": new_name}], + idp_style, + ), + ) + assert resp.status_code == 200 + assert resp.json()["displayName"] == new_name + + # Confirm via GET + get_resp = ScimClient.get(f"/Groups/{created['id']}", scim_token) + assert get_resp.json()["displayName"] == new_name + + +def test_delete_group(scim_token: str, idp_style: str) -> None: + """DELETE /Groups/{id} removes the group.""" + created = _create_scim_group( + scim_token, + f"Delete Me Group {idp_style}", + external_id=f"ext-del-g-{idp_style}", + ).json() + + resp = ScimClient.delete(f"/Groups/{created['id']}", scim_token) + assert resp.status_code == 204 + + # Second DELETE returns 404 (group hard-deleted) + resp2 = ScimClient.delete(f"/Groups/{created['id']}", scim_token) + assert resp2.status_code == 404 + + +def test_delete_group_preserves_members(scim_token: str, idp_style: str) -> None: + """DELETE /Groups/{id} removes memberships but does not deactivate users.""" + user = _create_scim_user( + scim_token, f"grp_del_member_{idp_style}@example.com", f"ext-gdm-{idp_style}" + ).json() + created = _create_scim_group( + scim_token, + f"Delete With Members {idp_style}", + external_id=f"ext-del-wm-{idp_style}", + members=[{"value": user["id"]}], + ).json() + + resp = ScimClient.delete(f"/Groups/{created['id']}", scim_token) + assert resp.status_code == 204 + + # User should still be active and retrievable + user_resp = ScimClient.get(f"/Users/{user['id']}", scim_token) + assert user_resp.status_code == 200 + assert user_resp.json()["active"] is True + + +# ------------------------------------------------------------------ +# Error cases +# ------------------------------------------------------------------ + + +def test_create_group_duplicate_name(scim_token: str, idp_style: str) -> None: + """POST /Groups with an already-taken displayName returns 409.""" + name = f"Dup Name Group {idp_style}" + resp1 = _create_scim_group(scim_token, name, external_id=f"ext-dup-g1-{idp_style}") + assert resp1.status_code == 201 + + resp2 = _create_scim_group(scim_token, name, external_id=f"ext-dup-g2-{idp_style}") + assert resp2.status_code == 409 + + +def test_get_nonexistent_group(scim_token: str) -> None: + """GET /Groups/{bad-id} returns 404.""" + resp = ScimClient.get("/Groups/999999999", scim_token) + assert resp.status_code == 404 + + +def test_create_group_with_invalid_member(scim_token: str, idp_style: str) -> None: + """POST /Groups with a non-existent member UUID returns 400.""" + resp = _create_scim_group( + scim_token, + f"Bad Member Group {idp_style}", + external_id=f"ext-bad-m-{idp_style}", + members=[{"value": "00000000-0000-0000-0000-000000000000"}], + ) + assert resp.status_code == 400 + assert "not found" in resp.json()["detail"].lower() + + +def test_patch_add_nonexistent_member(scim_token: str, idp_style: str) -> None: + """PATCH /Groups/{id} adding a non-existent member returns 400.""" + created = _create_scim_group( + scim_token, + f"Patch Bad Member Group {idp_style}", + external_id=f"ext-pbm-{idp_style}", + ).json() + + resp = ScimClient.patch( + f"/Groups/{created['id']}", + scim_token, + json=_make_patch_request( + [ + { + "op": "add", + "path": "members", + "value": [{"value": "00000000-0000-0000-0000-000000000000"}], + } + ], + idp_style, + ), + ) + assert resp.status_code == 400 + assert "not found" in resp.json()["detail"].lower() + + +def test_patch_add_duplicate_member_is_idempotent( + scim_token: str, idp_style: str +) -> None: + """PATCH /Groups/{id} adding an already-present member succeeds silently.""" + user = _create_scim_user( + scim_token, f"grp_dup_add_{idp_style}@example.com", f"ext-gda-{idp_style}" + ).json() + created = _create_scim_group( + scim_token, + f"Idempotent Add Group {idp_style}", + external_id=f"ext-idem-g-{idp_style}", + members=[{"value": user["id"]}], + ).json() + assert len(created["members"]) == 1 + + # Add same member again + resp = ScimClient.patch( + f"/Groups/{created['id']}", + scim_token, + json=_make_patch_request( + [{"op": "add", "path": "members", "value": [{"value": user["id"]}]}], + idp_style, + ), + ) + assert resp.status_code == 200 + assert len(resp.json()["members"]) == 1 # still just one member diff --git a/backend/tests/integration/tests/scim/test_scim_tokens.py b/backend/tests/integration/tests/scim/test_scim_tokens.py index 19c95acbbdc..9476df86ef7 100644 --- a/backend/tests/integration/tests/scim/test_scim_tokens.py +++ b/backend/tests/integration/tests/scim/test_scim_tokens.py @@ -15,6 +15,7 @@ import requests from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.managers.scim_client import ScimClient from tests.integration.common_utils.managers.scim_token import ScimTokenManager from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.test_models import DATestUser @@ -39,7 +40,7 @@ def test_scim_token_lifecycle(admin_user: DATestUser) -> None: assert active == token.model_copy(update={"raw_token": None}) # Token works for SCIM requests - response = ScimTokenManager.scim_get("/Users", token.raw_token) + response = ScimClient.get("/Users", token.raw_token) assert response.status_code == 200 body = response.json() assert "Resources" in body @@ -54,7 +55,7 @@ def test_scim_token_rotation_revokes_previous(admin_user: DATestUser) -> None: ) assert first.raw_token is not None - response = ScimTokenManager.scim_get("/Users", first.raw_token) + response = ScimClient.get("/Users", first.raw_token) assert response.status_code == 200 # Create second token — should revoke first @@ -69,25 +70,22 @@ def test_scim_token_rotation_revokes_previous(admin_user: DATestUser) -> None: assert active == second.model_copy(update={"raw_token": None}) # First token rejected, second works - assert ScimTokenManager.scim_get("/Users", first.raw_token).status_code == 401 - assert ScimTokenManager.scim_get("/Users", second.raw_token).status_code == 200 + assert ScimClient.get("/Users", first.raw_token).status_code == 401 + assert ScimClient.get("/Users", second.raw_token).status_code == 200 def test_scim_request_without_token_rejected( admin_user: DATestUser, # noqa: ARG001 ) -> None: """SCIM endpoints reject requests with no Authorization header.""" - assert ScimTokenManager.scim_get_no_auth("/Users").status_code == 401 + assert ScimClient.get_no_auth("/Users").status_code == 401 def test_scim_request_with_bad_token_rejected( admin_user: DATestUser, # noqa: ARG001 ) -> None: """SCIM endpoints reject requests with an invalid token.""" - assert ( - ScimTokenManager.scim_get("/Users", "onyx_scim_bogus_token_value").status_code - == 401 - ) + assert ScimClient.get("/Users", "onyx_scim_bogus_token_value").status_code == 401 def test_non_admin_cannot_create_token( @@ -139,7 +137,7 @@ def test_service_discovery_no_auth_required( ) -> None: """Service discovery endpoints work without any authentication.""" for path in ["/ServiceProviderConfig", "/ResourceTypes", "/Schemas"]: - response = ScimTokenManager.scim_get_no_auth(path) + response = ScimClient.get_no_auth(path) assert response.status_code == 200, f"{path} returned {response.status_code}" @@ -158,7 +156,7 @@ def test_last_used_at_updated_after_scim_request( assert active.last_used_at is None # Make a SCIM request, then verify last_used_at is set - assert ScimTokenManager.scim_get("/Users", token.raw_token).status_code == 200 + assert ScimClient.get("/Users", token.raw_token).status_code == 200 time.sleep(0.5) active_after = ScimTokenManager.get_active(user_performing_action=admin_user) diff --git a/backend/tests/integration/tests/scim/test_scim_users.py b/backend/tests/integration/tests/scim/test_scim_users.py new file mode 100644 index 00000000000..c7c844175c9 --- /dev/null +++ b/backend/tests/integration/tests/scim/test_scim_users.py @@ -0,0 +1,520 @@ +"""Integration tests for SCIM user provisioning endpoints. + +Covers the full user lifecycle as driven by an IdP (Okta / Azure AD): +1. Create a user via POST /Users +2. Retrieve a user via GET /Users/{id} +3. List, filter, and paginate users via GET /Users +4. Replace a user via PUT /Users/{id} +5. Patch a user (deactivate/reactivate) via PATCH /Users/{id} +6. Delete a user via DELETE /Users/{id} +7. Error cases: missing externalId, duplicate email, not-found, seat limit + +All tests are parameterized across IdP request styles: +- **Okta**: lowercase PATCH ops, minimal payloads (core schema only). +- **Entra**: capitalized ops (``"Replace"``), enterprise extension data + (department, manager), and structured email arrays. + +The server normalizes both — these tests verify that all IdP-specific fields +are accepted and round-tripped correctly. + +Auth, revoked-token, and service-discovery tests live in test_scim_tokens.py. +""" + +from datetime import datetime +from datetime import timedelta +from datetime import timezone + +import pytest +import redis +import requests + +from ee.onyx.server.license.models import LicenseMetadata +from ee.onyx.server.license.models import LicenseSource +from ee.onyx.server.license.models import PlanType +from onyx.auth.schemas import UserRole +from onyx.configs.app_configs import REDIS_DB_NUMBER +from onyx.configs.app_configs import REDIS_HOST +from onyx.configs.app_configs import REDIS_PORT +from onyx.server.settings.models import ApplicationStatus +from tests.integration.common_utils.managers.scim_client import ScimClient +from tests.integration.common_utils.managers.scim_token import ScimTokenManager + + +SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User" +SCIM_ENTERPRISE_USER_SCHEMA = ( + "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User" +) +SCIM_PATCH_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp" + +_LICENSE_REDIS_KEY = "public:license:metadata" + + +@pytest.fixture(scope="module", params=["okta", "entra"]) +def idp_style(request: pytest.FixtureRequest) -> str: + """Parameterized IdP style — runs every test with both Okta and Entra request formats.""" + return request.param + + +@pytest.fixture(scope="module") +def scim_token(idp_style: str) -> str: + """Create a single SCIM token shared across all tests in this module. + + Creating a new token revokes the previous one, so we create exactly once + per IdP-style run and reuse. Uses UserManager directly to avoid + fixture-scope conflicts with the function-scoped admin_user fixture. + """ + from tests.integration.common_utils.constants import ADMIN_USER_NAME + from tests.integration.common_utils.constants import GENERAL_HEADERS + from tests.integration.common_utils.managers.user import build_email + from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD + from tests.integration.common_utils.managers.user import UserManager + from tests.integration.common_utils.test_models import DATestUser + + try: + admin = UserManager.create(name=ADMIN_USER_NAME) + except Exception: + admin = UserManager.login_as_user( + DATestUser( + id="", + email=build_email(ADMIN_USER_NAME), + password=DEFAULT_PASSWORD, + headers=GENERAL_HEADERS, + role=UserRole.ADMIN, + is_active=True, + ) + ) + + token = ScimTokenManager.create( + name=f"scim-user-tests-{idp_style}", + user_performing_action=admin, + ).raw_token + assert token is not None + return token + + +def _make_user_resource( + email: str, + external_id: str, + given_name: str = "Test", + family_name: str = "User", + active: bool = True, + idp_style: str = "okta", + department: str | None = None, + manager_id: str | None = None, +) -> dict: + """Build a SCIM UserResource payload appropriate for the IdP style. + + Entra sends richer payloads including enterprise extension data (department, + manager), structured email arrays, and the enterprise schema URN. Okta sends + minimal payloads with just core user fields. + """ + resource: dict = { + "schemas": [SCIM_USER_SCHEMA], + "userName": email, + "externalId": external_id, + "name": { + "givenName": given_name, + "familyName": family_name, + }, + "active": active, + } + if idp_style == "entra": + dept = department or "Engineering" + mgr = manager_id or "mgr-ext-001" + resource["schemas"].append(SCIM_ENTERPRISE_USER_SCHEMA) + resource[SCIM_ENTERPRISE_USER_SCHEMA] = { + "department": dept, + "manager": {"value": mgr}, + } + resource["emails"] = [ + {"value": email, "type": "work", "primary": True}, + ] + return resource + + +def _make_patch_request(operations: list[dict], idp_style: str = "okta") -> dict: + """Build a SCIM PatchOp payload, applying IdP-specific operation casing. + + Entra sends capitalized operations (e.g. ``"Replace"`` instead of + ``"replace"``). The server's ``normalize_operation`` validator lowercases + them — these tests verify that both casings are accepted. + """ + cased_operations = [] + for operation in operations: + cased = dict(operation) + if idp_style == "entra": + cased["op"] = operation["op"].capitalize() + cased_operations.append(cased) + return { + "schemas": [SCIM_PATCH_SCHEMA], + "Operations": cased_operations, + } + + +def _create_scim_user( + token: str, + email: str, + external_id: str, + idp_style: str = "okta", +) -> requests.Response: + return ScimClient.post( + "/Users", + token, + json=_make_user_resource(email, external_id, idp_style=idp_style), + ) + + +def _assert_entra_extension( + body: dict, + expected_department: str = "Engineering", + expected_manager: str = "mgr-ext-001", +) -> None: + """Assert that Entra enterprise extension fields round-tripped correctly.""" + assert SCIM_ENTERPRISE_USER_SCHEMA in body["schemas"] + ext = body[SCIM_ENTERPRISE_USER_SCHEMA] + assert ext["department"] == expected_department + assert ext["manager"]["value"] == expected_manager + + +def _assert_entra_emails(body: dict, expected_email: str) -> None: + """Assert that structured email metadata round-tripped correctly.""" + emails = body["emails"] + assert len(emails) >= 1 + work_email = next(e for e in emails if e.get("type") == "work") + assert work_email["value"] == expected_email + assert work_email["primary"] is True + + +# ------------------------------------------------------------------ +# Lifecycle: create -> get -> list -> replace -> patch -> delete +# ------------------------------------------------------------------ + + +def test_create_user(scim_token: str, idp_style: str) -> None: + """POST /Users creates a provisioned user and returns 201.""" + email = f"scim_create_{idp_style}@example.com" + ext_id = f"ext-create-{idp_style}" + resp = _create_scim_user(scim_token, email, ext_id, idp_style) + assert resp.status_code == 201 + + body = resp.json() + assert body["userName"] == email + assert body["externalId"] == ext_id + assert body["active"] is True + assert body["id"] # UUID assigned by server + assert body["meta"]["resourceType"] == "User" + assert body["name"]["givenName"] == "Test" + assert body["name"]["familyName"] == "User" + + if idp_style == "entra": + _assert_entra_extension(body) + _assert_entra_emails(body, email) + + +def test_get_user(scim_token: str, idp_style: str) -> None: + """GET /Users/{id} returns the user resource with all stored fields.""" + email = f"scim_get_{idp_style}@example.com" + ext_id = f"ext-get-{idp_style}" + created = _create_scim_user(scim_token, email, ext_id, idp_style).json() + + resp = ScimClient.get(f"/Users/{created['id']}", scim_token) + assert resp.status_code == 200 + + body = resp.json() + assert body["id"] == created["id"] + assert body["userName"] == email + assert body["externalId"] == ext_id + assert body["name"]["givenName"] == "Test" + assert body["name"]["familyName"] == "User" + + if idp_style == "entra": + _assert_entra_extension(body) + _assert_entra_emails(body, email) + + +def test_list_users(scim_token: str, idp_style: str) -> None: + """GET /Users returns a ListResponse containing provisioned users.""" + email = f"scim_list_{idp_style}@example.com" + _create_scim_user(scim_token, email, f"ext-list-{idp_style}", idp_style) + + resp = ScimClient.get("/Users", scim_token) + assert resp.status_code == 200 + + body = resp.json() + assert body["totalResults"] >= 1 + emails = [r["userName"] for r in body["Resources"]] + assert email in emails + + +def test_list_users_pagination(scim_token: str, idp_style: str) -> None: + """GET /Users with startIndex and count returns correct pagination.""" + _create_scim_user( + scim_token, + f"scim_page1_{idp_style}@example.com", + f"ext-page-1-{idp_style}", + idp_style, + ) + _create_scim_user( + scim_token, + f"scim_page2_{idp_style}@example.com", + f"ext-page-2-{idp_style}", + idp_style, + ) + + resp = ScimClient.get("/Users?startIndex=1&count=1", scim_token) + assert resp.status_code == 200 + + body = resp.json() + assert body["startIndex"] == 1 + assert body["itemsPerPage"] == 1 + assert body["totalResults"] >= 2 + assert len(body["Resources"]) == 1 + + +def test_filter_users_by_username(scim_token: str, idp_style: str) -> None: + """GET /Users?filter=userName eq '...' returns only matching users.""" + email = f"scim_filter_{idp_style}@example.com" + _create_scim_user(scim_token, email, f"ext-filter-{idp_style}", idp_style) + + resp = ScimClient.get(f'/Users?filter=userName eq "{email}"', scim_token) + assert resp.status_code == 200 + + body = resp.json() + assert body["totalResults"] == 1 + assert body["Resources"][0]["userName"] == email + + +def test_replace_user(scim_token: str, idp_style: str) -> None: + """PUT /Users/{id} replaces the user resource including enterprise fields.""" + email = f"scim_replace_{idp_style}@example.com" + ext_id = f"ext-replace-{idp_style}" + created = _create_scim_user(scim_token, email, ext_id, idp_style).json() + + updated_resource = _make_user_resource( + email=email, + external_id=ext_id, + given_name="Updated", + family_name="Name", + idp_style=idp_style, + department="Product", + ) + resp = ScimClient.put(f"/Users/{created['id']}", scim_token, json=updated_resource) + assert resp.status_code == 200 + + body = resp.json() + assert body["name"]["givenName"] == "Updated" + assert body["name"]["familyName"] == "Name" + + if idp_style == "entra": + _assert_entra_extension(body, expected_department="Product") + _assert_entra_emails(body, email) + + +def test_patch_deactivate_user(scim_token: str, idp_style: str) -> None: + """PATCH /Users/{id} with active=false deactivates the user.""" + created = _create_scim_user( + scim_token, + f"scim_deactivate_{idp_style}@example.com", + f"ext-deactivate-{idp_style}", + idp_style, + ).json() + assert created["active"] is True + + resp = ScimClient.patch( + f"/Users/{created['id']}", + scim_token, + json=_make_patch_request( + [{"op": "replace", "path": "active", "value": False}], idp_style + ), + ) + assert resp.status_code == 200 + assert resp.json()["active"] is False + + # Confirm via GET + get_resp = ScimClient.get(f"/Users/{created['id']}", scim_token) + assert get_resp.json()["active"] is False + + +def test_patch_reactivate_user(scim_token: str, idp_style: str) -> None: + """PATCH active=true reactivates a previously deactivated user.""" + created = _create_scim_user( + scim_token, + f"scim_reactivate_{idp_style}@example.com", + f"ext-reactivate-{idp_style}", + idp_style, + ).json() + + # Deactivate + deactivate_resp = ScimClient.patch( + f"/Users/{created['id']}", + scim_token, + json=_make_patch_request( + [{"op": "replace", "path": "active", "value": False}], idp_style + ), + ) + assert deactivate_resp.status_code == 200 + assert deactivate_resp.json()["active"] is False + + # Reactivate + resp = ScimClient.patch( + f"/Users/{created['id']}", + scim_token, + json=_make_patch_request( + [{"op": "replace", "path": "active", "value": True}], idp_style + ), + ) + assert resp.status_code == 200 + assert resp.json()["active"] is True + + +def test_delete_user(scim_token: str, idp_style: str) -> None: + """DELETE /Users/{id} deactivates and removes the SCIM mapping.""" + created = _create_scim_user( + scim_token, + f"scim_delete_{idp_style}@example.com", + f"ext-delete-{idp_style}", + idp_style, + ).json() + + resp = ScimClient.delete(f"/Users/{created['id']}", scim_token) + assert resp.status_code == 204 + + # Second DELETE returns 404 per RFC 7644 §3.6 (mapping removed) + resp2 = ScimClient.delete(f"/Users/{created['id']}", scim_token) + assert resp2.status_code == 404 + + +# ------------------------------------------------------------------ +# Error cases +# ------------------------------------------------------------------ + + +def test_create_user_missing_external_id(scim_token: str, idp_style: str) -> None: + """POST /Users without externalId succeeds (RFC 7643: externalId is optional).""" + email = f"scim_no_extid_{idp_style}@example.com" + resp = ScimClient.post( + "/Users", + scim_token, + json={ + "schemas": [SCIM_USER_SCHEMA], + "userName": email, + "active": True, + }, + ) + assert resp.status_code == 201 + body = resp.json() + assert body["userName"] == email + assert body.get("externalId") is None + + +def test_create_user_duplicate_email(scim_token: str, idp_style: str) -> None: + """POST /Users with an already-taken email returns 409.""" + email = f"scim_dup_{idp_style}@example.com" + resp1 = _create_scim_user(scim_token, email, f"ext-dup-1-{idp_style}", idp_style) + assert resp1.status_code == 201 + + resp2 = _create_scim_user(scim_token, email, f"ext-dup-2-{idp_style}", idp_style) + assert resp2.status_code == 409 + + +def test_get_nonexistent_user(scim_token: str) -> None: + """GET /Users/{bad-id} returns 404.""" + resp = ScimClient.get("/Users/00000000-0000-0000-0000-000000000000", scim_token) + assert resp.status_code == 404 + + +def test_filter_users_by_external_id(scim_token: str, idp_style: str) -> None: + """GET /Users?filter=externalId eq '...' returns the matching user.""" + ext_id = f"ext-unique-filter-id-{idp_style}" + _create_scim_user( + scim_token, f"scim_extfilter_{idp_style}@example.com", ext_id, idp_style + ) + + resp = ScimClient.get(f'/Users?filter=externalId eq "{ext_id}"', scim_token) + assert resp.status_code == 200 + + body = resp.json() + assert body["totalResults"] == 1 + assert body["Resources"][0]["externalId"] == ext_id + + +# ------------------------------------------------------------------ +# Seat-limit enforcement +# ------------------------------------------------------------------ + + +def _seed_license(r: redis.Redis, seats: int) -> None: + """Write a LicenseMetadata entry into Redis with the given seat cap.""" + now = datetime.now(timezone.utc) + metadata = LicenseMetadata( + tenant_id="public", + organization_name="Test Org", + seats=seats, + used_seats=0, # check_seat_availability recalculates from DB + plan_type=PlanType.ANNUAL, + issued_at=now, + expires_at=now + timedelta(days=365), + status=ApplicationStatus.ACTIVE, + source=LicenseSource.MANUAL_UPLOAD, + ) + r.set(_LICENSE_REDIS_KEY, metadata.model_dump_json(), ex=300) + + +def test_create_user_seat_limit(scim_token: str, idp_style: str) -> None: + """POST /Users returns 403 when the seat limit is reached.""" + r = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER) + + # admin_user already occupies 1 seat; cap at 1 -> full + _seed_license(r, seats=1) + + try: + resp = _create_scim_user( + scim_token, + f"scim_blocked_{idp_style}@example.com", + f"ext-blocked-{idp_style}", + idp_style, + ) + assert resp.status_code == 403 + assert "seat" in resp.json()["detail"].lower() + finally: + r.delete(_LICENSE_REDIS_KEY) + + +def test_reactivate_user_seat_limit(scim_token: str, idp_style: str) -> None: + """PATCH active=true returns 403 when the seat limit is reached.""" + # Create and deactivate a user (before license is seeded) + created = _create_scim_user( + scim_token, + f"scim_reactivate_blocked_{idp_style}@example.com", + f"ext-reactivate-blocked-{idp_style}", + idp_style, + ).json() + assert created["active"] is True + + deactivate_resp = ScimClient.patch( + f"/Users/{created['id']}", + scim_token, + json=_make_patch_request( + [{"op": "replace", "path": "active", "value": False}], idp_style + ), + ) + assert deactivate_resp.status_code == 200 + assert deactivate_resp.json()["active"] is False + + r = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER) + + # Seed license capped at current active users -> reactivation should fail + _seed_license(r, seats=1) + + try: + resp = ScimClient.patch( + f"/Users/{created['id']}", + scim_token, + json=_make_patch_request( + [{"op": "replace", "path": "active", "value": True}], idp_style + ), + ) + assert resp.status_code == 403 + assert "seat" in resp.json()["detail"].lower() + finally: + r.delete(_LICENSE_REDIS_KEY) diff --git a/backend/tests/integration/tests/search_settings/test_search_settings.py b/backend/tests/integration/tests/search_settings/test_search_settings.py index 02b7b1e768f..b6debb519ad 100644 --- a/backend/tests/integration/tests/search_settings/test_search_settings.py +++ b/backend/tests/integration/tests/search_settings/test_search_settings.py @@ -1,4 +1,3 @@ -import pytest import requests from tests.integration.common_utils.constants import API_SERVER_URL @@ -365,7 +364,6 @@ def test_update_contextual_rag_missing_model_name( assert "Provider name and model name are required" in response.json()["detail"] -@pytest.mark.skip(reason="Set new search settings is temporarily disabled.") def test_set_new_search_settings_with_contextual_rag( reset: None, # noqa: ARG001 admin_user: DATestUser, @@ -394,7 +392,6 @@ def test_set_new_search_settings_with_contextual_rag( _cancel_new_embedding(admin_user) -@pytest.mark.skip(reason="Set new search settings is temporarily disabled.") def test_set_new_search_settings_without_contextual_rag( reset: None, # noqa: ARG001 admin_user: DATestUser, @@ -419,7 +416,6 @@ def test_set_new_search_settings_without_contextual_rag( _cancel_new_embedding(admin_user) -@pytest.mark.skip(reason="Set new search settings is temporarily disabled.") def test_set_new_then_update_inference_settings( reset: None, # noqa: ARG001 admin_user: DATestUser, @@ -457,7 +453,6 @@ def test_set_new_then_update_inference_settings( _cancel_new_embedding(admin_user) -@pytest.mark.skip(reason="Set new search settings is temporarily disabled.") def test_set_new_search_settings_replaces_previous_secondary( reset: None, # noqa: ARG001 admin_user: DATestUser, diff --git a/backend/tests/integration/tests/users/test_slack_user_deactivation.py b/backend/tests/integration/tests/users/test_slack_user_deactivation.py new file mode 100644 index 00000000000..d96f5f34a86 --- /dev/null +++ b/backend/tests/integration/tests/users/test_slack_user_deactivation.py @@ -0,0 +1,121 @@ +"""Integration tests for Slack user deactivation and reactivation via admin endpoints. + +Verifies that: +- Slack users can be deactivated by admins +- Deactivated Slack users can be reactivated by admins +- Reactivation is blocked when the seat limit is reached +""" + +from datetime import datetime +from datetime import timedelta + +import redis +import requests + +from ee.onyx.server.license.models import LicenseMetadata +from ee.onyx.server.license.models import LicenseSource +from ee.onyx.server.license.models import PlanType +from onyx.auth.schemas import UserRole +from onyx.configs.app_configs import REDIS_DB_NUMBER +from onyx.configs.app_configs import REDIS_HOST +from onyx.configs.app_configs import REDIS_PORT +from onyx.server.settings.models import ApplicationStatus +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.test_models import DATestUser + +_LICENSE_REDIS_KEY = "public:license:metadata" + + +def _seed_license(r: redis.Redis, seats: int) -> None: + now = datetime.utcnow() + metadata = LicenseMetadata( + tenant_id="public", + organization_name="Test Org", + seats=seats, + used_seats=0, + plan_type=PlanType.ANNUAL, + issued_at=now, + expires_at=now + timedelta(days=365), + status=ApplicationStatus.ACTIVE, + source=LicenseSource.MANUAL_UPLOAD, + ) + r.set(_LICENSE_REDIS_KEY, metadata.model_dump_json(), ex=300) + + +def _clear_license(r: redis.Redis) -> None: + r.delete(_LICENSE_REDIS_KEY) + + +def _redis() -> redis.Redis: + return redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER) + + +def _get_user_is_active(email: str, admin_user: DATestUser) -> bool: + """Look up a user's is_active flag via the admin users list endpoint.""" + result = UserManager.get_user_page( + user_performing_action=admin_user, + search_query=email, + ) + matching = [u for u in result.items if u.email == email] + assert len(matching) == 1, f"Expected exactly 1 user with email {email}" + return matching[0].is_active + + +def test_slack_user_deactivate_and_reactivate(reset: None) -> None: # noqa: ARG001 + """Admin can deactivate and then reactivate a Slack user.""" + admin_user = UserManager.create(name="admin_user") + + slack_user = UserManager.create(name="slack_test_user") + slack_user = UserManager.set_role( + user_to_set=slack_user, + target_role=UserRole.SLACK_USER, + user_performing_action=admin_user, + explicit_override=True, + ) + + # Deactivate the Slack user + UserManager.set_status( + slack_user, target_status=False, user_performing_action=admin_user + ) + assert _get_user_is_active(slack_user.email, admin_user) is False + + # Reactivate the Slack user + UserManager.set_status( + slack_user, target_status=True, user_performing_action=admin_user + ) + assert _get_user_is_active(slack_user.email, admin_user) is True + + +def test_slack_user_reactivation_blocked_by_seat_limit( + reset: None, # noqa: ARG001 +) -> None: + """Reactivating a deactivated Slack user returns 402 when seats are full.""" + r = _redis() + + admin_user = UserManager.create(name="admin_user") + + slack_user = UserManager.create(name="slack_test_user") + slack_user = UserManager.set_role( + user_to_set=slack_user, + target_role=UserRole.SLACK_USER, + user_performing_action=admin_user, + explicit_override=True, + ) + + UserManager.set_status( + slack_user, target_status=False, user_performing_action=admin_user + ) + + # License allows 1 seat — only admin counts + _seed_license(r, seats=1) + + try: + response = requests.patch( + url=f"{API_SERVER_URL}/manage/admin/activate-user", + json={"user_email": slack_user.email}, + headers=admin_user.headers, + ) + assert response.status_code == 402 + finally: + _clear_license(r) diff --git a/backend/tests/unit/build/test_rewrite_asset_paths.py b/backend/tests/unit/build/test_rewrite_asset_paths.py new file mode 100644 index 00000000000..7eb7eb3ce49 --- /dev/null +++ b/backend/tests/unit/build/test_rewrite_asset_paths.py @@ -0,0 +1,256 @@ +"""Unit tests for webapp proxy path rewriting/injection.""" + +from types import SimpleNamespace +from typing import cast +from typing import Literal +from uuid import UUID + +import httpx +import pytest +from fastapi import Request +from sqlalchemy.orm import Session + +from onyx.server.features.build.api import api +from onyx.server.features.build.api.api import _inject_hmr_fixer +from onyx.server.features.build.api.api import _rewrite_asset_paths +from onyx.server.features.build.api.api import _rewrite_proxy_response_headers + +SESSION_ID = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" +BASE = f"/api/build/sessions/{SESSION_ID}/webapp" + + +def rewrite(html: str) -> str: + return _rewrite_asset_paths(html.encode(), SESSION_ID).decode() + + +def inject(html: str) -> str: + return _inject_hmr_fixer(html.encode(), SESSION_ID).decode() + + +class TestNextjsPathRewriting: + def test_rewrites_bare_next_script_src(self) -> None: + html = '