Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions src/cost-attribution.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
// ─── Cost Attribution ─────────────────────────────────────────
// Maps tool calls to credit costs and enforces quotas via edge-auth.
// Every tool call is metered: check quota before, consume after success.
// Cost data flows to the audit pipeline for billing dashboards.

import type { Tier, AuthServiceRpc } from './types.js';
import type { AuditArtifact } from './audit.js';

// ─── Credit cost per tool call ──────────────────────────────────
// Costs are in "credits" — 1 credit = 1 unit of the tier's allocation.
// Expensive operations (image gen, deploy) cost more.
export interface ToolCost {
/** Base credit cost for the tool call */
baseCost: number;
/** Feature key for quota tracking (maps to edge-auth quota.feature) */
feature: string;
}

const TOOL_COSTS: Record<string, ToolCost> = {
// img-forge: costs depend on quality tier (resolved at call time)
'image_generate': { baseCost: 5, feature: 'mcp.image_generate' },
'image_list_models': { baseCost: 0, feature: 'mcp.image_list_models' },
'image_check_job': { baseCost: 0, feature: 'mcp.image_check_job' },

// TarotScript scaffold
'scaffold_create': { baseCost: 2, feature: 'mcp.scaffold_create' },
'scaffold_classify': { baseCost: 0, feature: 'mcp.scaffold_classify' },
'scaffold_status': { baseCost: 0, feature: 'mcp.scaffold_status' },
'scaffold_publish': { baseCost: 3, feature: 'mcp.scaffold_publish' },
'scaffold_deploy': { baseCost: 5, feature: 'mcp.scaffold_deploy' },
'scaffold_import': { baseCost: 1, feature: 'mcp.scaffold_import' },

// Flow tools
'flow_create': { baseCost: 2, feature: 'mcp.flow_create' },
'flow_status': { baseCost: 0, feature: 'mcp.flow_status' },
'flow_summary': { baseCost: 0, feature: 'mcp.flow_summary' },
'flow_quality': { baseCost: 0, feature: 'mcp.flow_quality' },
'flow_governance': { baseCost: 0, feature: 'mcp.flow_governance' },

// Visual QA
'visual_screenshot': { baseCost: 1, feature: 'mcp.visual_screenshot' },
'visual_analyze': { baseCost: 2, feature: 'mcp.visual_analyze' },
'visual_pages': { baseCost: 0, feature: 'mcp.visual_pages' },
};

// Quality tier multipliers for image_generate
const IMAGE_QUALITY_MULTIPLIER: Record<string, number> = {
draft: 1,
standard: 1,
premium: 3,
ultra: 5,
ultra_plus: 8,
};

/**
* Resolve the credit cost for a tool call, factoring in quality tier for images.
*/
export function resolveToolCost(
toolName: string,
args?: Record<string, unknown>,
): ToolCost {
const base = TOOL_COSTS[toolName];
if (!base) {
// Unknown tools cost 1 credit by default (conservative)
return { baseCost: 1, feature: `mcp.${toolName}` };
}

// Apply quality multiplier for image_generate
if (toolName === 'image_generate' && args?.quality_tier) {
const multiplier = IMAGE_QUALITY_MULTIPLIER[args.quality_tier as string] ?? 1;
return { ...base, baseCost: base.baseCost * multiplier };
}

return base;
}

/**
* Check if a tool call is free (cost = 0). Free calls skip quota enforcement.
*/
export function isFreeTool(toolName: string): boolean {
const cost = TOOL_COSTS[toolName];
return cost !== undefined && cost.baseCost === 0;
}

export interface QuotaCheckResult {
allowed: boolean;
reservationId?: string;
remaining?: number;
error?: string;
}

/**
* Check and reserve quota for a tool call via edge-auth RPC.
* Returns a reservation ID that must be committed or refunded after the call.
*/
export async function reserveQuota(
authService: AuthServiceRpc,
tenantId: string,
userId: string,
toolName: string,
args?: Record<string, unknown>,
): Promise<QuotaCheckResult> {
const cost = resolveToolCost(toolName, args);

// Free tools don't consume quota
if (cost.baseCost === 0) {
return { allowed: true };
}

try {
const result = await authService.consumeQuota({
tenantId,
userId,
feature: cost.feature,
amount: cost.baseCost,
});

if (!result.success) {
return {
allowed: false,
error: result.error ?? 'Quota exceeded',
remaining: result.remaining,
};
}

return {
allowed: true,
reservationId: result.reservationId,
remaining: result.remaining,
};
} catch (err) {
// Quota service unavailable — fail open for read-only tools, closed for mutations
const isReadOnly = cost.baseCost <= 0;
if (isReadOnly) {
return { allowed: true };
}
return {
allowed: false,
error: 'Quota service unavailable',
};
}
}

/**
* Commit or refund a quota reservation based on tool call outcome.
*/
export async function settleQuota(
authService: AuthServiceRpc,
reservationId: string | undefined,
success: boolean,
): Promise<void> {
if (!reservationId) return;

try {
await authService.commitOrRefundQuota(
reservationId,
success ? 'success' : 'failed',
);
} catch {
// Best-effort — don't fail the tool call if settlement fails.
// The reservation will auto-expire in edge-auth.
console.error(`[cost] Failed to settle reservation ${reservationId}`);
}
}

/**
* Build cost attribution data for the audit artifact.
*/
export function buildCostAttribution(
toolName: string,
args?: Record<string, unknown>,
): { feature: string; creditCost: number } {
const cost = resolveToolCost(toolName, args);
return {
feature: cost.feature,
creditCost: cost.baseCost,
};
}
113 changes: 111 additions & 2 deletions src/gateway.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import { materializeScaffold } from './scaffold-materializer.js';
import { publishToGitHub } from './scaffold-publish.js';
import { classifyIntention, type IntentClassification } from './intent-classifier.js';
import { logDivergence } from './divergence-logger.js';
import { checkRateLimit, rateLimitHeaders, type RateLimitResult } from './rate-limiter.js';
import { reserveQuota, settleQuota, buildCostAttribution, isFreeTool } from './cost-attribution.js';

const MCP_PROTOCOL_VERSION = '2025-03-26';
const JSON_RPC_PARSE_ERROR = -32700;
Expand Down Expand Up @@ -1010,6 +1012,28 @@ async function handlePost(request: Request, env: GatewayEnv, oauthProps?: OAuthP
);
}

// Rate limiting — check before processing
const rateLimitKey = authResult.tenantId ?? authResult.userId ?? 'unknown';
const rlResult = await checkRateLimit(env.RATELIMIT_KV, rateLimitKey, authResult.tier);
if (!rlResult.allowed) {
audit({
trace_id: generateTraceId(),
principal: authResult.userId ?? 'unknown',
tenant: authResult.tenantId ?? 'unknown',
tool: 'rate_limit',
risk_level: 'UNKNOWN',
policy_decision: 'DENY',
redacted_input_summary: '{}',
outcome: 'auth_denied',
timestamp: new Date().toISOString(),
}, env);
return jsonResponse(
{ error: 'Rate limit exceeded', code: 'RATE_LIMITED' },
429,
rateLimitHeaders(rlResult),
);
}

// Validate Accept header
const accept = request.headers.get('Accept') ?? '';
if (!accept.includes('application/json') && !accept.includes('*/*') && accept !== '') {
Expand Down Expand Up @@ -1065,7 +1089,18 @@ async function handlePost(request: Request, env: GatewayEnv, oauthProps?: OAuthP
// ─── tools/list ─────────────────────────────────────────
if (rpcMethod === 'tools/list') {
// KV handles session expiration via expirationTtl — no manual pruning needed
const tools = buildAggregatedCatalog();
let tools = buildAggregatedCatalog();

// Scope-based filtering: only show tools the session has access to
const hasGenerate = session.scopes.includes('generate');
if (!hasGenerate) {
// Read-only scope — filter out mutation tools
tools = tools.filter(t => {
const risk = getToolRiskLevel(t.name);
return risk === 'READ_ONLY';
});
}

return rpcResult(rpcId, { tools });
}

Expand Down Expand Up @@ -1108,6 +1143,23 @@ async function handlePost(request: Request, env: GatewayEnv, oauthProps?: OAuthP
return rpcError(rpcId, JSON_RPC_METHOD_NOT_FOUND, `Unknown tool: ${toolName}`);
}

// Scope enforcement: mutation tools require 'generate' scope
if (risk !== 'READ_ONLY' && !session.scopes.includes('generate')) {
audit({
trace_id: traceId,
principal: session.userId ?? 'unknown',
tenant: session.tenantId ?? 'unknown',
tool: toolName,
risk_level: risk,
policy_decision: 'DENY',
redacted_input_summary: summarizeInput(toolArgs),
outcome: 'auth_denied',
timestamp: new Date().toISOString(),
}, env);
return rpcError(rpcId, JSON_RPC_INVALID_PARAMS,
`Tool "${toolName}" requires the "generate" scope. Your API key only has: ${session.scopes.join(', ')}`);
}

// Validate arguments are object-shaped
const argValidation = validateToolArguments(toolArgs, { type: 'object' });
if (!argValidation.valid) {
Expand Down Expand Up @@ -1142,8 +1194,65 @@ async function handlePost(request: Request, env: GatewayEnv, oauthProps?: OAuthP
return rpcError(rpcId, JSON_RPC_INVALID_PARAMS, tierDenied);
}

// ─── Cost attribution: reserve quota before tool call ────
const costInfo = buildCostAttribution(toolName, toolArgs as Record<string, unknown> | undefined);
let quotaReservation: { reservationId?: string } = {};

if (!isFreeTool(toolName) && session.tenantId) {
const quotaResult = await reserveQuota(
env.AUTH_SERVICE,
session.tenantId,
session.userId ?? '',
toolName,
toolArgs as Record<string, unknown> | undefined,
);

if (!quotaResult.allowed) {
audit({
trace_id: traceId,
principal: session.userId ?? 'unknown',
tenant: session.tenantId ?? 'unknown',
tool: toolName,
risk_level: risk,
policy_decision: 'DENY',
redacted_input_summary: summarizeInput(toolArgs),
outcome: 'tier_denied',
timestamp: new Date().toISOString(),
}, env);
return rpcError(rpcId, JSON_RPC_INVALID_PARAMS,
`Quota exceeded for ${toolName}. ${quotaResult.error ?? 'Upgrade your plan for more credits.'}`);
}

quotaReservation = { reservationId: quotaResult.reservationId };
}

const result = await proxyToolCall(env, toolName, toolArgs, session, traceId);
return rpcResult(rpcId, result);

// ─── Cost attribution: settle quota after tool call ──────
const toolSucceeded = !result.isError;
await settleQuota(env.AUTH_SERVICE, quotaReservation.reservationId, toolSucceeded);

// Enrich audit queue event with cost data
queueAuditEvent(env.PLATFORM_EVENTS_QUEUE, {
trace_id: traceId,
principal: session.userId ?? 'unknown',
tenant: session.tenantId ?? 'unknown',
tool: toolName,
risk_level: risk,
policy_decision: 'ALLOW',
redacted_input_summary: summarizeInput(toolArgs),
outcome: toolSucceeded ? 'success' : 'error',
timestamp: new Date().toISOString(),
latency_ms: 0, // latency is tracked in proxyToolCall's own audit
});

// Add rate limit headers to successful responses
const response = rpcResult(rpcId, result);
const rlHeaders = rateLimitHeaders(rlResult);
for (const [k, v] of Object.entries(rlHeaders)) {
response.headers.set(k, v);
}
return response;
}

return rpcError(rpcId, JSON_RPC_METHOD_NOT_FOUND, `Unknown method: ${rpcMethod}`);
Expand Down
Loading