From c73aced270308539c408f79f9df890492909a5a1 Mon Sep 17 00:00:00 2001 From: sid mohan Date: Fri, 27 Feb 2026 17:33:11 -0800 Subject: [PATCH] feat(admin): add optional admin dashboard and receipts listing --- README.md | 10 + cmd/datafog-api/main.go | 35 ++- docs/FRONTEND.md | 1 + docs/admin.html | 354 +++++++++++++++++++++++ docs/contracts/datafog-api-contract.md | 43 +++ go.mod | 3 +- go.sum | 2 + internal/policy/policy.go | 62 +++- internal/receipts/store.go | 341 +++++++++++++++++++--- internal/receipts/store_test.go | 148 ++++++++++ internal/scan/detector.go | 100 +++++-- internal/scan/ner.go | 95 +++++- internal/server/admin.go | 35 +++ internal/server/metrics.go | 166 +++++++++++ internal/server/server.go | 352 +++++++++++++--------- internal/server/server_benchmark_test.go | 3 + internal/server/server_test.go | 90 ++++++ internal/shim/events.go | 152 ++++++++-- internal/shim/events_test.go | 75 +++++ internal/types/ttlcache/ttlcache.go | 202 +++++++++++++ 20 files changed, 2045 insertions(+), 224 deletions(-) create mode 100644 docs/admin.html create mode 100644 internal/server/admin.go create mode 100644 internal/server/metrics.go create mode 100644 internal/shim/events_test.go create mode 100644 internal/types/ttlcache/ttlcache.go diff --git a/README.md b/README.md index e8156ac..281974d 100644 --- a/README.md +++ b/README.md @@ -103,6 +103,8 @@ If you set `DATAFOG_API_TOKEN`, send it on every request using: | `DATAFOG_FGPROF` | `false` | Add `/debug/fgprof` endpoint to the profiling server | | `DATAFOG_ENABLE_DEMO` | *(unset)* | Enable `/demo*` endpoints | | `DATAFOG_DEMO_HTML` | `docs/demo.html` | Path to demo HTML | +| `DATAFOG_ENABLE_ADMIN_UI` | *(unset)* | Enable read-only `GET /admin` | +| `DATAFOG_ADMIN_HTML` | `docs/admin.html` | Path to static admin dashboard HTML | Duration values use Go duration syntax, for example `1s`, `500ms`, `2m`. @@ -118,8 +120,10 @@ Base URL defaults to `http://localhost:8080`. | `POST` | `/v1/decide` | Evaluate an action + findings and get a decision | | `POST` | `/v1/transform` | Apply requested transform mode(s) | | `POST` | `/v1/anonymize` | Apply irreversible anonymization | +| `GET` | `/v1/receipts` | List recent decision receipts | | `GET` | `/v1/receipts/{id}` | Read a decision receipt | | `GET` | `/v1/events` | List recent decision events | +| `GET` | `/admin` | Read-only operational dashboard (requires DATAFOG_ENABLE_ADMIN_UI) | | `GET` | `/metrics` | In-process metrics counters | Optional demo routes (only when demo mode is enabled): @@ -192,6 +196,12 @@ curl -X POST http://localhost:8080/v1/transform \ }' ``` +### List receipts (admin/read-only) + +```sh +curl 'http://localhost:8080/v1/receipts?limit=20&decision=allow' +``` + ### Fetch a receipt ```sh diff --git a/cmd/datafog-api/main.go b/cmd/datafog-api/main.go index a441030..3c2ffdf 100644 --- a/cmd/datafog-api/main.go +++ b/cmd/datafog-api/main.go @@ -33,6 +33,7 @@ func main() { rateLimitRPS := getenvInt("DATAFOG_RATE_LIMIT_RPS", 0) shutdownTimeout := getenvDuration("DATAFOG_SHUTDOWN_TIMEOUT", 10*time.Second) enableDemo := getenv("DATAFOG_ENABLE_DEMO", "") != "" || hasFlag("--enable-demo") + enableAdmin := getenvBool("DATAFOG_ENABLE_ADMIN_UI", false) || hasFlag("--enable-admin-ui") eventsPath := getenv("DATAFOG_EVENTS_PATH", "") pprofAddr := getenv("DATAFOG_PPROF_ADDR", "") fgprofEnabled := getenvBool("DATAFOG_FGPROF", false) @@ -60,25 +61,48 @@ func main() { h.SetEventReader(eventReader) } - var handler http.Handler + var demo *server.DemoHandler if enableDemo { // Create a shim gate backed by a local HTTP decision client client := shim.NewHTTPDecisionClient("http://127.0.0.1"+addr, apiToken) gate := shim.NewGate(client, shim.WithEventSink(eventSink)) demoHTMLPath := getenv("DATAFOG_DEMO_HTML", "docs/demo.html") - demo, err := server.NewDemoHandler(gate, h, demoHTMLPath) + demo, err = server.NewDemoHandler(gate, h, demoHTMLPath) if err != nil { log.Fatalf("init demo: %v", err) } defer demo.Cleanup() + } + var admin *server.AdminHandler + if enableAdmin { + adminHTMLPath := getenv("DATAFOG_ADMIN_HTML", "docs/admin.html") + admin, err = server.NewAdminHandler(adminHTMLPath) + if err != nil { + log.Fatalf("init admin UI: %v", err) + } + } + + var handler http.Handler + switch { + case demo != nil && admin != nil: + handler = h.HandlerWithDemoAndAdmin(demo, admin) + case demo != nil: handler = h.HandlerWithDemo(demo) - log.Printf("demo mode enabled — /demo/exec, /demo/write-file, /demo/read-file available") - } else { + case admin != nil: + handler = h.HandlerWithAdmin(admin) + default: handler = h.Handler() } + if enableDemo { + log.Printf("demo mode enabled — /demo/exec, /demo/write-file, /demo/read-file available") + } + if enableAdmin { + log.Printf("admin UI enabled — /admin available") + } + var pprofSrv *http.Server if pprofAddr != "" { pprofSrv = startProfilingServer(pprofAddr, fgprofEnabled, log.Default()) @@ -129,6 +153,9 @@ func main() { log.Printf("forced close failed: %v", closeErr) } } + if err := h.Shutdown(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.Printf("application shutdown helper failed: %v", err) + } if err := <-done; err != nil && !errors.Is(err, http.ErrServerClosed) { log.Printf("server stopped with error: %v", err) } diff --git a/docs/FRONTEND.md b/docs/FRONTEND.md index 7a649b7..b47564f 100644 --- a/docs/FRONTEND.md +++ b/docs/FRONTEND.md @@ -9,6 +9,7 @@ DataFog API is a backend-first project. There is no React/Vue/Next.js applicatio - Primary user-facing UI is API-first: clients interact through HTTP endpoints. - Optional demo assets are static HTML in `docs/demo.html` and rendered by `GET /demo` when demo mode is enabled. +- Optional admin dashboard is static HTML in `docs/admin.html` and rendered by `GET /admin` when admin UI mode is enabled. ## Conventions diff --git a/docs/admin.html b/docs/admin.html new file mode 100644 index 0000000..309f8bd --- /dev/null +++ b/docs/admin.html @@ -0,0 +1,354 @@ + + + + + + DataFog Admin + + + +

DataFog Admin

+

Read-only operational dashboard for inspecting decisions and running policy checks.

+ +
+
+

Connection

+
+
Unknown
+ +
+ + +
+ + +
+

Token stored only in browser localStorage for this page.

+
+
+ +
+

Service Metrics Snapshot

+
Loading...
+
+ +
+

Scan / Decide

+ + + + + + + + +
No payload yet.
+
+ +
+

Receipt lookup

+ +
+ +
+ +
No payload yet.
+
+ +
+

Recent Events

+
+ + +
+ +
No payload yet.
+
+ +
+

Recent Receipts

+
+ + +
+ +
No payload yet.
+
+
+ + + + diff --git a/docs/contracts/datafog-api-contract.md b/docs/contracts/datafog-api-contract.md index 1543e09..461cf1b 100644 --- a/docs/contracts/datafog-api-contract.md +++ b/docs/contracts/datafog-api-contract.md @@ -337,6 +337,36 @@ Returns persisted decision receipts. } ``` +### `GET /v1/receipts` + +Returns a bounded list of recent receipts. + +Query params: + +- `limit` (`1..1000`, default `100`) +- `after` (RFC3339 timestamp, strictly greater-than) +- `before` (RFC3339 timestamp, strictly less-than) +- `decision` (`allow|allow_with_redaction|transform|deny`) +- `action_type` (match against `action.type`) + +Results are sorted newest-first by `timestamp`. + +```json +{ + "receipts": [ + { + "receipt_id": "string", + "timestamp": "RFC3339 timestamp", + "decision": "allow|allow_with_redaction|transform|deny", + "action": { + "type": "string" + } + } + ], + "total": 1 +} +``` + ### `GET /v1/events` Returns decision events when `DATAFOG_EVENTS_PATH` is configured (or another reader is set). @@ -377,6 +407,19 @@ Query params: If no events are configured or none match, return `{"events":[],"total":0}`. +## Optional admin dashboard (v1) + +The following route is available when `DATAFOG_ENABLE_ADMIN_UI` or `--enable-admin-ui` is set: + +- `GET /admin` + +It serves a read-only static dashboard for operations and diagnostics. If enabled, a lightweight HTML asset is served from: + +- `docs/admin.html` by default, +- or the path specified by `DATAFOG_ADMIN_HTML`. + +Supported auth follows normal API middleware behavior; requesters may use the same token and request ID model as all other routes. + ## Idempotency - Supported endpoints: `POST /v1/scan`, `POST /v1/decide`, `POST /v1/transform`, `POST /v1/anonymize`. diff --git a/go.mod b/go.mod index 2028ce1..fc00857 100644 --- a/go.mod +++ b/go.mod @@ -1,11 +1,12 @@ module github.com/datafog/datafog-api -go 1.22 +go 1.24.0 require ( github.com/felixge/fgprof v0.9.5 go.uber.org/automaxprocs v1.6.0 go.uber.org/goleak v1.3.0 + golang.org/x/sync v0.19.0 ) require github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7 // indirect diff --git a/go.sum b/go.sum index e80acc3..6bf2fac 100644 --- a/go.sum +++ b/go.sum @@ -32,6 +32,8 @@ go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/policy/policy.go b/internal/policy/policy.go index c50f4c6..9da5cf4 100644 --- a/internal/policy/policy.go +++ b/internal/policy/policy.go @@ -131,7 +131,7 @@ func ValidatePolicy(policy models.Policy) error { if len(errors) == 0 { return nil } - return fmt.Errorf(strings.Join(errors, "; ")) + return fmt.Errorf("%s", strings.Join(errors, "; ")) } var allowedModes = map[models.TransformMode]struct{}{ @@ -164,6 +164,58 @@ func Evaluate(policy models.Policy, ctx DecisionContext) DecisionResult { // EvaluateSorted evaluates policy decisions assuming rules are already sorted // by priority descending. +type PolicyIndex struct { + byAction map[string][]models.Rule +} + +// BuildPolicyIndex precomputes policy lookup tables for faster action-based filtering. +// Policy rules are expected to already be in evaluation order. +func BuildPolicyIndex(policy models.Policy) *PolicyIndex { + index := &PolicyIndex{byAction: make(map[string][]models.Rule, len(policy.Rules))} + for _, rule := range policy.Rules { + if len(rule.Match.ActionTypes) == 0 { + index.byAction["*"] = append(index.byAction["*"], rule) + continue + } + for _, actionType := range rule.Match.ActionTypes { + normalized := strings.ToLower(strings.TrimSpace(actionType)) + if normalized == "" { + continue + } + index.byAction[normalized] = append(index.byAction[normalized], rule) + } + } + return index +} + +func (idx *PolicyIndex) rulesForAction(actionType string) []models.Rule { + if idx == nil { + return nil + } + + normalizedActionType := strings.ToLower(strings.TrimSpace(actionType)) + actionRules := idx.byAction[normalizedActionType] + wildcard := idx.byAction["*"] + matched := make([]models.Rule, 0, len(actionRules)+len(wildcard)) + matched = append(matched, actionRules...) + matched = append(matched, wildcard...) + + if len(matched) == 0 { + return nil + } + + seen := make(map[string]struct{}, len(matched)) + ordered := make([]models.Rule, 0, len(matched)) + for _, rule := range matched { + if _, ok := seen[rule.ID]; ok { + continue + } + seen[rule.ID] = struct{}{} + ordered = append(ordered, rule) + } + return ordered +} + type DecisionContext struct { Action models.ActionMeta Findings []models.ScanFinding @@ -177,6 +229,11 @@ type DecisionResult struct { } func EvaluateSorted(policy models.Policy, ctx DecisionContext) DecisionResult { + return EvaluateWithIndex(policy, nil, ctx) +} + +// EvaluateWithIndex is the hot-path implementation that can use a precomputed policy index. +func EvaluateWithIndex(policy models.Policy, index *PolicyIndex, ctx DecisionContext) DecisionResult { if ctx.Action.Type == "" { return DecisionResult{ Decision: models.DecisionDeny, @@ -192,6 +249,9 @@ func EvaluateSorted(policy models.Policy, ctx DecisionContext) DecisionResult { } rules := policy.Rules + if index != nil { + rules = index.rulesForAction(ctx.Action.Type) + } hasFindings := map[string]struct{}{} for _, f := range ctx.Findings { diff --git a/internal/receipts/store.go b/internal/receipts/store.go index d8e453b..962ed00 100644 --- a/internal/receipts/store.go +++ b/internal/receipts/store.go @@ -5,9 +5,11 @@ import ( "crypto/rand" "encoding/hex" "encoding/json" + "errors" "fmt" "os" "path/filepath" + "sort" "strings" "sync" "time" @@ -16,18 +18,54 @@ import ( "github.com/datafog/datafog-api/internal/policy" ) -const maxReceiptLineBytes = 1024 * 1024 -const defaultReceiptFileMode = 0o600 -const defaultReceiptDirMode = 0o750 +const ( + maxReceiptLineBytes = 1024 * 1024 + + defaultReceiptFileMode = 0o600 + defaultReceiptDirMode = 0o750 + + defaultReceiptWriteQueueSize = 1024 + defaultReceiptQueueTimeout = 5 * time.Second + defaultReceiptFlushInterval = 250 * time.Millisecond + defaultReceiptBatchWrites = 32 +) type ReceiptStore struct { - mu sync.RWMutex - filePath string - receipts map[string]models.Receipt - maxEntries int - entryCount int + mu sync.RWMutex + filePath string + receipts map[string]models.Receipt + maxEntries int + entryCount int + writeQueue chan queuedReceipt + flushInterval time.Duration + batchWrites int + queueTimeout time.Duration + writeDelay time.Duration + closed bool + + closeCh chan struct{} + closedCh chan struct{} + workerWg sync.WaitGroup + writer *bufio.Writer + file *os.File +} + +type queuedReceipt struct { + receipt models.Receipt + rotate bool +} + +type ListQuery struct { + Limit int + Decision string + ActionType string + After *time.Time + Before *time.Time } +var errStoreClosed = errors.New("receipt store is closed") +var errReceiptWriteQueueFull = errors.New("receipt write queue is full") + // MaxEntries sets the maximum number of receipts before rotation. // 0 means no limit (default). func MaxEntries(n int) func(*ReceiptStore) { @@ -36,6 +74,31 @@ func MaxEntries(n int) func(*ReceiptStore) { } } +// MaxWriteQueueSize sets the write queue buffer size. +func MaxWriteQueueSize(n int) func(*ReceiptStore) { + return func(s *ReceiptStore) { + if n < 0 { + n = 0 + } + s.writeQueue = make(chan queuedReceipt, n) + } +} + +// WriteQueueTimeout sets the amount of time Save waits when queue is full before returning. +func WriteQueueTimeout(timeout time.Duration) func(*ReceiptStore) { + return func(s *ReceiptStore) { + s.queueTimeout = timeout + } +} + +// WriteDelay introduces an artificial delay before each persisted write. +// It can be used to support deterministically testing backpressure behavior. +func WriteDelay(delay time.Duration) func(*ReceiptStore) { + return func(s *ReceiptStore) { + s.writeDelay = delay + } +} + func NewReceiptStore(filePath string, opts ...func(*ReceiptStore)) (*ReceiptStore, error) { if filePath == "" { filePath = "datafog_receipts.jsonl" @@ -52,20 +115,49 @@ func NewReceiptStore(filePath string, opts ...func(*ReceiptStore)) (*ReceiptStor } store := &ReceiptStore{ - filePath: filePath, - receipts: map[string]models.Receipt{}, + filePath: filePath, + receipts: map[string]models.Receipt{}, + flushInterval: defaultReceiptFlushInterval, + batchWrites: defaultReceiptBatchWrites, + queueTimeout: defaultReceiptQueueTimeout, + writeQueue: make(chan queuedReceipt, defaultReceiptWriteQueueSize), + closeCh: make(chan struct{}), + closedCh: make(chan struct{}), } + for _, opt := range opts { opt(store) } - f, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDONLY, defaultReceiptFileMode) // #nosec G304 -- receipt path is validated from startup configuration. - if err != nil { - return nil, err + + if store.flushInterval <= 0 { + store.flushInterval = defaultReceiptFlushInterval + } + if store.batchWrites <= 0 { + store.batchWrites = defaultReceiptBatchWrites + } + if store.queueTimeout <= 0 { + store.queueTimeout = defaultReceiptQueueTimeout + } + if store.writeQueue == nil { + store.writeQueue = make(chan queuedReceipt, defaultReceiptWriteQueueSize) } - f.Close() + if store.closeCh == nil { + store.closeCh = make(chan struct{}) + } + if store.closedCh == nil { + store.closedCh = make(chan struct{}) + } + if err := store.loadExistingReceipts(); err != nil { return nil, err } + + if err := store.openWriter(); err != nil { + return nil, err + } + + store.workerWg.Add(1) + go store.writeLoop() return store, nil } @@ -91,51 +183,171 @@ func (s *ReceiptStore) NewReceipt(req models.DecideRequest, decision models.Deci func (s *ReceiptStore) Save(receipt models.Receipt) (models.Receipt, error) { s.mu.Lock() - defer s.mu.Unlock() - + if s.closed { + s.mu.Unlock() + return receipt, errStoreClosed + } if receipt.ReceiptID == "" { receipt.ReceiptID = newID() } - // Rotate if we've hit the max + rotate := false if s.maxEntries > 0 && s.entryCount >= s.maxEntries { - if err := s.rotateLocked(); err != nil { - return models.Receipt{}, fmt.Errorf("receipt rotation failed: %w", err) + s.receipts = map[string]models.Receipt{} + s.entryCount = 0 + rotate = true + } + + s.receipts[receipt.ReceiptID] = receipt + s.entryCount++ + s.mu.Unlock() + + writeTask := queuedReceipt{receipt: receipt, rotate: rotate} + select { + case s.writeQueue <- writeTask: + return receipt, nil + case <-s.closeCh: + return receipt, errStoreClosed + case <-time.After(s.queueTimeout): + return receipt, errReceiptWriteQueueFull + } +} + +func (s *ReceiptStore) Close() error { + s.mu.Lock() + if s.closed { + s.mu.Unlock() + <-s.closedCh + return nil + } + s.closed = true + close(s.closeCh) + s.mu.Unlock() + + s.workerWg.Wait() + <-s.closedCh + return nil +} + +// rotateAndOpen archives the current receipts file and starts fresh. +func (s *ReceiptStore) rotateAndOpen() error { + if err := s.flushWriter(); err != nil { + return err + } + if err := s.closeWriter(); err != nil { + return err + } + + archivePath := s.filePath + "." + time.Now().UTC().Format("20060102T150405Z") + if err := os.Rename(s.filePath, archivePath); err != nil && !os.IsNotExist(err) { + return err + } + + return s.openWriter() +} + +func (s *ReceiptStore) writeLoop() { + defer s.workerWg.Done() + defer close(s.closedCh) + + ticker := time.NewTicker(s.flushInterval) + defer ticker.Stop() + + pendingWrites := 0 + for { + select { + case <-ticker.C: + if pendingWrites > 0 { + _ = s.flushWriter() + pendingWrites = 0 + } + case writeTask := <-s.writeQueue: + if err := s.applyWriteTask(writeTask); err != nil { + // Drop task-specific persistence errors for now; in-memory state remains. + // Callers can observe filesystem health by checking write throughput externally. + } + pendingWrites++ + if pendingWrites >= s.batchWrites { + _ = s.flushWriter() + pendingWrites = 0 + } + case <-s.closeCh: + for { + select { + case writeTask := <-s.writeQueue: + if err := s.applyWriteTask(writeTask); err != nil { + } + pendingWrites++ + default: + goto drained + } + } + drained: + if pendingWrites > 0 { + _ = s.flushWriter() + } + _ = s.closeWriter() + return } } +} - data, err := json.Marshal(receipt) - if err != nil { - return models.Receipt{}, err +func (s *ReceiptStore) applyWriteTask(writeTask queuedReceipt) error { + if writeTask.rotate { + if err := s.rotateAndOpen(); err != nil { + return err + } } - f, err := os.OpenFile(s.filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, defaultReceiptFileMode) // #nosec G304 -- receipt path is validated from startup configuration. + if s.writeDelay > 0 { + time.Sleep(s.writeDelay) + } + + payload, err := json.Marshal(writeTask.receipt) if err != nil { - return models.Receipt{}, err + return err } - defer f.Close() - if _, err := f.Write(appendWithLine(data)); err != nil { - return models.Receipt{}, err + next := appendWithLine(payload) + _, err = s.writer.Write(next) + return err +} + +func (s *ReceiptStore) flushWriter() error { + if s.writer == nil || s.file == nil { + return nil } - if err := f.Sync(); err != nil { - return models.Receipt{}, err + if err := s.writer.Flush(); err != nil { + return err } - - s.receipts[receipt.ReceiptID] = receipt - s.entryCount++ - return receipt, nil + return s.file.Sync() } -// rotateLocked archives the current receipts file and starts fresh. -// Must be called with s.mu held. -func (s *ReceiptStore) rotateLocked() error { - archivePath := s.filePath + "." + time.Now().UTC().Format("20060102T150405Z") - if err := os.Rename(s.filePath, archivePath); err != nil && !os.IsNotExist(err) { +func (s *ReceiptStore) openWriter() error { + file, err := os.OpenFile(s.filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, defaultReceiptFileMode) // #nosec G304 -- receipt path is validated from startup configuration. + if err != nil { return err } - s.receipts = map[string]models.Receipt{} - s.entryCount = 0 + + s.file = file + s.writer = bufio.NewWriter(file) + return nil +} + +func (s *ReceiptStore) closeWriter() error { + if s.writer != nil { + if err := s.writer.Flush(); err != nil { + return err + } + s.writer = nil + } + if s.file != nil { + if err := s.file.Close(); err != nil { + s.file = nil + return err + } + s.file = nil + } return nil } @@ -146,9 +358,56 @@ func (s *ReceiptStore) Count() int { return len(s.receipts) } +// List returns receipts sorted newest-first and optional filters. +func (s *ReceiptStore) List(q ListQuery) ([]models.Receipt, int) { + s.mu.RLock() + defer s.mu.RUnlock() + + limit := q.Limit + if limit <= 0 { + limit = 100 + } + if limit > 1000 { + limit = 1000 + } + + results := make([]models.Receipt, 0, len(s.receipts)) + for _, receipt := range s.receipts { + if q.Decision != "" && !strings.EqualFold(string(receipt.Decision), q.Decision) { + continue + } + if q.ActionType != "" && !strings.EqualFold(receipt.Action.Type, q.ActionType) { + continue + } + if q.After != nil && !receipt.Timestamp.After(*q.After) { + continue + } + if q.Before != nil && !receipt.Timestamp.Before(*q.Before) { + continue + } + results = append(results, receipt) + } + + sort.Slice(results, func(i, j int) bool { + if results[i].Timestamp.Equal(results[j].Timestamp) { + return results[i].ReceiptID < results[j].ReceiptID + } + return results[i].Timestamp.After(results[j].Timestamp) + }) + + total := len(results) + if len(results) > limit { + results = results[:limit] + } + return results, total +} + func (s *ReceiptStore) loadExistingReceipts() error { f, err := os.OpenFile(s.filePath, os.O_RDONLY, defaultReceiptFileMode) // #nosec G304 -- receipt path is validated from startup configuration. if err != nil { + if os.IsNotExist(err) { + return nil + } return err } defer f.Close() @@ -157,6 +416,8 @@ func (s *ReceiptStore) loadExistingReceipts() error { scanner.Buffer(make([]byte, 64*1024), maxReceiptLineBytes) badLines := 0 + s.mu.Lock() + defer s.mu.Unlock() for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if line == "" { diff --git a/internal/receipts/store_test.go b/internal/receipts/store_test.go index 4fc0dda..97765c9 100644 --- a/internal/receipts/store_test.go +++ b/internal/receipts/store_test.go @@ -2,9 +2,11 @@ package receipts import ( "encoding/json" + "errors" "os" "strings" "testing" + "time" "github.com/datafog/datafog-api/internal/models" "github.com/datafog/datafog-api/internal/policy" @@ -16,6 +18,9 @@ func TestReceiptStoreSaveAndGet(t *testing.T) { if err != nil { t.Fatalf("new store failed: %v", err) } + t.Cleanup(func() { + _ = store.Close() + }) req := models.DecideRequest{RequestID: "r1", Action: models.ActionMeta{Type: "file.read", Resource: "x"}} result := policy.DecisionResult{Decision: models.DecisionAllow, MatchedRules: []string{"allow-1"}} @@ -63,6 +68,9 @@ func TestReceiptStoreLoadsExistingReceipts(t *testing.T) { if err != nil { t.Fatalf("new store failed: %v", err) } + t.Cleanup(func() { + _ = store.Close() + }) got, ok := store.Get("receipt-seeded") if !ok { t.Fatalf("expected to load existing receipt") @@ -92,6 +100,9 @@ func TestReceiptStoreSkipsCorruptReceiptLines(t *testing.T) { if err != nil { t.Fatalf("expected corrupt+good receipts to load, got %v", err) } + t.Cleanup(func() { + _ = store.Close() + }) if got, ok := store.Get("receipt-good"); !ok { t.Fatalf("expected good receipt loaded") } else if got.ReceiptID != "receipt-good" { @@ -141,6 +152,9 @@ func TestReceiptStoreLoadsLargeReceiptLine(t *testing.T) { if err != nil { t.Fatalf("new store failed: %v", err) } + t.Cleanup(func() { + _ = store.Close() + }) got, ok := store.Get("receipt-large") if !ok { t.Fatalf("expected to load large receipt") @@ -149,3 +163,137 @@ func TestReceiptStoreLoadsLargeReceiptLine(t *testing.T) { t.Fatalf("expected loaded receipt id receipt-large, got %q", got.ReceiptID) } } + +func TestReceiptStoreSaveQueueSaturationReturnsBackpressure(t *testing.T) { + path := t.TempDir() + "/receipts.jsonl" + + store, err := NewReceiptStore( + path, + MaxWriteQueueSize(0), + WriteQueueTimeout(5*time.Millisecond), + WriteDelay(20*time.Millisecond), + ) + if err != nil { + t.Fatalf("new store failed: %v", err) + } + t.Cleanup(func() { + _ = store.Close() + }) + + if _, err := store.Save(models.Receipt{ReceiptID: "full-1"}); err != nil { + t.Fatalf("expected first save to enqueue, got %v", err) + } + if _, err := store.Save(models.Receipt{ReceiptID: "full-2"}); !errors.Is(err, errReceiptWriteQueueFull) { + t.Fatalf("expected queue saturation error, got %v", err) + } + if got, ok := store.Get("full-1"); !ok || got.ReceiptID != "full-1" { + t.Fatalf("expected first receipt to remain in-memory, got ok=%v id=%q", ok, got.ReceiptID) + } + if got, ok := store.Get("full-2"); !ok || got.ReceiptID != "full-2" { + t.Fatalf("expected second receipt to remain in-memory despite write backpressure, got ok=%v id=%q", ok, got.ReceiptID) + } +} + +func TestReceiptStoreListFiltersAndLimit(t *testing.T) { + path := t.TempDir() + "/receipts.jsonl" + store, err := NewReceiptStore(path) + if err != nil { + t.Fatalf("new store failed: %v", err) + } + t.Cleanup(func() { + _ = store.Close() + }) + + base := time.Date(2026, 1, 2, 3, 4, 5, 0, time.UTC) + if _, err := store.Save(models.Receipt{ReceiptID: "a-old", Decision: models.DecisionAllow, Action: models.ActionMeta{Type: "file.write"}, Timestamp: base.Add(-2 * time.Hour)}); err != nil { + t.Fatalf("save failed: %v", err) + } + if _, err := store.Save(models.Receipt{ReceiptID: "b-middle", Decision: models.DecisionAllow, Action: models.ActionMeta{Type: "shell.exec"}, Timestamp: base.Add(-time.Minute)}); err != nil { + t.Fatalf("save failed: %v", err) + } + if _, err := store.Save(models.Receipt{ReceiptID: "c-new", Decision: models.DecisionDeny, Action: models.ActionMeta{Type: "file.write"}, Timestamp: base}); err != nil { + t.Fatalf("save failed: %v", err) + } + + entries, total := store.List(ListQuery{Limit: 2, Decision: "allow"}) + if total != 2 { + t.Fatalf("expected 2 total allow matches, got %d", total) + } + if len(entries) != 2 { + t.Fatalf("expected 2 returned entries with default limit 2, got %d", len(entries)) + } + if entries[0].ReceiptID != "b-middle" || entries[1].ReceiptID != "a-old" { + t.Fatalf("expected ordered newest-to-oldest allow receipts, got %s, %s", entries[0].ReceiptID, entries[1].ReceiptID) + } + + limitedEntries, limitedTotal := store.List(ListQuery{Decision: "allow", Limit: 1}) + if limitedTotal != 2 { + t.Fatalf("expected 2 total allow matches for limited query, got %d", limitedTotal) + } + if len(limitedEntries) != 1 { + t.Fatalf("expected 1 returned entry with limit 1, got %d", len(limitedEntries)) + } + if limitedEntries[0].ReceiptID != "b-middle" { + t.Fatalf("expected newest allow receipt first, got %s", limitedEntries[0].ReceiptID) + } + + before := base + actionEntries, total := store.List(ListQuery{ActionType: "file.write", Before: &before, Limit: 10}) + if total != 1 { + t.Fatalf("expected 1 file.write receipt before %s, got %d", before, total) + } + if len(actionEntries) != 1 { + t.Fatalf("expected 1 returned entry, got %d", len(actionEntries)) + } +} + +func TestReceiptStoreCloseDrainsQueuedWrites(t *testing.T) { + path := t.TempDir() + "/receipts.jsonl" + store, err := NewReceiptStore(path) + if err != nil { + t.Fatalf("new store failed: %v", err) + } + + receipts := []models.Receipt{ + {ReceiptID: "flush-1", Decision: models.DecisionAllow}, + {ReceiptID: "flush-2", Decision: models.DecisionAllow}, + {ReceiptID: "flush-3", Decision: models.DecisionAllow}, + } + + for i, receipt := range receipts { + if _, err := store.Save(receipt); err != nil { + t.Fatalf("save %d failed: %v", i, err) + } + } + if err := store.Close(); err != nil { + t.Fatalf("close failed: %v", err) + } + + reopened, err := NewReceiptStore(path) + if err != nil { + t.Fatalf("reopen after close failed: %v", err) + } + t.Cleanup(func() { + _ = reopened.Close() + }) + + for _, id := range []string{"flush-1", "flush-2", "flush-3"} { + if _, ok := reopened.Get(id); !ok { + t.Fatalf("expected receipt %q to be flushed to disk and reloaded", id) + } + } +} + +func TestReceiptStoreSaveAfterCloseReturnsError(t *testing.T) { + path := t.TempDir() + "/receipts.jsonl" + store, err := NewReceiptStore(path) + if err != nil { + t.Fatalf("new store failed: %v", err) + } + if err := store.Close(); err != nil { + t.Fatalf("close failed: %v", err) + } + if _, err := store.Save(models.Receipt{ReceiptID: "closed"}); !errors.Is(err, errStoreClosed) { + t.Fatalf("expected closed store error, got %v", err) + } +} diff --git a/internal/scan/detector.go b/internal/scan/detector.go index 93dfc3c..213764b 100644 --- a/internal/scan/detector.go +++ b/internal/scan/detector.go @@ -1,7 +1,11 @@ package scan import ( + "context" "regexp" + "runtime" + + "golang.org/x/sync/errgroup" "github.com/datafog/datafog-api/internal/models" ) @@ -77,43 +81,93 @@ var defaultScanEntityTypes = []string{ func ScanText(text string, entityFilter []string) []models.ScanFinding { requested := requestedEntitySet(entityFilter) + if text == "" { + return nil + } - findings := make([]models.ScanFinding, 0) - - // Phase 1: Regex engine (fast, always available) + targets := make([]string, 0, len(defaultScanEntityTypes)) for _, entityType := range defaultScanEntityTypes { if shouldRunEntityType(requested, entityType) { - pattern := DefaultEntityPatterns[entityType] - idxs := pattern.Re.FindAllStringIndex(text, -1) - for _, idx := range idxs { - if len(idx) != 2 || idx[0] < 0 || idx[1] < idx[0] { - continue - } - value := text[idx[0]:idx[1]] - if pattern.Validate != nil && !pattern.Validate(value) { - continue - } - findings = append(findings, models.ScanFinding{ - EntityType: entityType, - Value: value, - Start: idx[0], - End: idx[1], - Confidence: DefaultEntityConfidences[entityType], - }) + targets = append(targets, entityType) + } + } + + chunked := make([][]models.ScanFinding, len(targets)) + if len(targets) > 0 { + g, _ := errgroup.WithContext(context.Background()) + limit := scanWorkers() + if limit > 1 { + g.SetLimit(limit) + } + for i, entityType := range targets { + i := i + entityType := entityType + pattern, ok := DefaultEntityPatterns[entityType] + if !ok || pattern.Re == nil { + continue } + g.Go(func() error { + local := findMatchesForPattern(text, entityType, pattern) + chunked[i] = local + return nil + }) } + _ = g.Wait() + } + + findings := make([]models.ScanFinding, 0) + for _, chunk := range chunked { + if len(chunk) == 0 { + continue + } + findings = append(findings, chunk...) } if !shouldRunNERForFilter(requested) { return findings } - - // Phase 2: NER engine (heuristic, when enabled) findings = append(findings, scanNERWithFilter(text, requested)...) - return findings } +func findMatchesForPattern(text string, entityType string, pattern EntityPattern) []models.ScanFinding { + if text == "" || pattern.Re == nil { + return nil + } + + idxs := pattern.Re.FindAllStringIndex(text, -1) + if len(idxs) == 0 { + return nil + } + + results := make([]models.ScanFinding, 0, len(idxs)) + for _, idx := range idxs { + if len(idx) != 2 || idx[0] < 0 || idx[1] < idx[0] { + continue + } + value := text[idx[0]:idx[1]] + if pattern.Validate != nil && !pattern.Validate(value) { + continue + } + results = append(results, models.ScanFinding{ + EntityType: entityType, + Value: value, + Start: idx[0], + End: idx[1], + Confidence: DefaultEntityConfidences[entityType], + }) + } + return results +} + +func scanWorkers() int { + workers := runtime.GOMAXPROCS(0) + if workers < 1 { + return 1 + } + return workers +} + // luhnValid implements the Luhn algorithm to validate credit card numbers. // It strips spaces and dashes before checking. func luhnValid(s string) bool { diff --git a/internal/scan/ner.go b/internal/scan/ner.go index 673a4a7..3d43ab4 100644 --- a/internal/scan/ner.go +++ b/internal/scan/ner.go @@ -1,9 +1,14 @@ package scan import ( + "context" + "runtime" "strings" + "sync" "unicode" + "golang.org/x/sync/semaphore" + "github.com/datafog/datafog-api/internal/models" ) @@ -116,9 +121,67 @@ var wellKnownLocations = map[string]bool{ "japan": true, "china": true, "india": true, "brazil": true, } -// NEREnabled controls whether the NER engine runs. Can be toggled via env var. +// NEREnabled controls whether the heuristic NER engine runs. +// It can be toggled via env var by caller-level config. var NEREnabled = true +var ( + nerOnce sync.Once + nerLimiterMu sync.Mutex + nerLimiter *semaphore.Weighted +) + +func ensureNERState() { + nerOnce.Do(func() { + personTriggers = normalizeTokenMap(personTriggers) + orgSuffixes = normalizeTokenMap(orgSuffixes) + locationTriggers = normalizeTokenMap(locationTriggers) + wellKnownLocations = normalizeTokenMap(wellKnownLocations) + commonFirstNames = normalizeTokenMap(commonFirstNames) + + workers := runtime.GOMAXPROCS(0) + if workers < 1 { + workers = 1 + } + numWorkers := int64(workers) + nerLimiterMu.Lock() + nerLimiter = semaphore.NewWeighted(numWorkers) + nerLimiterMu.Unlock() + }) +} + +func normalizeTokenMap(src map[string]bool) map[string]bool { + normalized := make(map[string]bool, len(src)) + for token := range src { + normalized[strings.ToLower(strings.TrimSpace(token))] = true + } + return normalized +} + +func acquireNERSlot() error { + ensureNERState() + nerLimiterMu.Lock() + lim := nerLimiter + nerLimiterMu.Unlock() + if lim == nil { + return nil + } + return lim.Acquire(context.Background(), 1) +} + +func releaseNERSlot() { + nerLimiterMu.Lock() + lim := nerLimiter + nerLimiterMu.Unlock() + if lim != nil { + lim.Release(1) + } +} + +// NEREnabled controls whether the NER engine runs. Can be toggled via env var. +// Kept for backward compatibility with existing tests and integrations. +// var NEREnabled = true + // ScanNER runs the heuristic NER engine over text and returns findings // for person, organization, and location entities. func ScanNER(text string, entityFilter []string) []models.ScanFinding { @@ -127,6 +190,9 @@ func ScanNER(text string, entityFilter []string) []models.ScanFinding { } requested := requestedEntitySet(entityFilter) + if len(requested) > 0 && !shouldRunNERForFilter(requested) { + return nil + } return scanNERWithFilter(text, requested) } @@ -134,8 +200,13 @@ func scanNERWithFilter(text string, requested map[string]struct{}) []models.Scan if !NEREnabled { return nil } + ensureNERState() + if err := acquireNERSlot(); err != nil { + return nil + } + defer releaseNERSlot() - if len(requested) > 0 && !shouldRunNERForFilter(requested) { + if !hasRequestedNERFilter(requested) { return nil } @@ -198,6 +269,10 @@ func scanNERWithFilter(text string, requested map[string]struct{}) []models.Scan orgSpan = span[1:] } } + if len(orgSpan) == 0 { + i += len(span) - 1 + continue + } orgText := buildSpanText(text, orgSpan) findings = append(findings, models.ScanFinding{ EntityType: "organization", @@ -268,6 +343,22 @@ func scanNERWithFilter(text string, requested map[string]struct{}) []models.Scan return findings } +func hasRequestedNERFilter(requested map[string]struct{}) bool { + if len(requested) == 0 { + return true + } + if _, ok := requested["person"]; ok { + return true + } + if _, ok := requested["organization"]; ok { + return true + } + if _, ok := requested["location"]; ok { + return true + } + return false +} + type tokenInfo struct { text string start int diff --git a/internal/server/admin.go b/internal/server/admin.go new file mode 100644 index 0000000..6cb365a --- /dev/null +++ b/internal/server/admin.go @@ -0,0 +1,35 @@ +package server + +import ( + "net/http" + "os" +) + +// AdminHandler serves a lightweight read-only admin dashboard page. +type AdminHandler struct { + adminHTML []byte +} + +// NewAdminHandler creates an admin dashboard handler backed by the given HTML asset. +func NewAdminHandler(htmlPath string) (*AdminHandler, error) { + if htmlPath == "" { + htmlPath = "docs/admin.html" + } + + html, err := os.ReadFile(htmlPath) + if err != nil { + return nil, err + } + + return &AdminHandler{adminHTML: html}, nil +} + +// Register adds the admin endpoint to the given mux. +func (d *AdminHandler) Register(mux *http.ServeMux) { + mux.HandleFunc("/admin", d.handleAdminPage) +} + +func (d *AdminHandler) handleAdminPage(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.Write(d.adminHTML) +} diff --git a/internal/server/metrics.go b/internal/server/metrics.go new file mode 100644 index 0000000..94fd750 --- /dev/null +++ b/internal/server/metrics.go @@ -0,0 +1,166 @@ +package server + +import ( + "fmt" + "hash/fnv" + "sync" + "sync/atomic" + "time" +) + +const metricShardCount = 16 + +type metricCounterMap struct { + shards [metricShardCount]metricCounterShard +} + +type metricCounterShard struct { + mu sync.RWMutex + counters map[string]*atomic.Int64 +} + +func newMetricCounterMap() *metricCounterMap { + m := &metricCounterMap{} + for i := 0; i < metricShardCount; i++ { + m.shards[i].counters = make(map[string]*atomic.Int64) + } + return m +} + +func (m *metricCounterMap) add(key string, delta int64) { + sh := m.shard(key) + counter := sh.loadCounter(key) + if counter == nil { + counter = sh.initCounter(key) + } + counter.Add(delta) +} + +func (m *metricCounterMap) load(key string) int64 { + sh := m.shard(key) + sh.mu.RLock() + counter := sh.counters[key] + sh.mu.RUnlock() + if counter == nil { + return 0 + } + return counter.Load() +} + +func (m *metricCounterMap) snapshot() map[string]int64 { + out := make(map[string]int64) + for i := 0; i < metricShardCount; i++ { + sh := &m.shards[i] + sh.mu.RLock() + for key, counter := range sh.counters { + out[key] = counter.Load() + } + sh.mu.RUnlock() + } + return out +} + +func (m *metricCounterMap) shard(key string) *metricCounterShard { + h := fnv.New32a() + _, _ = h.Write([]byte(key)) + return &m.shards[h.Sum32()%metricShardCount] +} + +func (s *metricCounterShard) loadCounter(key string) *atomic.Int64 { + s.mu.RLock() + counter := s.counters[key] + s.mu.RUnlock() + return counter +} + +func (s *metricCounterShard) initCounter(key string) *atomic.Int64 { + s.mu.Lock() + defer s.mu.Unlock() + if existing := s.counters[key]; existing != nil { + return existing + } + counter := &atomic.Int64{} + s.counters[key] = counter + return counter +} + +func (s *metricCounterMap) len() int { + count := 0 + for i := 0; i < metricShardCount; i++ { + sh := &s.shards[i] + sh.mu.RLock() + count += len(sh.counters) + sh.mu.RUnlock() + } + return count +} + +func intKey(value int) string { + return fmt.Sprintf("%d", value) +} + +type requestMetrics struct { + totalCount atomic.Int64 + errorCount atomic.Int64 + totalLatencyNs atomic.Int64 + pathHits *metricCounterMap + methodHits *metricCounterMap + statusHits *metricCounterMap + pathLatencyNs *metricCounterMap + pathLatencyHits *metricCounterMap +} + +func newRequestMetrics() *requestMetrics { + return &requestMetrics{ + pathHits: newMetricCounterMap(), + methodHits: newMetricCounterMap(), + statusHits: newMetricCounterMap(), + pathLatencyNs: newMetricCounterMap(), + pathLatencyHits: newMetricCounterMap(), + } +} + +func (m *requestMetrics) record(method string, route string, status int, latencyNs int64) { + m.totalCount.Add(1) + m.totalLatencyNs.Add(latencyNs) + if status >= 400 { + m.errorCount.Add(1) + } + m.methodHits.add(method, 1) + m.pathHits.add(route, 1) + m.statusHits.add(intKey(status), 1) + m.pathLatencyNs.add(route, latencyNs) + m.pathLatencyHits.add(route, 1) +} + +func (m *requestMetrics) snapshot() (total int64, errorCount int64, avgLatencyMs float64, byMethod, byPath, byStatus map[string]int64, byPathLatency map[string]float64) { + count := m.totalCount.Load() + errorCount = m.errorCount.Load() + latencyNs := m.totalLatencyNs.Load() + if count > 0 { + avgLatencyMs = float64(latencyNs) / float64(count) / float64(1e6) + if avgLatencyMs <= 0 { + avgLatencyMs = float64(time.Nanosecond) / float64(time.Millisecond) + } + } + + byMethod = m.methodHits.snapshot() + byPath = m.pathHits.snapshot() + byStatus = m.statusHits.snapshot() + + pathNs := m.pathLatencyNs.snapshot() + pathHits := m.pathLatencyHits.snapshot() + byPathLatency = make(map[string]float64) + for path, ns := range pathNs { + hits := pathHits[path] + if hits == 0 { + continue + } + v := float64(ns) / float64(hits) / float64(1e6) + if v <= 0 { + v = float64(time.Nanosecond) / float64(time.Millisecond) + } + byPathLatency[path] = v + } + return count, errorCount, avgLatencyMs, byMethod, byPath, byStatus, byPathLatency +} diff --git a/internal/server/server.go b/internal/server/server.go index fb1e64c..f34c246 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,6 +1,7 @@ package server import ( + "compress/gzip" "context" "crypto/rand" "crypto/sha256" @@ -25,30 +26,25 @@ import ( "github.com/datafog/datafog-api/internal/scan" "github.com/datafog/datafog-api/internal/shim" "github.com/datafog/datafog-api/internal/transform" + "github.com/datafog/datafog-api/internal/types/ttlcache" ) type Server struct { - policy models.Policy - store *receipts.ReceiptStore - eventReader shim.EventReader - apiToken string - rateLimiter *tokenBucket - startedAt time.Time - logger *log.Logger - mu sync.Mutex - statsMu sync.Mutex - decisions map[string]idempotentDecision - scans map[string]idempotentCachedResponse - transforms map[string]idempotentCachedResponse - anonymizes map[string]idempotentCachedResponse - totalCount int64 - errorCount int64 - statusHits map[int]int64 - pathHits map[string]int64 - methodHits map[string]int64 - totalLatencyNs int64 - pathLatencyNs map[string]int64 - pathLatencyCounts map[string]int64 + policy models.Policy + policyIndex *policy.PolicyIndex + store *receipts.ReceiptStore + eventReader shim.EventReader + eventSink shim.DecisionEventSink + apiToken string + rateLimiter *tokenBucket + startedAt time.Time + logger *log.Logger + + decisions *ttlcache.Cache[idempotentDecisionResponse] + scans *ttlcache.Cache[idempotentCachedResponse] + transforms *ttlcache.Cache[idempotentCachedResponse] + anonymizes *ttlcache.Cache[idempotentCachedResponse] + metrics *requestMetrics } type requestIDContextKey struct{} @@ -70,11 +66,67 @@ func (w *responseStatusWriter) Write(body []byte) (int, error) { return w.ResponseWriter.Write(body) } +type gzipResponseWriter struct { + ResponseWriter http.ResponseWriter + writer *gzip.Writer + wroteHeader bool + status int +} + +func (w *gzipResponseWriter) Header() http.Header { + return w.ResponseWriter.Header() +} + +func (w *gzipResponseWriter) WriteHeader(statusCode int) { + if w.wroteHeader { + return + } + w.wroteHeader = true + w.status = statusCode + + w.ResponseWriter.Header().Del("Content-Length") + w.ResponseWriter.Header().Set("Content-Encoding", "gzip") + w.ResponseWriter.Header().Set("Vary", "Accept-Encoding") + if w.writer == nil { + writer := gzip.NewWriter(w.ResponseWriter) + w.writer = writer + } + w.ResponseWriter.WriteHeader(statusCode) +} + +func (w *gzipResponseWriter) Write(body []byte) (int, error) { + if w.status == 0 { + w.WriteHeader(http.StatusOK) + } + if w.writer == nil { + w.writer = gzip.NewWriter(w.ResponseWriter) + } + return w.writer.Write(body) +} + +func (w *gzipResponseWriter) Close() error { + if w.writer == nil { + return nil + } + return w.writer.Close() +} + +func shouldCompressResponse(r *http.Request) bool { + enc := strings.ToLower(r.Header.Get("Accept-Encoding")) + if enc == "" { + return false + } + return strings.Contains(enc, "gzip") +} + const ( maxRequestBodyBytes int64 = 1024 * 1024 // 1 MiB + + defaultIdempotencyCacheSize = 4096 + defaultIdempotencyTTL = 30 * time.Minute ) -type idempotentDecision struct { +type idempotentDecisionResponse struct { requestHash string response models.DecideResponse } @@ -103,31 +155,77 @@ func New(policyData models.Policy, store *receipts.ReceiptStore, logger *log.Log } policyData = policy.NormalizeForEvaluation(policyData) return &Server{ - policy: policyData, - store: store, - apiToken: apiToken, - rateLimiter: newTokenBucket(rateLimitRPS), - startedAt: time.Now().UTC(), - logger: logger, - decisions: map[string]idempotentDecision{}, - scans: map[string]idempotentCachedResponse{}, - transforms: map[string]idempotentCachedResponse{}, - anonymizes: map[string]idempotentCachedResponse{}, - statusHits: map[int]int64{}, - pathHits: map[string]int64{}, - methodHits: map[string]int64{}, - totalLatencyNs: 0, - pathLatencyNs: map[string]int64{}, - pathLatencyCounts: map[string]int64{}, + policy: policyData, + policyIndex: policy.BuildPolicyIndex(policyData), + store: store, + apiToken: apiToken, + rateLimiter: newTokenBucket(rateLimitRPS), + startedAt: time.Now().UTC(), + logger: logger, + decisions: ttlcache.New[idempotentDecisionResponse](defaultIdempotencyCacheSize, defaultIdempotencyTTL), + scans: ttlcache.New[idempotentCachedResponse](defaultIdempotencyCacheSize, defaultIdempotencyTTL), + transforms: ttlcache.New[idempotentCachedResponse](defaultIdempotencyCacheSize, defaultIdempotencyTTL), + anonymizes: ttlcache.New[idempotentCachedResponse](defaultIdempotencyCacheSize, defaultIdempotencyTTL), + metrics: newRequestMetrics(), } } func (s *Server) SetEventReader(reader shim.EventReader) { s.eventReader = reader + if sink, ok := reader.(shim.DecisionEventSink); ok { + s.eventSink = sink + } +} + +func (s *Server) Shutdown(ctx context.Context) error { + if ctx == nil { + ctx = context.Background() + } + done := make(chan error, 1) + go func() { + if s.decisions != nil { + s.decisions.Close() + } + if s.scans != nil { + s.scans.Close() + } + if s.transforms != nil { + s.transforms.Close() + } + if s.anonymizes != nil { + s.anonymizes.Close() + } + if s.eventSink != nil { + if closer, ok := s.eventSink.(interface{ Close() error }); ok { + _ = closer.Close() + } + } + if s.store != nil { + _ = s.store.Close() + } + done <- nil + }() + + select { + case err := <-done: + return err + case <-ctx.Done(): + return ctx.Err() + } } // HandlerWithDemo returns the HTTP handler with optional demo endpoints registered. func (s *Server) HandlerWithDemo(demo *DemoHandler) http.Handler { + return s.HandlerWithDemoAndAdmin(demo, nil) +} + +// HandlerWithAdmin returns the HTTP handler with optional admin endpoints registered. +func (s *Server) HandlerWithAdmin(admin *AdminHandler) http.Handler { + return s.HandlerWithDemoAndAdmin(nil, admin) +} + +// HandlerWithDemoAndAdmin returns the HTTP handler with optional demo and admin endpoints. +func (s *Server) HandlerWithDemoAndAdmin(demo *DemoHandler, admin *AdminHandler) http.Handler { mux := http.NewServeMux() mux.HandleFunc("/health", s.handleHealth) mux.HandleFunc("/v1/policy/version", s.handlePolicyVersion) @@ -135,17 +233,21 @@ func (s *Server) HandlerWithDemo(demo *DemoHandler) http.Handler { mux.HandleFunc("/v1/decide", s.handleDecide) mux.HandleFunc("/v1/transform", s.handleTransform) mux.HandleFunc("/v1/anonymize", s.handleAnonymize) + mux.HandleFunc("/v1/receipts", s.handleReceipts) mux.HandleFunc("/v1/receipts/", s.handleReceipt) mux.HandleFunc("/v1/events", s.handleEvents) mux.HandleFunc("/metrics", s.handleMetrics) if demo != nil { demo.Register(mux) } + if admin != nil { + admin.Register(mux) + } return s.wrapMiddleware(mux) } func (s *Server) Handler() http.Handler { - return s.HandlerWithDemo(nil) + return s.HandlerWithDemoAndAdmin(nil, nil) } func (s *Server) wrapMiddleware(mux *http.ServeMux) http.Handler { @@ -169,25 +271,34 @@ func (s *Server) wrapMiddleware(mux *http.ServeMux) http.Handler { r = r.WithContext(context.WithValue(r.Context(), requestIDContextKey{}, reqID)) w.Header().Set("X-Request-ID", reqID) - responseWriter := &responseStatusWriter{ResponseWriter: w} + statusWriter := &responseStatusWriter{ResponseWriter: w} + responseWriter := http.ResponseWriter(statusWriter) + if shouldCompressResponse(r) { + responseWriter = &gzipResponseWriter{ResponseWriter: statusWriter} + } startedAt := time.Now() handler, pattern := mux.Handler(r) defer func() { latency := time.Since(startedAt) if rec := recover(); rec != nil { - responseWriter.status = http.StatusInternalServerError + statusWriter.status = http.StatusInternalServerError s.logger.Printf("request panic request_id=%s method=%s path=%s err=%v", reqID, r.Method, r.URL.Path, rec) - s.respondError(responseWriter, http.StatusInternalServerError, models.APIError{Code: "internal_error", Message: "internal server error", RequestID: reqID}) + s.respondError(statusWriter, http.StatusInternalServerError, models.APIError{Code: "internal_error", Message: "internal server error", RequestID: reqID}) } - if responseWriter.status == 0 { - responseWriter.status = http.StatusOK + if statusWriter.status == 0 { + statusWriter.status = http.StatusOK } if pattern == "" { - s.recordRequestMetrics(r.Method, "/_not_found", responseWriter.status, latency) + s.recordRequestMetrics(r.Method, "/_not_found", statusWriter.status, latency) } else { - s.recordRequestMetrics(r.Method, canonicalizedRoute(pattern, r.URL.Path), responseWriter.status, latency) + s.recordRequestMetrics(r.Method, canonicalizedRoute(pattern, r.URL.Path), statusWriter.status, latency) + } + s.logger.Printf("request complete request_id=%s method=%s path=%s status=%d latency_ms=%d", reqID, r.Method, r.URL.Path, statusWriter.status, latency.Milliseconds()) + if gzipWriter, ok := responseWriter.(*gzipResponseWriter); ok { + if err := gzipWriter.Close(); err != nil { + s.logger.Printf("gzip close failed request_id=%s: %v", reqID, err) + } } - s.logger.Printf("request complete request_id=%s method=%s path=%s status=%d latency_ms=%d", reqID, r.Method, r.URL.Path, responseWriter.status, latency.Milliseconds()) }() if !s.authorized(r) { @@ -286,20 +397,7 @@ func canonicalizedRoute(pattern string, path string) string { } func (s *Server) recordRequestMetrics(method string, route string, status int, latency time.Duration) { - s.statsMu.Lock() - defer s.statsMu.Unlock() - s.totalCount++ - s.methodHits[method]++ - s.pathHits[route]++ - s.statusHits[status]++ - if status >= 400 { - s.errorCount++ - } - - latencyNs := latency.Nanoseconds() - s.totalLatencyNs += latencyNs - s.pathLatencyNs[route] += latencyNs - s.pathLatencyCounts[route]++ + s.metrics.record(method, route, status, latency.Nanoseconds()) } func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { @@ -353,15 +451,12 @@ func (s *Server) handleScan(w http.ResponseWriter, r *http.Request) { s.respondError(w, http.StatusBadRequest, models.APIError{Code: "invalid_request", Message: "unable to hash request payload", Details: err.Error(), RequestID: requestID(r)}) return } - s.mu.Lock() - existing, ok := s.scans[req.IdempotencyKey] - s.mu.Unlock() - if ok { - if existing.requestHash != reqHash { + if cached, ok := s.scans.Get(req.IdempotencyKey); ok { + if cached.requestHash != reqHash { s.respondError(w, http.StatusConflict, models.APIError{Code: "idempotency_conflict", Message: "different request payload for same idempotency_key", RequestID: requestID(r)}) return } - s.respondRaw(w, existing.status, existing.body) + s.respondRaw(w, cached.status, cached.body) return } } @@ -381,13 +476,11 @@ func (s *Server) handleScan(w http.ResponseWriter, r *http.Request) { return } hash, _ := hashScanRequest(req) - s.mu.Lock() - s.scans[req.IdempotencyKey] = idempotentCachedResponse{ + s.scans.Set(req.IdempotencyKey, idempotentCachedResponse{ requestHash: hash, body: body, status: http.StatusOK, - } - s.mu.Unlock() + }) s.respondRaw(w, http.StatusOK, body) return } @@ -420,15 +513,12 @@ func (s *Server) handleDecide(w http.ResponseWriter, r *http.Request) { s.respondError(w, http.StatusBadRequest, models.APIError{Code: "invalid_request", Message: "unable to hash request payload", Details: err.Error(), RequestID: requestID(r)}) return } - s.mu.Lock() - existing, ok := s.decisions[req.IdempotencyKey] - s.mu.Unlock() - if ok { - if existing.requestHash != reqHash { + if cached, ok := s.decisions.Get(req.IdempotencyKey); ok { + if cached.requestHash != reqHash { s.respondError(w, http.StatusConflict, models.APIError{Code: "idempotency_conflict", Message: "different request payload for same idempotency_key", RequestID: requestID(r)}) return } - s.respond(w, http.StatusOK, existing.response) + s.respond(w, http.StatusOK, cached.response) return } } @@ -437,7 +527,7 @@ func (s *Server) handleDecide(w http.ResponseWriter, r *http.Request) { if len(findings) == 0 && req.Text != "" { findings = scan.ScanText(req.Text, nil) } - result := policy.EvaluateSorted(s.policy, policy.DecisionContext{Action: req.Action, Findings: findings}) + result := policy.EvaluateWithIndex(s.policy, s.policyIndex, policy.DecisionContext{Action: req.Action, Findings: findings}) actionHash, err := hashDecideAction(req.Action) if err != nil { s.respondError(w, http.StatusInternalServerError, models.APIError{Code: "hash_error", Message: "unable to hash action", Details: err.Error(), RequestID: requestID(r)}) @@ -480,12 +570,10 @@ func (s *Server) handleDecide(w http.ResponseWriter, r *http.Request) { } if req.IdempotencyKey != "" { hash, _ := hashDecideRequest(req) - s.mu.Lock() - s.decisions[req.IdempotencyKey] = idempotentDecision{ + s.decisions.Set(req.IdempotencyKey, idempotentDecisionResponse{ requestHash: hash, response: res, - } - s.mu.Unlock() + }) } s.respond(w, http.StatusOK, res) } @@ -540,15 +628,12 @@ func (s *Server) handleTransform(w http.ResponseWriter, r *http.Request) { s.respondError(w, http.StatusBadRequest, models.APIError{Code: "invalid_request", Message: "unable to hash request payload", Details: err.Error(), RequestID: requestID(r)}) return } - s.mu.Lock() - existing, ok := s.transforms[req.IdempotencyKey] - s.mu.Unlock() - if ok { - if existing.requestHash != reqHash { + if cached, ok := s.transforms.Get(req.IdempotencyKey); ok { + if cached.requestHash != reqHash { s.respondError(w, http.StatusConflict, models.APIError{Code: "idempotency_conflict", Message: "different request payload for same idempotency_key", RequestID: requestID(r)}) return } - s.respondRaw(w, existing.status, existing.body) + s.respondRaw(w, cached.status, cached.body) return } } @@ -588,13 +673,11 @@ func (s *Server) handleTransform(w http.ResponseWriter, r *http.Request) { return } hash, _ := hashTransformRequest(req) - s.mu.Lock() - s.transforms[req.IdempotencyKey] = idempotentCachedResponse{ + s.transforms.Set(req.IdempotencyKey, idempotentCachedResponse{ requestHash: hash, body: body, status: http.StatusOK, - } - s.mu.Unlock() + }) s.respondRaw(w, http.StatusOK, body) return } @@ -627,15 +710,12 @@ func (s *Server) handleAnonymize(w http.ResponseWriter, r *http.Request) { s.respondError(w, http.StatusBadRequest, models.APIError{Code: "invalid_request", Message: "unable to hash request payload", Details: err.Error(), RequestID: requestID(r)}) return } - s.mu.Lock() - existing, ok := s.anonymizes[req.IdempotencyKey] - s.mu.Unlock() - if ok { - if existing.requestHash != reqHash { + if cached, ok := s.anonymizes.Get(req.IdempotencyKey); ok { + if cached.requestHash != reqHash { s.respondError(w, http.StatusConflict, models.APIError{Code: "idempotency_conflict", Message: "different request payload for same idempotency_key", RequestID: requestID(r)}) return } - s.respondRaw(w, existing.status, existing.body) + s.respondRaw(w, cached.status, cached.body) return } } @@ -671,13 +751,11 @@ func (s *Server) handleAnonymize(w http.ResponseWriter, r *http.Request) { return } hash, _ := hashAnonymizeRequest(req) - s.mu.Lock() - s.anonymizes[req.IdempotencyKey] = idempotentCachedResponse{ + s.anonymizes.Set(req.IdempotencyKey, idempotentCachedResponse{ requestHash: hash, body: body, status: http.StatusOK, - } - s.mu.Unlock() + }) s.respondRaw(w, http.StatusOK, body) return } @@ -742,47 +820,51 @@ func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request) { } func (s *Server) snapshotMetrics() metricsResponse { - s.statsMu.Lock() - defer s.statsMu.Unlock() - byStatus := map[string]int64{} - for status, count := range s.statusHits { - byStatus[strconv.Itoa(status)] = count + total, errors, avgLatency, byMethod, byPath, byStatus, byPathLatency := s.metrics.snapshot() + return metricsResponse{ + TotalRequests: total, + ErrorRequests: errors, + ByStatus: byStatus, + ByPath: byPath, + ByMethod: byMethod, + ByPathLatency: byPathLatency, + AvgLatencyMs: avgLatency, + StartedAt: s.startedAt.Format(time.RFC3339), + UptimeSeconds: time.Since(s.startedAt).Seconds(), } +} - byPath := map[string]int64{} - for path, count := range s.pathHits { - byPath[path] = count +func (s *Server) handleReceipts(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + s.respondError(w, http.StatusMethodNotAllowed, models.APIError{Code: "method_not_allowed", Message: "method must be GET", RequestID: requestID(r)}) + return } - byMethod := map[string]int64{} - for method, count := range s.methodHits { - byMethod[method] = count + q := receipts.ListQuery{Limit: 100} + if limit := strings.TrimSpace(r.URL.Query().Get("limit")); limit != "" { + if n, err := strconv.Atoi(limit); err == nil && n > 0 && n <= 1000 { + q.Limit = n + } } - - byPathLatency := map[string]float64{} - for path, count := range s.pathLatencyCounts { - if count == 0 { - continue + if after := strings.TrimSpace(r.URL.Query().Get("after")); after != "" { + if t, err := time.Parse(time.RFC3339, after); err == nil { + q.After = &t } - byPathLatency[path] = float64(s.pathLatencyNs[path]) / float64(count) / float64(time.Millisecond) } - - avgLatencyMs := 0.0 - if s.totalCount > 0 { - avgLatencyMs = float64(s.totalLatencyNs) / float64(s.totalCount) / float64(time.Millisecond) + if before := strings.TrimSpace(r.URL.Query().Get("before")); before != "" { + if t, err := time.Parse(time.RFC3339, before); err == nil { + q.Before = &t + } } - - return metricsResponse{ - TotalRequests: s.totalCount, - ErrorRequests: s.errorCount, - ByStatus: byStatus, - ByPath: byPath, - ByMethod: byMethod, - ByPathLatency: byPathLatency, - AvgLatencyMs: avgLatencyMs, - StartedAt: s.startedAt.Format(time.RFC3339), - UptimeSeconds: time.Since(s.startedAt).Seconds(), + if decision := strings.TrimSpace(r.URL.Query().Get("decision")); decision != "" { + q.Decision = decision } + if actionType := strings.TrimSpace(r.URL.Query().Get("action_type")); actionType != "" { + q.ActionType = actionType + } + + entries, total := s.store.List(q) + s.respond(w, http.StatusOK, map[string]interface{}{"receipts": entries, "total": total}) } func (s *Server) handleReceipt(w http.ResponseWriter, r *http.Request) { @@ -793,7 +875,7 @@ func (s *Server) handleReceipt(w http.ResponseWriter, r *http.Request) { id := strings.TrimPrefix(r.URL.Path, "/v1/receipts/") if id == "" || strings.Contains(id, "/") { - s.respondError(w, http.StatusNotFound, models.APIError{Code: "not_found", Message: "receipt id missing"}) + s.respondError(w, http.StatusNotFound, models.APIError{Code: "not_found", Message: "receipt id missing", RequestID: requestID(r)}) return } receipt, ok := s.store.Get(id) diff --git a/internal/server/server_benchmark_test.go b/internal/server/server_benchmark_test.go index 318359f..25eb97e 100644 --- a/internal/server/server_benchmark_test.go +++ b/internal/server/server_benchmark_test.go @@ -54,6 +54,9 @@ func benchmarkServer(b *testing.B) *http.Server { if err != nil { b.Fatalf("new store: %v", err) } + b.Cleanup(func() { + _ = store.Close() + }) h := New(testPolicy(), store, log.New(io.Discard, "", 0), "", 0) return &http.Server{Handler: h.Handler()} } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 5181c6f..ef6f59b 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "os" "strings" "testing" "time" @@ -41,6 +42,9 @@ func makeServerWithTokenAndRateLimit(t *testing.T, apiToken string, rateLimitRPS if err != nil { t.Fatalf("new store: %v", err) } + t.Cleanup(func() { + _ = store.Close() + }) h := New(testPolicy(), store, nil, apiToken, rateLimitRPS) return &http.Server{Handler: h.Handler()} } @@ -137,6 +141,88 @@ func TestTokenAuth(t *testing.T) { }) } +func TestAdminAndReceiptListEndpoints(t *testing.T) { + tmp := t.TempDir() + storePath := tmp + "/receipts.jsonl" + store, err := receipts.NewReceiptStore(storePath) + if err != nil { + t.Fatalf("new store: %v", err) + } + t.Cleanup(func() { + _ = store.Close() + }) + + if _, err := store.Save(models.Receipt{ReceiptID: "r1", Decision: models.DecisionAllow, Action: models.ActionMeta{Type: "file.write"}, Timestamp: time.Date(2026, 1, 2, 3, 4, 5, 0, time.UTC)}); err != nil { + t.Fatalf("save receipt failed: %v", err) + } + if _, err := store.Save(models.Receipt{ReceiptID: "r2", Decision: models.DecisionDeny, Action: models.ActionMeta{Type: "shell.exec"}, Timestamp: time.Date(2026, 1, 2, 3, 5, 0, 0, time.UTC)}); err != nil { + t.Fatalf("save receipt failed: %v", err) + } + + h := New(testPolicy(), store, nil, "", 0) + srv := &http.Server{Handler: h.Handler()} + + req := httptest.NewRequest(http.MethodPost, "/v1/receipts", nil) + resp := httptest.NewRecorder() + srv.Handler.ServeHTTP(resp, req) + assertJSONError(t, resp, http.StatusMethodNotAllowed, "method_not_allowed") + + listReq := httptest.NewRequest(http.MethodGet, "/v1/receipts?limit=10", nil) + listResp := httptest.NewRecorder() + srv.Handler.ServeHTTP(listResp, listReq) + if listResp.Code != http.StatusOK { + t.Fatalf("expected 200 for receipt list, got %d", listResp.Code) + } + var got struct { + Receipts []models.Receipt `json:"receipts"` + Total int `json:"total"` + } + if err := json.NewDecoder(listResp.Body).Decode(&got); err != nil { + t.Fatalf("decode failed: %v", err) + } + if got.Total != 2 { + t.Fatalf("expected 2 total receipts, got %d", got.Total) + } + if len(got.Receipts) != 2 { + t.Fatalf("expected 2 receipts in payload, got %d", len(got.Receipts)) + } + if got.Receipts[0].ReceiptID != "r2" { + t.Fatalf("expected newest receipt first, got %s", got.Receipts[0].ReceiptID) + } + + filteredReq := httptest.NewRequest(http.MethodGet, "/v1/receipts?decision=allow&limit=10", nil) + filteredResp := httptest.NewRecorder() + h.Handler().ServeHTTP(filteredResp, filteredReq) + if filteredResp.Code != http.StatusOK { + t.Fatalf("expected 200 for filtered receipt list, got %d", filteredResp.Code) + } + if err := json.NewDecoder(filteredResp.Body).Decode(&got); err != nil { + t.Fatalf("decode filtered payload failed: %v", err) + } + if got.Total != 1 { + t.Fatalf("expected 1 allow receipt, got %d", got.Total) + } + + adminHTMLPath := tmp + "/admin.html" + if err := os.WriteFile(adminHTMLPath, []byte("admin-ui-ok"), 0o644); err != nil { + t.Fatalf("write admin html: %v", err) + } + admin, err := NewAdminHandler(adminHTMLPath) + if err != nil { + t.Fatalf("new admin handler: %v", err) + } + adminServer := &http.Server{Handler: h.HandlerWithAdmin(admin)} + adminReq := httptest.NewRequest(http.MethodGet, "/admin", nil) + adminResp := httptest.NewRecorder() + adminServer.Handler.ServeHTTP(adminResp, adminReq) + if adminResp.Code != http.StatusOK { + t.Fatalf("expected admin endpoint 200, got %d", adminResp.Code) + } + if strings.TrimSpace(adminResp.Body.String()) != "admin-ui-ok" { + t.Fatalf("expected custom admin html body") + } +} + func TestRateLimit(t *testing.T) { server := makeServerWithTokenAndRateLimit(t, "", 2) @@ -719,6 +805,7 @@ func TestValidateMethodAndBadInputs(t *testing.T) { {name: "transform", method: http.MethodGet, path: "/v1/transform", wantStatus: http.StatusMethodNotAllowed}, {name: "anonymize", method: http.MethodGet, path: "/v1/anonymize", wantStatus: http.StatusMethodNotAllowed}, {name: "receipts", method: http.MethodPost, path: "/v1/receipts/abc", wantStatus: http.StatusMethodNotAllowed}, + {name: "receipts_list", method: http.MethodPost, path: "/v1/receipts", wantStatus: http.StatusMethodNotAllowed}, } for _, tc := range tests { @@ -939,6 +1026,9 @@ func TestEventsAdapterFilterCanonicalized(t *testing.T) { if err != nil { t.Fatalf("new store: %v", err) } + t.Cleanup(func() { + _ = s.Close() + }) h := New(testPolicy(), s, nil, "", 0) h.SetEventReader(fakeEventReader{events: []shim.DecisionEvent{ {Tool: "claude", Decision: string(models.DecisionAllow)}, diff --git a/internal/shim/events.go b/internal/shim/events.go index 3d921d3..e94a3a9 100644 --- a/internal/shim/events.go +++ b/internal/shim/events.go @@ -3,7 +3,6 @@ package shim import ( "bufio" "encoding/json" - "fmt" "os" "path/filepath" "strings" @@ -53,38 +52,152 @@ type noopEventSink struct{} func (s noopEventSink) Record(_ DecisionEvent) {} type NDJSONDecisionEventSink struct { - path string - mu sync.Mutex + path string + writes chan DecisionEvent + closed chan struct{} + closeOnce sync.Once + writer *bufio.Writer + file *os.File + flushTimeout time.Duration + mu sync.RWMutex + isClosed bool + writeMu sync.Mutex } func NewNDJSONDecisionEventSink(path string) *NDJSONDecisionEventSink { - return &NDJSONDecisionEventSink{path: path} + sink := &NDJSONDecisionEventSink{ + path: path, + writes: make(chan DecisionEvent, 512), + closed: make(chan struct{}), + flushTimeout: 500 * time.Millisecond, + } + if path == "" { + return sink + } + if err := sink.openWriter(); err != nil { + // Will keep trying when events arrive. + } + go sink.loop() + return sink } -func (s *NDJSONDecisionEventSink) Record(event DecisionEvent) { +func (s *NDJSONDecisionEventSink) Close() error { if s == nil || s.path == "" { - return + return nil } + s.closeOnce.Do(func() { + s.mu.Lock() + s.isClosed = true + s.mu.Unlock() + + close(s.writes) + <-s.closed + }) + return nil +} + +func (s *NDJSONDecisionEventSink) openWriter() error { + if s.path == "" { + return nil + } if err := os.MkdirAll(filepath.Dir(s.path), 0o750); err != nil { - return + return err } - payload, err := json.Marshal(event) + file, err := os.OpenFile(s.path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o600) if err != nil { - return + return err } + s.file = file + s.writer = bufio.NewWriter(file) + return nil +} - s.mu.Lock() - defer s.mu.Unlock() +func (s *NDJSONDecisionEventSink) closeWriter() error { + s.writeMu.Lock() + defer s.writeMu.Unlock() + if s.writer != nil { + if err := s.writer.Flush(); err != nil { + return err + } + s.writer = nil + } + if s.file != nil { + if err := s.file.Close(); err != nil { + s.file = nil + return err + } + s.file = nil + } + return nil +} - file, err := os.OpenFile(s.path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o600) - if err != nil { +func (s *NDJSONDecisionEventSink) flushWriter() error { + s.writeMu.Lock() + defer s.writeMu.Unlock() + if s.writer == nil { + return nil + } + if err := s.writer.Flush(); err != nil { + return err + } + return s.file.Sync() +} + +func (s *NDJSONDecisionEventSink) loop() { + defer close(s.closed) + defer func() { + _ = s.closeWriter() + }() + + ticker := time.NewTicker(s.flushTimeout) + defer ticker.Stop() + + for { + select { + case event, ok := <-s.writes: + if !ok { + _ = s.flushWriter() + return + } + s.writeMu.Lock() + if s.writer == nil { + if err := s.openWriter(); err != nil { + s.writeMu.Unlock() + continue + } + } + payload, err := json.Marshal(event) + if err != nil { + s.writeMu.Unlock() + continue + } + _, _ = s.writer.Write(append(payload, '\n')) + s.writeMu.Unlock() + case <-ticker.C: + _ = s.flushWriter() + } + } +} + +func (s *NDJSONDecisionEventSink) Record(event DecisionEvent) { + if s == nil || s.path == "" { return } - defer file.Close() - _, _ = fmt.Fprintln(file, string(payload)) + s.mu.RLock() + closed := s.isClosed + s.mu.RUnlock() + if closed { + return + } + + select { + case s.writes <- event: + default: + // Drop events when the buffer is full. + } } // Query reads events from the NDJSON file and applies filters. @@ -93,8 +206,8 @@ func (s *NDJSONDecisionEventSink) Query(q EventQuery) ([]DecisionEvent, error) { return nil, nil } - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() f, err := os.Open(s.path) if err != nil { @@ -137,6 +250,9 @@ func (s *NDJSONDecisionEventSink) Query(q EventQuery) ([]DecisionEvent, error) { break } } + if err := scanner.Err(); err != nil { + return nil, err + } - return events, scanner.Err() + return events, nil } diff --git a/internal/shim/events_test.go b/internal/shim/events_test.go new file mode 100644 index 0000000..66b8106 --- /dev/null +++ b/internal/shim/events_test.go @@ -0,0 +1,75 @@ +package shim + +import ( + "testing" + "time" +) + +func TestNDJSONDecisionEventSinkWritesAndQueriesAsync(t *testing.T) { + path := t.TempDir() + "/events.ndjson" + sink := NewNDJSONDecisionEventSink(path) + if sink == nil { + t.Fatalf("expected sink") + } + t.Cleanup(func() { + _ = sink.Close() + }) + + expected := DecisionEvent{ + Timestamp: time.Now().UTC(), + Mode: "observe", + ActionType: "shell.exec", + Tool: "claude", + Decision: "allow", + Allowed: true, + } + sink.Record(expected) + + deadline := time.Now().Add(2 * time.Second) + var got []DecisionEvent + for { + events, err := sink.Query(EventQuery{Limit: 10}) + if err != nil { + t.Fatalf("query failed: %v", err) + } + if len(events) == 1 { + got = events + break + } + if time.Now().After(deadline) { + t.Fatalf("expected async event to be persisted and queryable, got %d events", len(events)) + } + time.Sleep(10 * time.Millisecond) + } + if got[0].Tool != expected.Tool { + t.Fatalf("expected tool %q, got %q", expected.Tool, got[0].Tool) + } +} + +func TestNDJSONDecisionEventSinkCloseFlushesPendingEvents(t *testing.T) { + path := t.TempDir() + "/events.ndjson" + sink := NewNDJSONDecisionEventSink(path) + if sink == nil { + t.Fatalf("expected sink") + } + + sink.Record(DecisionEvent{Timestamp: time.Now().UTC(), Tool: "vcs", Decision: "deny"}) + sink.Record(DecisionEvent{Timestamp: time.Now().UTC(), Tool: "claude", Decision: "allow"}) + sink.Record(DecisionEvent{Timestamp: time.Now().UTC(), Tool: "editor", Decision: "allow"}) + + if err := sink.Close(); err != nil { + t.Fatalf("close failed: %v", err) + } + + events, err := sink.Query(EventQuery{Limit: 100}) + if err != nil { + t.Fatalf("query failed: %v", err) + } + if len(events) != 3 { + t.Fatalf("expected 3 events after close, got %d", len(events)) + } + + if err := sink.Close(); err != nil { + t.Fatalf("idempotent close failed: %v", err) + } +} diff --git a/internal/types/ttlcache/ttlcache.go b/internal/types/ttlcache/ttlcache.go new file mode 100644 index 0000000..f753d9c --- /dev/null +++ b/internal/types/ttlcache/ttlcache.go @@ -0,0 +1,202 @@ +package ttlcache + +import ( + "container/list" + "sync" + "time" +) + +const defaultCleanupInterval = 30 * time.Second + +// Cache is a concurrency-safe TTL cache with LRU eviction. +// It is intentionally small and purpose-built for in-process request-caching. +type Cache[V any] struct { + mu sync.Mutex + maxSize int + ttl time.Duration + items map[string]*list.Element + order *list.List // front = most recently used + clock func() time.Time + cleanupInterval time.Duration + stop chan struct{} + stopped chan struct{} +} + +type item[V any] struct { + key string + value V + expireAt time.Time +} + +func New[V any](maxSize int, ttl time.Duration, options ...Option[V]) *Cache[V] { + cfg := &Config[V]{ + cleanupInterval: 0, + } + for _, option := range options { + option(cfg) + } + + if cfg.cleanupInterval < 0 { + cfg.cleanupInterval = 0 + } + if cfg.clock == nil { + cfg.clock = time.Now + } + + c := &Cache[V]{ + maxSize: maxSize, + ttl: ttl, + items: make(map[string]*list.Element), + order: list.New(), + clock: cfg.clock, + cleanupInterval: cfg.cleanupInterval, + stop: make(chan struct{}), + stopped: make(chan struct{}), + } + + if c.cleanupInterval == 0 { + c.cleanupInterval = defaultCleanupInterval + } + if c.ttl > 0 && c.cleanupInterval > 0 { + go c.cleanupLoop() + } + return c +} + +type Config[V any] struct { + clock func() time.Time + cleanupInterval time.Duration +} + +type Option[V any] func(*Config[V]) + +func WithClock[V any](clock func() time.Time) Option[V] { + return func(cfg *Config[V]) { + cfg.clock = clock + } +} + +func WithCleanupInterval[V any](interval time.Duration) Option[V] { + return func(cfg *Config[V]) { + cfg.cleanupInterval = interval + } +} + +func (c *Cache[V]) Get(key string) (V, bool) { + c.mu.Lock() + defer c.mu.Unlock() + + var zero V + e := c.items[key] + if e == nil { + return zero, false + } + entry := e.Value.(*item[V]) + if c.isExpired(entry) { + c.remove(e) + return zero, false + } + c.order.MoveToFront(e) + return entry.value, true +} + +func (c *Cache[V]) Set(key string, value V) { + c.mu.Lock() + defer c.mu.Unlock() + + if existing := c.items[key]; existing != nil { + existing.Value.(*item[V]).value = value + existing.Value.(*item[V]).expireAt = c.now().Add(c.ttl) + c.order.MoveToFront(existing) + return + } + + entry := &item[V]{ + key: key, + value: value, + expireAt: c.now().Add(c.ttl), + } + e := c.order.PushFront(entry) + c.items[key] = e + c.enforceLimitLocked() +} + +func (c *Cache[V]) Delete(key string) { + c.mu.Lock() + defer c.mu.Unlock() + + if e := c.items[key]; e != nil { + c.remove(e) + } +} + +func (c *Cache[V]) Len() int { + c.mu.Lock() + defer c.mu.Unlock() + c.evictExpiredLocked() + return len(c.items) +} + +func (c *Cache[V]) Close() { + select { + case <-c.stop: + return + default: + close(c.stop) + } + <-c.stopped +} + +func (c *Cache[V]) enforceLimitLocked() { + if c.maxSize <= 0 { + return + } + for len(c.items) > c.maxSize { + back := c.order.Back() + if back == nil { + return + } + c.remove(back) + } +} + +func (c *Cache[V]) remove(e *list.Element) { + ent := e.Value.(*item[V]) + c.order.Remove(e) + delete(c.items, ent.key) +} + +func (c *Cache[V]) cleanupLoop() { + defer close(c.stopped) + ticker := time.NewTicker(c.cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + c.mu.Lock() + c.evictExpiredLocked() + c.mu.Unlock() + case <-c.stop: + return + } + } +} + +func (c *Cache[V]) evictExpiredLocked() { + for e := c.order.Back(); e != nil; { + prev := e.Prev() + if c.isExpired(e.Value.(*item[V])) { + c.remove(e) + } + e = prev + } +} + +func (c *Cache[V]) isExpired(it *item[V]) bool { + return c.ttl > 0 && !it.expireAt.IsZero() && !it.expireAt.After(c.now()) +} + +func (c *Cache[V]) now() time.Time { + return c.clock() +}