diff --git a/src/cost-attribution.ts b/src/cost-attribution.ts new file mode 100644 index 0000000..d49ba48 --- /dev/null +++ b/src/cost-attribution.ts @@ -0,0 +1,178 @@ +// ─── Cost Attribution ───────────────────────────────────────── +// Maps tool calls to credit costs and enforces quotas via edge-auth. +// Every tool call is metered: check quota before, consume after success. +// Cost data flows to the audit pipeline for billing dashboards. + +import type { Tier, AuthServiceRpc } from './types.js'; +import type { AuditArtifact } from './audit.js'; + +// ─── Credit cost per tool call ────────────────────────────────── +// Costs are in "credits" — 1 credit = 1 unit of the tier's allocation. +// Expensive operations (image gen, deploy) cost more. +export interface ToolCost { + /** Base credit cost for the tool call */ + baseCost: number; + /** Feature key for quota tracking (maps to edge-auth quota.feature) */ + feature: string; +} + +const TOOL_COSTS: Record = { + // img-forge: costs depend on quality tier (resolved at call time) + 'image_generate': { baseCost: 5, feature: 'mcp.image_generate' }, + 'image_list_models': { baseCost: 0, feature: 'mcp.image_list_models' }, + 'image_check_job': { baseCost: 0, feature: 'mcp.image_check_job' }, + + // TarotScript scaffold + 'scaffold_create': { baseCost: 2, feature: 'mcp.scaffold_create' }, + 'scaffold_classify': { baseCost: 0, feature: 'mcp.scaffold_classify' }, + 'scaffold_status': { baseCost: 0, feature: 'mcp.scaffold_status' }, + 'scaffold_publish': { baseCost: 3, feature: 'mcp.scaffold_publish' }, + 'scaffold_deploy': { baseCost: 5, feature: 'mcp.scaffold_deploy' }, + 'scaffold_import': { baseCost: 1, feature: 'mcp.scaffold_import' }, + + // Flow tools + 'flow_create': { baseCost: 2, feature: 'mcp.flow_create' }, + 'flow_status': { baseCost: 0, feature: 'mcp.flow_status' }, + 'flow_summary': { baseCost: 0, feature: 'mcp.flow_summary' }, + 'flow_quality': { baseCost: 0, feature: 'mcp.flow_quality' }, + 'flow_governance': { baseCost: 0, feature: 'mcp.flow_governance' }, + + // Visual QA + 'visual_screenshot': { baseCost: 1, feature: 'mcp.visual_screenshot' }, + 'visual_analyze': { baseCost: 2, feature: 'mcp.visual_analyze' }, + 'visual_pages': { baseCost: 0, feature: 'mcp.visual_pages' }, +}; + +// Quality tier multipliers for image_generate +const IMAGE_QUALITY_MULTIPLIER: Record = { + draft: 1, + standard: 1, + premium: 3, + ultra: 5, + ultra_plus: 8, +}; + +/** + * Resolve the credit cost for a tool call, factoring in quality tier for images. + */ +export function resolveToolCost( + toolName: string, + args?: Record, +): ToolCost { + const base = TOOL_COSTS[toolName]; + if (!base) { + // Unknown tools cost 1 credit by default (conservative) + return { baseCost: 1, feature: `mcp.${toolName}` }; + } + + // Apply quality multiplier for image_generate + if (toolName === 'image_generate' && args?.quality_tier) { + const multiplier = IMAGE_QUALITY_MULTIPLIER[args.quality_tier as string] ?? 1; + return { ...base, baseCost: base.baseCost * multiplier }; + } + + return base; +} + +/** + * Check if a tool call is free (cost = 0). Free calls skip quota enforcement. + */ +export function isFreeTool(toolName: string): boolean { + const cost = TOOL_COSTS[toolName]; + return cost !== undefined && cost.baseCost === 0; +} + +export interface QuotaCheckResult { + allowed: boolean; + reservationId?: string; + remaining?: number; + error?: string; +} + +/** + * Check and reserve quota for a tool call via edge-auth RPC. + * Returns a reservation ID that must be committed or refunded after the call. + */ +export async function reserveQuota( + authService: AuthServiceRpc, + tenantId: string, + userId: string, + toolName: string, + args?: Record, +): Promise { + const cost = resolveToolCost(toolName, args); + + // Free tools don't consume quota + if (cost.baseCost === 0) { + return { allowed: true }; + } + + try { + const result = await authService.consumeQuota({ + tenantId, + userId, + feature: cost.feature, + amount: cost.baseCost, + }); + + if (!result.success) { + return { + allowed: false, + error: result.error ?? 'Quota exceeded', + remaining: result.remaining, + }; + } + + return { + allowed: true, + reservationId: result.reservationId, + remaining: result.remaining, + }; + } catch (err) { + // Quota service unavailable — fail open for read-only tools, closed for mutations + const isReadOnly = cost.baseCost <= 0; + if (isReadOnly) { + return { allowed: true }; + } + return { + allowed: false, + error: 'Quota service unavailable', + }; + } +} + +/** + * Commit or refund a quota reservation based on tool call outcome. + */ +export async function settleQuota( + authService: AuthServiceRpc, + reservationId: string | undefined, + success: boolean, +): Promise { + if (!reservationId) return; + + try { + await authService.commitOrRefundQuota( + reservationId, + success ? 'success' : 'failed', + ); + } catch { + // Best-effort — don't fail the tool call if settlement fails. + // The reservation will auto-expire in edge-auth. + console.error(`[cost] Failed to settle reservation ${reservationId}`); + } +} + +/** + * Build cost attribution data for the audit artifact. + */ +export function buildCostAttribution( + toolName: string, + args?: Record, +): { feature: string; creditCost: number } { + const cost = resolveToolCost(toolName, args); + return { + feature: cost.feature, + creditCost: cost.baseCost, + }; +} diff --git a/src/gateway.ts b/src/gateway.ts index bbd0b3d..a50ab3d 100644 --- a/src/gateway.ts +++ b/src/gateway.ts @@ -12,6 +12,8 @@ import { materializeScaffold } from './scaffold-materializer.js'; import { publishToGitHub } from './scaffold-publish.js'; import { classifyIntention, type IntentClassification } from './intent-classifier.js'; import { logDivergence } from './divergence-logger.js'; +import { checkRateLimit, rateLimitHeaders, type RateLimitResult } from './rate-limiter.js'; +import { reserveQuota, settleQuota, buildCostAttribution, isFreeTool } from './cost-attribution.js'; const MCP_PROTOCOL_VERSION = '2025-03-26'; const JSON_RPC_PARSE_ERROR = -32700; @@ -1010,6 +1012,28 @@ async function handlePost(request: Request, env: GatewayEnv, oauthProps?: OAuthP ); } + // Rate limiting — check before processing + const rateLimitKey = authResult.tenantId ?? authResult.userId ?? 'unknown'; + const rlResult = await checkRateLimit(env.RATELIMIT_KV, rateLimitKey, authResult.tier); + if (!rlResult.allowed) { + audit({ + trace_id: generateTraceId(), + principal: authResult.userId ?? 'unknown', + tenant: authResult.tenantId ?? 'unknown', + tool: 'rate_limit', + risk_level: 'UNKNOWN', + policy_decision: 'DENY', + redacted_input_summary: '{}', + outcome: 'auth_denied', + timestamp: new Date().toISOString(), + }, env); + return jsonResponse( + { error: 'Rate limit exceeded', code: 'RATE_LIMITED' }, + 429, + rateLimitHeaders(rlResult), + ); + } + // Validate Accept header const accept = request.headers.get('Accept') ?? ''; if (!accept.includes('application/json') && !accept.includes('*/*') && accept !== '') { @@ -1065,7 +1089,18 @@ async function handlePost(request: Request, env: GatewayEnv, oauthProps?: OAuthP // ─── tools/list ───────────────────────────────────────── if (rpcMethod === 'tools/list') { // KV handles session expiration via expirationTtl — no manual pruning needed - const tools = buildAggregatedCatalog(); + let tools = buildAggregatedCatalog(); + + // Scope-based filtering: only show tools the session has access to + const hasGenerate = session.scopes.includes('generate'); + if (!hasGenerate) { + // Read-only scope — filter out mutation tools + tools = tools.filter(t => { + const risk = getToolRiskLevel(t.name); + return risk === 'READ_ONLY'; + }); + } + return rpcResult(rpcId, { tools }); } @@ -1108,6 +1143,23 @@ async function handlePost(request: Request, env: GatewayEnv, oauthProps?: OAuthP return rpcError(rpcId, JSON_RPC_METHOD_NOT_FOUND, `Unknown tool: ${toolName}`); } + // Scope enforcement: mutation tools require 'generate' scope + if (risk !== 'READ_ONLY' && !session.scopes.includes('generate')) { + audit({ + trace_id: traceId, + principal: session.userId ?? 'unknown', + tenant: session.tenantId ?? 'unknown', + tool: toolName, + risk_level: risk, + policy_decision: 'DENY', + redacted_input_summary: summarizeInput(toolArgs), + outcome: 'auth_denied', + timestamp: new Date().toISOString(), + }, env); + return rpcError(rpcId, JSON_RPC_INVALID_PARAMS, + `Tool "${toolName}" requires the "generate" scope. Your API key only has: ${session.scopes.join(', ')}`); + } + // Validate arguments are object-shaped const argValidation = validateToolArguments(toolArgs, { type: 'object' }); if (!argValidation.valid) { @@ -1142,8 +1194,65 @@ async function handlePost(request: Request, env: GatewayEnv, oauthProps?: OAuthP return rpcError(rpcId, JSON_RPC_INVALID_PARAMS, tierDenied); } + // ─── Cost attribution: reserve quota before tool call ──── + const costInfo = buildCostAttribution(toolName, toolArgs as Record | undefined); + let quotaReservation: { reservationId?: string } = {}; + + if (!isFreeTool(toolName) && session.tenantId) { + const quotaResult = await reserveQuota( + env.AUTH_SERVICE, + session.tenantId, + session.userId ?? '', + toolName, + toolArgs as Record | undefined, + ); + + if (!quotaResult.allowed) { + audit({ + trace_id: traceId, + principal: session.userId ?? 'unknown', + tenant: session.tenantId ?? 'unknown', + tool: toolName, + risk_level: risk, + policy_decision: 'DENY', + redacted_input_summary: summarizeInput(toolArgs), + outcome: 'tier_denied', + timestamp: new Date().toISOString(), + }, env); + return rpcError(rpcId, JSON_RPC_INVALID_PARAMS, + `Quota exceeded for ${toolName}. ${quotaResult.error ?? 'Upgrade your plan for more credits.'}`); + } + + quotaReservation = { reservationId: quotaResult.reservationId }; + } + const result = await proxyToolCall(env, toolName, toolArgs, session, traceId); - return rpcResult(rpcId, result); + + // ─── Cost attribution: settle quota after tool call ────── + const toolSucceeded = !result.isError; + await settleQuota(env.AUTH_SERVICE, quotaReservation.reservationId, toolSucceeded); + + // Enrich audit queue event with cost data + queueAuditEvent(env.PLATFORM_EVENTS_QUEUE, { + trace_id: traceId, + principal: session.userId ?? 'unknown', + tenant: session.tenantId ?? 'unknown', + tool: toolName, + risk_level: risk, + policy_decision: 'ALLOW', + redacted_input_summary: summarizeInput(toolArgs), + outcome: toolSucceeded ? 'success' : 'error', + timestamp: new Date().toISOString(), + latency_ms: 0, // latency is tracked in proxyToolCall's own audit + }); + + // Add rate limit headers to successful responses + const response = rpcResult(rpcId, result); + const rlHeaders = rateLimitHeaders(rlResult); + for (const [k, v] of Object.entries(rlHeaders)) { + response.headers.set(k, v); + } + return response; } return rpcError(rpcId, JSON_RPC_METHOD_NOT_FOUND, `Unknown method: ${rpcMethod}`); diff --git a/src/rate-limiter.ts b/src/rate-limiter.ts new file mode 100644 index 0000000..83b7594 --- /dev/null +++ b/src/rate-limiter.ts @@ -0,0 +1,108 @@ +// ─── Rate Limiter ───────────────────────────────────────────── +// Sliding window rate limiting per API key / tenant. +// Uses KV with TTL for window expiration — no external dependencies. +// Returns 429 with Retry-After header when limit exceeded. + +import type { Tier } from './types.js'; + +export interface RateLimitConfig { + /** Max requests per window */ + limit: number; + /** Window size in seconds */ + windowSeconds: number; +} + +// Per-tier rate limits — configurable, conservative defaults +const TIER_LIMITS: Record = { + free: { limit: 20, windowSeconds: 60 }, + hobby: { limit: 60, windowSeconds: 60 }, + pro: { limit: 300, windowSeconds: 60 }, + enterprise: { limit: 1000, windowSeconds: 60 }, +}; + +export interface RateLimitResult { + allowed: boolean; + /** Requests remaining in current window */ + remaining: number; + /** Total limit for this window */ + limit: number; + /** Seconds until window resets */ + retryAfterSeconds: number; +} + +interface WindowState { + count: number; + windowStart: number; +} + +const RATE_LIMIT_PREFIX = 'rl:'; + +/** + * Check and increment rate limit for a given key (API key ID, tenant ID, etc.) + * Uses a simple fixed-window approach with KV TTL for auto-cleanup. + */ +export async function checkRateLimit( + kv: KVNamespace, + key: string, + tier: Tier, + config?: RateLimitConfig, +): Promise { + const { limit, windowSeconds } = config ?? TIER_LIMITS[tier] ?? TIER_LIMITS.free; + const now = Math.floor(Date.now() / 1000); + const windowStart = now - (now % windowSeconds); + const kvKey = `${RATE_LIMIT_PREFIX}${key}:${windowStart}`; + + // Read current window count + const raw = await kv.get(kvKey); + let state: WindowState; + + if (raw) { + state = JSON.parse(raw) as WindowState; + } else { + state = { count: 0, windowStart }; + } + + const retryAfterSeconds = windowSeconds - (now - windowStart); + + if (state.count >= limit) { + return { + allowed: false, + remaining: 0, + limit, + retryAfterSeconds, + }; + } + + // Increment + state.count += 1; + // Write back with TTL — window auto-expires + await kv.put(kvKey, JSON.stringify(state), { + expirationTtl: windowSeconds + 10, // small buffer past window end + }); + + return { + allowed: true, + remaining: limit - state.count, + limit, + retryAfterSeconds, + }; +} + +/** + * Build standard rate limit response headers. + */ +export function rateLimitHeaders(result: RateLimitResult): Record { + const headers: Record = { + 'X-RateLimit-Limit': String(result.limit), + 'X-RateLimit-Remaining': String(result.remaining), + 'X-RateLimit-Reset': String(Math.floor(Date.now() / 1000) + result.retryAfterSeconds), + }; + if (!result.allowed) { + headers['Retry-After'] = String(result.retryAfterSeconds); + } + return headers; +} + +export function getRateLimitConfig(tier: Tier): RateLimitConfig { + return TIER_LIMITS[tier] ?? TIER_LIMITS.free; +} diff --git a/src/types.ts b/src/types.ts index 5d91711..45e5dd4 100644 --- a/src/types.ts +++ b/src/types.ts @@ -54,6 +54,35 @@ export interface AuthServiceRpc { name?: string; error?: string; }>; + + // ─── Quota (cost attribution) ──────────────────────────────── + checkQuota(params: { + tenantId: string; + userId?: string; + feature: string; + amount?: number; + }): Promise<{ + allowed: boolean; + remaining: number; + limit: number; + resetsAt?: string; + }>; + consumeQuota(params: { + tenantId: string; + userId?: string; + feature: string; + amount: number; + idempotencyKey?: string; + }): Promise<{ + success: boolean; + reservationId?: string; + remaining?: number; + error?: string; + }>; + commitOrRefundQuota( + reservationId: string, + outcome: 'success' | 'failed', + ): Promise; } // ─── Backend RPC surface (what product workers expose) ──────── @@ -98,6 +127,9 @@ export interface GatewayEnv { OAUTH_PROVIDER: OAuthHelpers; OAUTH_KV: KVNamespace; + // Rate limiting + RATELIMIT_KV: KVNamespace; + // Secrets SERVICE_BINDING_SECRET: string; diff --git a/test/audit.test.ts b/test/audit.test.ts index 78173eb..fe9ab95 100644 --- a/test/audit.test.ts +++ b/test/audit.test.ts @@ -176,6 +176,9 @@ describe('gateway audit integration', () => { registerUser: async (_n: string, _e: string, _p: string) => ({ valid: false }), provisionTenant: async (_p: { userId: string; source: string }) => ({ tenantId: '', userId: '', tier: 'free', delinquent: false, createdAt: '' }), exchangeSocialCode: async (_c: string) => ({ valid: false }), + checkQuota: async () => ({ allowed: true, remaining: 100, limit: 500 }), + consumeQuota: async () => ({ success: true, reservationId: 'res-1', remaining: 99 }), + commitOrRefundQuota: async () => {}, }; } @@ -206,6 +209,16 @@ describe('gateway audit integration', () => { getWithMetadata: async () => ({ value: null, metadata: null, cacheStatus: null }), } as unknown as KVNamespace; })(), + RATELIMIT_KV: (() => { + const store = new Map(); + return { + get: async (key: string) => store.get(key) ?? null, + put: async (key: string, value: string) => { store.set(key, value); }, + delete: async (key: string) => { store.delete(key); }, + list: async () => ({ keys: [], list_complete: true, cacheStatus: null }), + getWithMetadata: async () => ({ value: null, metadata: null, cacheStatus: null }), + } as unknown as KVNamespace; + })(), PLATFORM_EVENTS_QUEUE: { send: async () => {} } as unknown as Queue, SERVICE_BINDING_SECRET: 'test-secret', API_BASE_URL: 'https://mcp.stackbilt.dev', diff --git a/test/auth.test.ts b/test/auth.test.ts index 3527cc7..027f51f 100644 --- a/test/auth.test.ts +++ b/test/auth.test.ts @@ -16,6 +16,9 @@ function mockAuthService(overrides?: Partial): AuthServiceRpc { registerUser: async () => ({ valid: false }), provisionTenant: async () => ({ tenantId: '', userId: '', tier: 'free', delinquent: false, createdAt: '' }), exchangeSocialCode: async () => ({ valid: false }), + checkQuota: async () => ({ allowed: true, remaining: 100, limit: 500 }), + consumeQuota: async () => ({ success: true, reservationId: 'res-1', remaining: 99 }), + commitOrRefundQuota: async () => {}, ...overrides, }; } diff --git a/test/cost-attribution.test.ts b/test/cost-attribution.test.ts new file mode 100644 index 0000000..b3dcea2 --- /dev/null +++ b/test/cost-attribution.test.ts @@ -0,0 +1,144 @@ +import { describe, it, expect, vi } from 'vitest'; +import { + resolveToolCost, + isFreeTool, + reserveQuota, + settleQuota, + buildCostAttribution, +} from '../src/cost-attribution.js'; +import type { AuthServiceRpc } from '../src/types.js'; + +function mockAuthService(overrides?: Partial): AuthServiceRpc { + return { + validateApiKey: async () => ({ valid: true, tenant_id: 't', tier: 'pro', scopes: [] }), + validateJwt: async () => ({ valid: true, tenant_id: 't', user_id: 'u', tier: 'pro', scopes: [] }), + authenticateUser: async () => ({ valid: false }), + registerUser: async () => ({ valid: false }), + provisionTenant: async () => ({ tenantId: '', userId: '', tier: 'free', delinquent: false, createdAt: '' }), + exchangeSocialCode: async () => ({ valid: false }), + checkQuota: async () => ({ allowed: true, remaining: 100, limit: 500 }), + consumeQuota: async () => ({ success: true, reservationId: 'res-1', remaining: 99 }), + commitOrRefundQuota: async () => {}, + ...overrides, + }; +} + +describe('resolveToolCost', () => { + it('returns base cost for known tools', () => { + const cost = resolveToolCost('scaffold_create'); + expect(cost.baseCost).toBe(2); + expect(cost.feature).toBe('mcp.scaffold_create'); + }); + + it('returns 0 cost for read-only tools', () => { + expect(resolveToolCost('scaffold_status').baseCost).toBe(0); + expect(resolveToolCost('image_list_models').baseCost).toBe(0); + expect(resolveToolCost('flow_status').baseCost).toBe(0); + }); + + it('applies quality multiplier for image_generate', () => { + const draft = resolveToolCost('image_generate', { quality_tier: 'draft' }); + const ultra = resolveToolCost('image_generate', { quality_tier: 'ultra' }); + expect(ultra.baseCost).toBeGreaterThan(draft.baseCost); + }); + + it('returns default cost for unknown tools', () => { + const cost = resolveToolCost('unknown_tool'); + expect(cost.baseCost).toBe(1); + expect(cost.feature).toBe('mcp.unknown_tool'); + }); +}); + +describe('isFreeTool', () => { + it('returns true for zero-cost tools', () => { + expect(isFreeTool('scaffold_status')).toBe(true); + expect(isFreeTool('image_list_models')).toBe(true); + }); + + it('returns false for paid tools', () => { + expect(isFreeTool('image_generate')).toBe(false); + expect(isFreeTool('scaffold_create')).toBe(false); + }); + + it('returns false for unknown tools', () => { + expect(isFreeTool('nonexistent')).toBe(false); + }); +}); + +describe('reserveQuota', () => { + it('reserves quota via auth service', async () => { + const auth = mockAuthService(); + const result = await reserveQuota(auth, 'tenant-1', 'user-1', 'scaffold_create'); + expect(result.allowed).toBe(true); + expect(result.reservationId).toBe('res-1'); + }); + + it('skips quota for free tools', async () => { + const auth = mockAuthService(); + const result = await reserveQuota(auth, 'tenant-1', 'user-1', 'scaffold_status'); + expect(result.allowed).toBe(true); + expect(result.reservationId).toBeUndefined(); + }); + + it('rejects when quota is exceeded', async () => { + const auth = mockAuthService({ + consumeQuota: async () => ({ success: false, error: 'Quota exceeded', remaining: 0 }), + }); + const result = await reserveQuota(auth, 'tenant-1', 'user-1', 'image_generate'); + expect(result.allowed).toBe(false); + expect(result.error).toContain('Quota exceeded'); + }); + + it('fails closed on auth service error for mutation tools', async () => { + const auth = mockAuthService({ + consumeQuota: async () => { throw new Error('RPC timeout'); }, + }); + const result = await reserveQuota(auth, 'tenant-1', 'user-1', 'scaffold_create'); + expect(result.allowed).toBe(false); + expect(result.error).toContain('unavailable'); + }); +}); + +describe('settleQuota', () => { + it('commits on success', async () => { + const commitFn = vi.fn(); + const auth = mockAuthService({ commitOrRefundQuota: commitFn }); + await settleQuota(auth, 'res-1', true); + expect(commitFn).toHaveBeenCalledWith('res-1', 'success'); + }); + + it('refunds on failure', async () => { + const commitFn = vi.fn(); + const auth = mockAuthService({ commitOrRefundQuota: commitFn }); + await settleQuota(auth, 'res-1', false); + expect(commitFn).toHaveBeenCalledWith('res-1', 'failed'); + }); + + it('skips settlement when no reservation ID', async () => { + const commitFn = vi.fn(); + const auth = mockAuthService({ commitOrRefundQuota: commitFn }); + await settleQuota(auth, undefined, true); + expect(commitFn).not.toHaveBeenCalled(); + }); + + it('does not throw on settlement failure', async () => { + const auth = mockAuthService({ + commitOrRefundQuota: async () => { throw new Error('RPC error'); }, + }); + // Should not throw + await expect(settleQuota(auth, 'res-1', true)).resolves.toBeUndefined(); + }); +}); + +describe('buildCostAttribution', () => { + it('returns feature and cost for known tools', () => { + const attr = buildCostAttribution('image_generate', { quality_tier: 'premium' }); + expect(attr.feature).toBe('mcp.image_generate'); + expect(attr.creditCost).toBe(15); // 5 * 3 (premium multiplier) + }); + + it('returns base cost for tools without args', () => { + const attr = buildCostAttribution('scaffold_create'); + expect(attr.creditCost).toBe(2); + }); +}); diff --git a/test/gateway.test.ts b/test/gateway.test.ts index 9161f0d..3d35fd3 100644 --- a/test/gateway.test.ts +++ b/test/gateway.test.ts @@ -11,6 +11,9 @@ function mockAuthService(tier: string = 'pro'): AuthServiceRpc { registerUser: async () => ({ valid: false }), provisionTenant: async () => ({ tenantId: '', userId: '', tier: 'free', delinquent: false, createdAt: '' }), exchangeSocialCode: async () => ({ valid: false }), + checkQuota: async () => ({ allowed: true, remaining: 100, limit: 500 }), + consumeQuota: async () => ({ success: true, reservationId: 'res-1', remaining: 99 }), + commitOrRefundQuota: async () => {}, }; } @@ -42,6 +45,7 @@ function makeEnv(overrides?: Partial): GatewayEnv { IMG_FORGE: mockFetcher({ jsonrpc: '2.0', id: 1, result: { content: [{ type: 'text', text: 'image generated' }] } }), OAUTH_PROVIDER: {} as any, OAUTH_KV: mockKV(), + RATELIMIT_KV: mockKV(), PLATFORM_EVENTS_QUEUE: { send: async () => {} } as unknown as Queue, SERVICE_BINDING_SECRET: 'test-secret', API_BASE_URL: 'https://mcp.stackbilt.dev', diff --git a/test/oauth-handler.test.ts b/test/oauth-handler.test.ts index 32c6b9d..9ce1959 100644 --- a/test/oauth-handler.test.ts +++ b/test/oauth-handler.test.ts @@ -83,6 +83,9 @@ function mockAuthService(overrides?: Partial) { email: 'social@example.com', name: 'Social User', })), + checkQuota: vi.fn(async () => ({ allowed: true, remaining: 100, limit: 500 })), + consumeQuota: vi.fn(async () => ({ success: true, reservationId: 'res-1', remaining: 99 })), + commitOrRefundQuota: vi.fn(async () => {}), ...overrides, }; } @@ -94,6 +97,7 @@ function makeEnv(overrides?: Partial): GatewayEnv { IMG_FORGE: {} as Fetcher, OAUTH_PROVIDER: mockOAuthProvider() as unknown as GatewayEnv['OAUTH_PROVIDER'], OAUTH_KV: mockKV(), + RATELIMIT_KV: mockKV(), PLATFORM_EVENTS_QUEUE: { send: async () => {} } as unknown as Queue, SERVICE_BINDING_SECRET: TEST_SECRET, API_BASE_URL: TEST_API_BASE_URL, diff --git a/test/rate-limiter.test.ts b/test/rate-limiter.test.ts new file mode 100644 index 0000000..743d19a --- /dev/null +++ b/test/rate-limiter.test.ts @@ -0,0 +1,90 @@ +import { describe, it, expect } from 'vitest'; +import { checkRateLimit, rateLimitHeaders, getRateLimitConfig } from '../src/rate-limiter.js'; + +function mockKV(): KVNamespace { + const store = new Map(); + return { + get: async (key: string) => store.get(key) ?? null, + put: async (key: string, value: string) => { store.set(key, value); }, + delete: async (key: string) => { store.delete(key); }, + list: async () => ({ keys: [], list_complete: true, cacheStatus: null }), + getWithMetadata: async () => ({ value: null, metadata: null, cacheStatus: null }), + } as unknown as KVNamespace; +} + +describe('checkRateLimit', () => { + it('allows first request', async () => { + const kv = mockKV(); + const result = await checkRateLimit(kv, 'tenant-1', 'free'); + expect(result.allowed).toBe(true); + expect(result.remaining).toBeGreaterThan(0); + }); + + it('decrements remaining with each request', async () => { + const kv = mockKV(); + const r1 = await checkRateLimit(kv, 'tenant-1', 'free'); + const r2 = await checkRateLimit(kv, 'tenant-1', 'free'); + expect(r2.remaining).toBe(r1.remaining - 1); + }); + + it('rejects when limit is exhausted', async () => { + const kv = mockKV(); + // Use a tiny limit + const config = { limit: 2, windowSeconds: 60 }; + await checkRateLimit(kv, 'tenant-1', 'free', config); + await checkRateLimit(kv, 'tenant-1', 'free', config); + const result = await checkRateLimit(kv, 'tenant-1', 'free', config); + expect(result.allowed).toBe(false); + expect(result.remaining).toBe(0); + }); + + it('tracks separate keys independently', async () => { + const kv = mockKV(); + const config = { limit: 1, windowSeconds: 60 }; + await checkRateLimit(kv, 'tenant-1', 'free', config); + const result = await checkRateLimit(kv, 'tenant-2', 'free', config); + expect(result.allowed).toBe(true); + }); + + it('higher tiers get higher limits', () => { + const free = getRateLimitConfig('free'); + const pro = getRateLimitConfig('pro'); + const enterprise = getRateLimitConfig('enterprise'); + expect(pro.limit).toBeGreaterThan(free.limit); + expect(enterprise.limit).toBeGreaterThan(pro.limit); + }); +}); + +describe('rateLimitHeaders', () => { + it('includes standard rate limit headers', () => { + const headers = rateLimitHeaders({ + allowed: true, + remaining: 42, + limit: 60, + retryAfterSeconds: 30, + }); + expect(headers['X-RateLimit-Limit']).toBe('60'); + expect(headers['X-RateLimit-Remaining']).toBe('42'); + expect(headers['X-RateLimit-Reset']).toBeDefined(); + }); + + it('adds Retry-After when not allowed', () => { + const headers = rateLimitHeaders({ + allowed: false, + remaining: 0, + limit: 60, + retryAfterSeconds: 45, + }); + expect(headers['Retry-After']).toBe('45'); + }); + + it('omits Retry-After when allowed', () => { + const headers = rateLimitHeaders({ + allowed: true, + remaining: 10, + limit: 60, + retryAfterSeconds: 30, + }); + expect(headers['Retry-After']).toBeUndefined(); + }); +}); diff --git a/wrangler.toml b/wrangler.toml index 5a0b71f..1e74033 100644 --- a/wrangler.toml +++ b/wrangler.toml @@ -42,6 +42,11 @@ service = "n8n-transpiler" binding = "OAUTH_KV" id = "9c165be8754749e3b543458ae8e596db" +# Rate limiting — sliding window counters per API key / tenant +[[kv_namespaces]] +binding = "RATELIMIT_KV" +id = "240065d87b05466ab7b5527e3552817b" + # Custom domain — taken over from img-forge-mcp (ADR-039) [[routes]] pattern = "mcp.stackbilt.dev"