diff --git a/README.md b/README.md
index e8156ac..281974d 100644
--- a/README.md
+++ b/README.md
@@ -103,6 +103,8 @@ If you set `DATAFOG_API_TOKEN`, send it on every request using:
| `DATAFOG_FGPROF` | `false` | Add `/debug/fgprof` endpoint to the profiling server |
| `DATAFOG_ENABLE_DEMO` | *(unset)* | Enable `/demo*` endpoints |
| `DATAFOG_DEMO_HTML` | `docs/demo.html` | Path to demo HTML |
+| `DATAFOG_ENABLE_ADMIN_UI` | *(unset)* | Enable read-only `GET /admin` |
+| `DATAFOG_ADMIN_HTML` | `docs/admin.html` | Path to static admin dashboard HTML |
Duration values use Go duration syntax, for example `1s`, `500ms`, `2m`.
@@ -118,8 +120,10 @@ Base URL defaults to `http://localhost:8080`.
| `POST` | `/v1/decide` | Evaluate an action + findings and get a decision |
| `POST` | `/v1/transform` | Apply requested transform mode(s) |
| `POST` | `/v1/anonymize` | Apply irreversible anonymization |
+| `GET` | `/v1/receipts` | List recent decision receipts |
| `GET` | `/v1/receipts/{id}` | Read a decision receipt |
| `GET` | `/v1/events` | List recent decision events |
+| `GET` | `/admin` | Read-only operational dashboard (requires DATAFOG_ENABLE_ADMIN_UI) |
| `GET` | `/metrics` | In-process metrics counters |
Optional demo routes (only when demo mode is enabled):
@@ -192,6 +196,12 @@ curl -X POST http://localhost:8080/v1/transform \
}'
```
+### List receipts (admin/read-only)
+
+```sh
+curl 'http://localhost:8080/v1/receipts?limit=20&decision=allow'
+```
+
### Fetch a receipt
```sh
diff --git a/cmd/datafog-api/main.go b/cmd/datafog-api/main.go
index a441030..3c2ffdf 100644
--- a/cmd/datafog-api/main.go
+++ b/cmd/datafog-api/main.go
@@ -33,6 +33,7 @@ func main() {
rateLimitRPS := getenvInt("DATAFOG_RATE_LIMIT_RPS", 0)
shutdownTimeout := getenvDuration("DATAFOG_SHUTDOWN_TIMEOUT", 10*time.Second)
enableDemo := getenv("DATAFOG_ENABLE_DEMO", "") != "" || hasFlag("--enable-demo")
+ enableAdmin := getenvBool("DATAFOG_ENABLE_ADMIN_UI", false) || hasFlag("--enable-admin-ui")
eventsPath := getenv("DATAFOG_EVENTS_PATH", "")
pprofAddr := getenv("DATAFOG_PPROF_ADDR", "")
fgprofEnabled := getenvBool("DATAFOG_FGPROF", false)
@@ -60,25 +61,48 @@ func main() {
h.SetEventReader(eventReader)
}
- var handler http.Handler
+ var demo *server.DemoHandler
if enableDemo {
// Create a shim gate backed by a local HTTP decision client
client := shim.NewHTTPDecisionClient("http://127.0.0.1"+addr, apiToken)
gate := shim.NewGate(client, shim.WithEventSink(eventSink))
demoHTMLPath := getenv("DATAFOG_DEMO_HTML", "docs/demo.html")
- demo, err := server.NewDemoHandler(gate, h, demoHTMLPath)
+ demo, err = server.NewDemoHandler(gate, h, demoHTMLPath)
if err != nil {
log.Fatalf("init demo: %v", err)
}
defer demo.Cleanup()
+ }
+ var admin *server.AdminHandler
+ if enableAdmin {
+ adminHTMLPath := getenv("DATAFOG_ADMIN_HTML", "docs/admin.html")
+ admin, err = server.NewAdminHandler(adminHTMLPath)
+ if err != nil {
+ log.Fatalf("init admin UI: %v", err)
+ }
+ }
+
+ var handler http.Handler
+ switch {
+ case demo != nil && admin != nil:
+ handler = h.HandlerWithDemoAndAdmin(demo, admin)
+ case demo != nil:
handler = h.HandlerWithDemo(demo)
- log.Printf("demo mode enabled — /demo/exec, /demo/write-file, /demo/read-file available")
- } else {
+ case admin != nil:
+ handler = h.HandlerWithAdmin(admin)
+ default:
handler = h.Handler()
}
+ if enableDemo {
+ log.Printf("demo mode enabled — /demo/exec, /demo/write-file, /demo/read-file available")
+ }
+ if enableAdmin {
+ log.Printf("admin UI enabled — /admin available")
+ }
+
var pprofSrv *http.Server
if pprofAddr != "" {
pprofSrv = startProfilingServer(pprofAddr, fgprofEnabled, log.Default())
@@ -129,6 +153,9 @@ func main() {
log.Printf("forced close failed: %v", closeErr)
}
}
+ if err := h.Shutdown(ctx); err != nil && !errors.Is(err, context.Canceled) {
+ log.Printf("application shutdown helper failed: %v", err)
+ }
if err := <-done; err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Printf("server stopped with error: %v", err)
}
diff --git a/docs/FRONTEND.md b/docs/FRONTEND.md
index 7a649b7..b47564f 100644
--- a/docs/FRONTEND.md
+++ b/docs/FRONTEND.md
@@ -9,6 +9,7 @@ DataFog API is a backend-first project. There is no React/Vue/Next.js applicatio
- Primary user-facing UI is API-first: clients interact through HTTP endpoints.
- Optional demo assets are static HTML in `docs/demo.html` and rendered by `GET /demo` when demo mode is enabled.
+- Optional admin dashboard is static HTML in `docs/admin.html` and rendered by `GET /admin` when admin UI mode is enabled.
## Conventions
diff --git a/docs/admin.html b/docs/admin.html
new file mode 100644
index 0000000..309f8bd
--- /dev/null
+++ b/docs/admin.html
@@ -0,0 +1,354 @@
+
+
+
+
+
+ DataFog Admin
+
+
+
+ DataFog Admin
+ Read-only operational dashboard for inspecting decisions and running policy checks.
+
+
+
+
Connection
+
+
+
+
+
+
+
+
Token stored only in browser localStorage for this page.
+
+
+
+
+
Service Metrics Snapshot
+
Loading...
+
+
+
+
Scan / Decide
+
+
+
+
+
+
+
+
+
No payload yet.
+
+
+
+
Receipt lookup
+
+
+
+
+
+
No payload yet.
+
+
+
+
Recent Events
+
+
+
+
+
+
No payload yet.
+
+
+
+
Recent Receipts
+
+
+
+
+
+
No payload yet.
+
+
+
+
+
+
diff --git a/docs/contracts/datafog-api-contract.md b/docs/contracts/datafog-api-contract.md
index 1543e09..461cf1b 100644
--- a/docs/contracts/datafog-api-contract.md
+++ b/docs/contracts/datafog-api-contract.md
@@ -337,6 +337,36 @@ Returns persisted decision receipts.
}
```
+### `GET /v1/receipts`
+
+Returns a bounded list of recent receipts.
+
+Query params:
+
+- `limit` (`1..1000`, default `100`)
+- `after` (RFC3339 timestamp, strictly greater-than)
+- `before` (RFC3339 timestamp, strictly less-than)
+- `decision` (`allow|allow_with_redaction|transform|deny`)
+- `action_type` (match against `action.type`)
+
+Results are sorted newest-first by `timestamp`.
+
+```json
+{
+ "receipts": [
+ {
+ "receipt_id": "string",
+ "timestamp": "RFC3339 timestamp",
+ "decision": "allow|allow_with_redaction|transform|deny",
+ "action": {
+ "type": "string"
+ }
+ }
+ ],
+ "total": 1
+}
+```
+
### `GET /v1/events`
Returns decision events when `DATAFOG_EVENTS_PATH` is configured (or another reader is set).
@@ -377,6 +407,19 @@ Query params:
If no events are configured or none match, return `{"events":[],"total":0}`.
+## Optional admin dashboard (v1)
+
+The following route is available when `DATAFOG_ENABLE_ADMIN_UI` or `--enable-admin-ui` is set:
+
+- `GET /admin`
+
+It serves a read-only static dashboard for operations and diagnostics. If enabled, a lightweight HTML asset is served from:
+
+- `docs/admin.html` by default,
+- or the path specified by `DATAFOG_ADMIN_HTML`.
+
+Supported auth follows normal API middleware behavior; requesters may use the same token and request ID model as all other routes.
+
## Idempotency
- Supported endpoints: `POST /v1/scan`, `POST /v1/decide`, `POST /v1/transform`, `POST /v1/anonymize`.
diff --git a/go.mod b/go.mod
index 2028ce1..fc00857 100644
--- a/go.mod
+++ b/go.mod
@@ -1,11 +1,12 @@
module github.com/datafog/datafog-api
-go 1.22
+go 1.24.0
require (
github.com/felixge/fgprof v0.9.5
go.uber.org/automaxprocs v1.6.0
go.uber.org/goleak v1.3.0
+ golang.org/x/sync v0.19.0
)
require github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7 // indirect
diff --git a/go.sum b/go.sum
index e80acc3..6bf2fac 100644
--- a/go.sum
+++ b/go.sum
@@ -32,6 +32,8 @@ go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
+golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
+golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
diff --git a/internal/policy/policy.go b/internal/policy/policy.go
index c50f4c6..9da5cf4 100644
--- a/internal/policy/policy.go
+++ b/internal/policy/policy.go
@@ -131,7 +131,7 @@ func ValidatePolicy(policy models.Policy) error {
if len(errors) == 0 {
return nil
}
- return fmt.Errorf(strings.Join(errors, "; "))
+ return fmt.Errorf("%s", strings.Join(errors, "; "))
}
var allowedModes = map[models.TransformMode]struct{}{
@@ -164,6 +164,58 @@ func Evaluate(policy models.Policy, ctx DecisionContext) DecisionResult {
// EvaluateSorted evaluates policy decisions assuming rules are already sorted
// by priority descending.
+type PolicyIndex struct {
+ byAction map[string][]models.Rule
+}
+
+// BuildPolicyIndex precomputes policy lookup tables for faster action-based filtering.
+// Policy rules are expected to already be in evaluation order.
+func BuildPolicyIndex(policy models.Policy) *PolicyIndex {
+ index := &PolicyIndex{byAction: make(map[string][]models.Rule, len(policy.Rules))}
+ for _, rule := range policy.Rules {
+ if len(rule.Match.ActionTypes) == 0 {
+ index.byAction["*"] = append(index.byAction["*"], rule)
+ continue
+ }
+ for _, actionType := range rule.Match.ActionTypes {
+ normalized := strings.ToLower(strings.TrimSpace(actionType))
+ if normalized == "" {
+ continue
+ }
+ index.byAction[normalized] = append(index.byAction[normalized], rule)
+ }
+ }
+ return index
+}
+
+func (idx *PolicyIndex) rulesForAction(actionType string) []models.Rule {
+ if idx == nil {
+ return nil
+ }
+
+ normalizedActionType := strings.ToLower(strings.TrimSpace(actionType))
+ actionRules := idx.byAction[normalizedActionType]
+ wildcard := idx.byAction["*"]
+ matched := make([]models.Rule, 0, len(actionRules)+len(wildcard))
+ matched = append(matched, actionRules...)
+ matched = append(matched, wildcard...)
+
+ if len(matched) == 0 {
+ return nil
+ }
+
+ seen := make(map[string]struct{}, len(matched))
+ ordered := make([]models.Rule, 0, len(matched))
+ for _, rule := range matched {
+ if _, ok := seen[rule.ID]; ok {
+ continue
+ }
+ seen[rule.ID] = struct{}{}
+ ordered = append(ordered, rule)
+ }
+ return ordered
+}
+
type DecisionContext struct {
Action models.ActionMeta
Findings []models.ScanFinding
@@ -177,6 +229,11 @@ type DecisionResult struct {
}
func EvaluateSorted(policy models.Policy, ctx DecisionContext) DecisionResult {
+ return EvaluateWithIndex(policy, nil, ctx)
+}
+
+// EvaluateWithIndex is the hot-path implementation that can use a precomputed policy index.
+func EvaluateWithIndex(policy models.Policy, index *PolicyIndex, ctx DecisionContext) DecisionResult {
if ctx.Action.Type == "" {
return DecisionResult{
Decision: models.DecisionDeny,
@@ -192,6 +249,9 @@ func EvaluateSorted(policy models.Policy, ctx DecisionContext) DecisionResult {
}
rules := policy.Rules
+ if index != nil {
+ rules = index.rulesForAction(ctx.Action.Type)
+ }
hasFindings := map[string]struct{}{}
for _, f := range ctx.Findings {
diff --git a/internal/receipts/store.go b/internal/receipts/store.go
index d8e453b..962ed00 100644
--- a/internal/receipts/store.go
+++ b/internal/receipts/store.go
@@ -5,9 +5,11 @@ import (
"crypto/rand"
"encoding/hex"
"encoding/json"
+ "errors"
"fmt"
"os"
"path/filepath"
+ "sort"
"strings"
"sync"
"time"
@@ -16,18 +18,54 @@ import (
"github.com/datafog/datafog-api/internal/policy"
)
-const maxReceiptLineBytes = 1024 * 1024
-const defaultReceiptFileMode = 0o600
-const defaultReceiptDirMode = 0o750
+const (
+ maxReceiptLineBytes = 1024 * 1024
+
+ defaultReceiptFileMode = 0o600
+ defaultReceiptDirMode = 0o750
+
+ defaultReceiptWriteQueueSize = 1024
+ defaultReceiptQueueTimeout = 5 * time.Second
+ defaultReceiptFlushInterval = 250 * time.Millisecond
+ defaultReceiptBatchWrites = 32
+)
type ReceiptStore struct {
- mu sync.RWMutex
- filePath string
- receipts map[string]models.Receipt
- maxEntries int
- entryCount int
+ mu sync.RWMutex
+ filePath string
+ receipts map[string]models.Receipt
+ maxEntries int
+ entryCount int
+ writeQueue chan queuedReceipt
+ flushInterval time.Duration
+ batchWrites int
+ queueTimeout time.Duration
+ writeDelay time.Duration
+ closed bool
+
+ closeCh chan struct{}
+ closedCh chan struct{}
+ workerWg sync.WaitGroup
+ writer *bufio.Writer
+ file *os.File
+}
+
+type queuedReceipt struct {
+ receipt models.Receipt
+ rotate bool
+}
+
+type ListQuery struct {
+ Limit int
+ Decision string
+ ActionType string
+ After *time.Time
+ Before *time.Time
}
+var errStoreClosed = errors.New("receipt store is closed")
+var errReceiptWriteQueueFull = errors.New("receipt write queue is full")
+
// MaxEntries sets the maximum number of receipts before rotation.
// 0 means no limit (default).
func MaxEntries(n int) func(*ReceiptStore) {
@@ -36,6 +74,31 @@ func MaxEntries(n int) func(*ReceiptStore) {
}
}
+// MaxWriteQueueSize sets the write queue buffer size.
+func MaxWriteQueueSize(n int) func(*ReceiptStore) {
+ return func(s *ReceiptStore) {
+ if n < 0 {
+ n = 0
+ }
+ s.writeQueue = make(chan queuedReceipt, n)
+ }
+}
+
+// WriteQueueTimeout sets the amount of time Save waits when queue is full before returning.
+func WriteQueueTimeout(timeout time.Duration) func(*ReceiptStore) {
+ return func(s *ReceiptStore) {
+ s.queueTimeout = timeout
+ }
+}
+
+// WriteDelay introduces an artificial delay before each persisted write.
+// It can be used to support deterministically testing backpressure behavior.
+func WriteDelay(delay time.Duration) func(*ReceiptStore) {
+ return func(s *ReceiptStore) {
+ s.writeDelay = delay
+ }
+}
+
func NewReceiptStore(filePath string, opts ...func(*ReceiptStore)) (*ReceiptStore, error) {
if filePath == "" {
filePath = "datafog_receipts.jsonl"
@@ -52,20 +115,49 @@ func NewReceiptStore(filePath string, opts ...func(*ReceiptStore)) (*ReceiptStor
}
store := &ReceiptStore{
- filePath: filePath,
- receipts: map[string]models.Receipt{},
+ filePath: filePath,
+ receipts: map[string]models.Receipt{},
+ flushInterval: defaultReceiptFlushInterval,
+ batchWrites: defaultReceiptBatchWrites,
+ queueTimeout: defaultReceiptQueueTimeout,
+ writeQueue: make(chan queuedReceipt, defaultReceiptWriteQueueSize),
+ closeCh: make(chan struct{}),
+ closedCh: make(chan struct{}),
}
+
for _, opt := range opts {
opt(store)
}
- f, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDONLY, defaultReceiptFileMode) // #nosec G304 -- receipt path is validated from startup configuration.
- if err != nil {
- return nil, err
+
+ if store.flushInterval <= 0 {
+ store.flushInterval = defaultReceiptFlushInterval
+ }
+ if store.batchWrites <= 0 {
+ store.batchWrites = defaultReceiptBatchWrites
+ }
+ if store.queueTimeout <= 0 {
+ store.queueTimeout = defaultReceiptQueueTimeout
+ }
+ if store.writeQueue == nil {
+ store.writeQueue = make(chan queuedReceipt, defaultReceiptWriteQueueSize)
}
- f.Close()
+ if store.closeCh == nil {
+ store.closeCh = make(chan struct{})
+ }
+ if store.closedCh == nil {
+ store.closedCh = make(chan struct{})
+ }
+
if err := store.loadExistingReceipts(); err != nil {
return nil, err
}
+
+ if err := store.openWriter(); err != nil {
+ return nil, err
+ }
+
+ store.workerWg.Add(1)
+ go store.writeLoop()
return store, nil
}
@@ -91,51 +183,171 @@ func (s *ReceiptStore) NewReceipt(req models.DecideRequest, decision models.Deci
func (s *ReceiptStore) Save(receipt models.Receipt) (models.Receipt, error) {
s.mu.Lock()
- defer s.mu.Unlock()
-
+ if s.closed {
+ s.mu.Unlock()
+ return receipt, errStoreClosed
+ }
if receipt.ReceiptID == "" {
receipt.ReceiptID = newID()
}
- // Rotate if we've hit the max
+ rotate := false
if s.maxEntries > 0 && s.entryCount >= s.maxEntries {
- if err := s.rotateLocked(); err != nil {
- return models.Receipt{}, fmt.Errorf("receipt rotation failed: %w", err)
+ s.receipts = map[string]models.Receipt{}
+ s.entryCount = 0
+ rotate = true
+ }
+
+ s.receipts[receipt.ReceiptID] = receipt
+ s.entryCount++
+ s.mu.Unlock()
+
+ writeTask := queuedReceipt{receipt: receipt, rotate: rotate}
+ select {
+ case s.writeQueue <- writeTask:
+ return receipt, nil
+ case <-s.closeCh:
+ return receipt, errStoreClosed
+ case <-time.After(s.queueTimeout):
+ return receipt, errReceiptWriteQueueFull
+ }
+}
+
+func (s *ReceiptStore) Close() error {
+ s.mu.Lock()
+ if s.closed {
+ s.mu.Unlock()
+ <-s.closedCh
+ return nil
+ }
+ s.closed = true
+ close(s.closeCh)
+ s.mu.Unlock()
+
+ s.workerWg.Wait()
+ <-s.closedCh
+ return nil
+}
+
+// rotateAndOpen archives the current receipts file and starts fresh.
+func (s *ReceiptStore) rotateAndOpen() error {
+ if err := s.flushWriter(); err != nil {
+ return err
+ }
+ if err := s.closeWriter(); err != nil {
+ return err
+ }
+
+ archivePath := s.filePath + "." + time.Now().UTC().Format("20060102T150405Z")
+ if err := os.Rename(s.filePath, archivePath); err != nil && !os.IsNotExist(err) {
+ return err
+ }
+
+ return s.openWriter()
+}
+
+func (s *ReceiptStore) writeLoop() {
+ defer s.workerWg.Done()
+ defer close(s.closedCh)
+
+ ticker := time.NewTicker(s.flushInterval)
+ defer ticker.Stop()
+
+ pendingWrites := 0
+ for {
+ select {
+ case <-ticker.C:
+ if pendingWrites > 0 {
+ _ = s.flushWriter()
+ pendingWrites = 0
+ }
+ case writeTask := <-s.writeQueue:
+ if err := s.applyWriteTask(writeTask); err != nil {
+ // Drop task-specific persistence errors for now; in-memory state remains.
+ // Callers can observe filesystem health by checking write throughput externally.
+ }
+ pendingWrites++
+ if pendingWrites >= s.batchWrites {
+ _ = s.flushWriter()
+ pendingWrites = 0
+ }
+ case <-s.closeCh:
+ for {
+ select {
+ case writeTask := <-s.writeQueue:
+ if err := s.applyWriteTask(writeTask); err != nil {
+ }
+ pendingWrites++
+ default:
+ goto drained
+ }
+ }
+ drained:
+ if pendingWrites > 0 {
+ _ = s.flushWriter()
+ }
+ _ = s.closeWriter()
+ return
}
}
+}
- data, err := json.Marshal(receipt)
- if err != nil {
- return models.Receipt{}, err
+func (s *ReceiptStore) applyWriteTask(writeTask queuedReceipt) error {
+ if writeTask.rotate {
+ if err := s.rotateAndOpen(); err != nil {
+ return err
+ }
}
- f, err := os.OpenFile(s.filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, defaultReceiptFileMode) // #nosec G304 -- receipt path is validated from startup configuration.
+ if s.writeDelay > 0 {
+ time.Sleep(s.writeDelay)
+ }
+
+ payload, err := json.Marshal(writeTask.receipt)
if err != nil {
- return models.Receipt{}, err
+ return err
}
- defer f.Close()
- if _, err := f.Write(appendWithLine(data)); err != nil {
- return models.Receipt{}, err
+ next := appendWithLine(payload)
+ _, err = s.writer.Write(next)
+ return err
+}
+
+func (s *ReceiptStore) flushWriter() error {
+ if s.writer == nil || s.file == nil {
+ return nil
}
- if err := f.Sync(); err != nil {
- return models.Receipt{}, err
+ if err := s.writer.Flush(); err != nil {
+ return err
}
-
- s.receipts[receipt.ReceiptID] = receipt
- s.entryCount++
- return receipt, nil
+ return s.file.Sync()
}
-// rotateLocked archives the current receipts file and starts fresh.
-// Must be called with s.mu held.
-func (s *ReceiptStore) rotateLocked() error {
- archivePath := s.filePath + "." + time.Now().UTC().Format("20060102T150405Z")
- if err := os.Rename(s.filePath, archivePath); err != nil && !os.IsNotExist(err) {
+func (s *ReceiptStore) openWriter() error {
+ file, err := os.OpenFile(s.filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, defaultReceiptFileMode) // #nosec G304 -- receipt path is validated from startup configuration.
+ if err != nil {
return err
}
- s.receipts = map[string]models.Receipt{}
- s.entryCount = 0
+
+ s.file = file
+ s.writer = bufio.NewWriter(file)
+ return nil
+}
+
+func (s *ReceiptStore) closeWriter() error {
+ if s.writer != nil {
+ if err := s.writer.Flush(); err != nil {
+ return err
+ }
+ s.writer = nil
+ }
+ if s.file != nil {
+ if err := s.file.Close(); err != nil {
+ s.file = nil
+ return err
+ }
+ s.file = nil
+ }
return nil
}
@@ -146,9 +358,56 @@ func (s *ReceiptStore) Count() int {
return len(s.receipts)
}
+// List returns receipts sorted newest-first and optional filters.
+func (s *ReceiptStore) List(q ListQuery) ([]models.Receipt, int) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ limit := q.Limit
+ if limit <= 0 {
+ limit = 100
+ }
+ if limit > 1000 {
+ limit = 1000
+ }
+
+ results := make([]models.Receipt, 0, len(s.receipts))
+ for _, receipt := range s.receipts {
+ if q.Decision != "" && !strings.EqualFold(string(receipt.Decision), q.Decision) {
+ continue
+ }
+ if q.ActionType != "" && !strings.EqualFold(receipt.Action.Type, q.ActionType) {
+ continue
+ }
+ if q.After != nil && !receipt.Timestamp.After(*q.After) {
+ continue
+ }
+ if q.Before != nil && !receipt.Timestamp.Before(*q.Before) {
+ continue
+ }
+ results = append(results, receipt)
+ }
+
+ sort.Slice(results, func(i, j int) bool {
+ if results[i].Timestamp.Equal(results[j].Timestamp) {
+ return results[i].ReceiptID < results[j].ReceiptID
+ }
+ return results[i].Timestamp.After(results[j].Timestamp)
+ })
+
+ total := len(results)
+ if len(results) > limit {
+ results = results[:limit]
+ }
+ return results, total
+}
+
func (s *ReceiptStore) loadExistingReceipts() error {
f, err := os.OpenFile(s.filePath, os.O_RDONLY, defaultReceiptFileMode) // #nosec G304 -- receipt path is validated from startup configuration.
if err != nil {
+ if os.IsNotExist(err) {
+ return nil
+ }
return err
}
defer f.Close()
@@ -157,6 +416,8 @@ func (s *ReceiptStore) loadExistingReceipts() error {
scanner.Buffer(make([]byte, 64*1024), maxReceiptLineBytes)
badLines := 0
+ s.mu.Lock()
+ defer s.mu.Unlock()
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
diff --git a/internal/receipts/store_test.go b/internal/receipts/store_test.go
index 4fc0dda..97765c9 100644
--- a/internal/receipts/store_test.go
+++ b/internal/receipts/store_test.go
@@ -2,9 +2,11 @@ package receipts
import (
"encoding/json"
+ "errors"
"os"
"strings"
"testing"
+ "time"
"github.com/datafog/datafog-api/internal/models"
"github.com/datafog/datafog-api/internal/policy"
@@ -16,6 +18,9 @@ func TestReceiptStoreSaveAndGet(t *testing.T) {
if err != nil {
t.Fatalf("new store failed: %v", err)
}
+ t.Cleanup(func() {
+ _ = store.Close()
+ })
req := models.DecideRequest{RequestID: "r1", Action: models.ActionMeta{Type: "file.read", Resource: "x"}}
result := policy.DecisionResult{Decision: models.DecisionAllow, MatchedRules: []string{"allow-1"}}
@@ -63,6 +68,9 @@ func TestReceiptStoreLoadsExistingReceipts(t *testing.T) {
if err != nil {
t.Fatalf("new store failed: %v", err)
}
+ t.Cleanup(func() {
+ _ = store.Close()
+ })
got, ok := store.Get("receipt-seeded")
if !ok {
t.Fatalf("expected to load existing receipt")
@@ -92,6 +100,9 @@ func TestReceiptStoreSkipsCorruptReceiptLines(t *testing.T) {
if err != nil {
t.Fatalf("expected corrupt+good receipts to load, got %v", err)
}
+ t.Cleanup(func() {
+ _ = store.Close()
+ })
if got, ok := store.Get("receipt-good"); !ok {
t.Fatalf("expected good receipt loaded")
} else if got.ReceiptID != "receipt-good" {
@@ -141,6 +152,9 @@ func TestReceiptStoreLoadsLargeReceiptLine(t *testing.T) {
if err != nil {
t.Fatalf("new store failed: %v", err)
}
+ t.Cleanup(func() {
+ _ = store.Close()
+ })
got, ok := store.Get("receipt-large")
if !ok {
t.Fatalf("expected to load large receipt")
@@ -149,3 +163,137 @@ func TestReceiptStoreLoadsLargeReceiptLine(t *testing.T) {
t.Fatalf("expected loaded receipt id receipt-large, got %q", got.ReceiptID)
}
}
+
+func TestReceiptStoreSaveQueueSaturationReturnsBackpressure(t *testing.T) {
+ path := t.TempDir() + "/receipts.jsonl"
+
+ store, err := NewReceiptStore(
+ path,
+ MaxWriteQueueSize(0),
+ WriteQueueTimeout(5*time.Millisecond),
+ WriteDelay(20*time.Millisecond),
+ )
+ if err != nil {
+ t.Fatalf("new store failed: %v", err)
+ }
+ t.Cleanup(func() {
+ _ = store.Close()
+ })
+
+ if _, err := store.Save(models.Receipt{ReceiptID: "full-1"}); err != nil {
+ t.Fatalf("expected first save to enqueue, got %v", err)
+ }
+ if _, err := store.Save(models.Receipt{ReceiptID: "full-2"}); !errors.Is(err, errReceiptWriteQueueFull) {
+ t.Fatalf("expected queue saturation error, got %v", err)
+ }
+ if got, ok := store.Get("full-1"); !ok || got.ReceiptID != "full-1" {
+ t.Fatalf("expected first receipt to remain in-memory, got ok=%v id=%q", ok, got.ReceiptID)
+ }
+ if got, ok := store.Get("full-2"); !ok || got.ReceiptID != "full-2" {
+ t.Fatalf("expected second receipt to remain in-memory despite write backpressure, got ok=%v id=%q", ok, got.ReceiptID)
+ }
+}
+
+func TestReceiptStoreListFiltersAndLimit(t *testing.T) {
+ path := t.TempDir() + "/receipts.jsonl"
+ store, err := NewReceiptStore(path)
+ if err != nil {
+ t.Fatalf("new store failed: %v", err)
+ }
+ t.Cleanup(func() {
+ _ = store.Close()
+ })
+
+ base := time.Date(2026, 1, 2, 3, 4, 5, 0, time.UTC)
+ if _, err := store.Save(models.Receipt{ReceiptID: "a-old", Decision: models.DecisionAllow, Action: models.ActionMeta{Type: "file.write"}, Timestamp: base.Add(-2 * time.Hour)}); err != nil {
+ t.Fatalf("save failed: %v", err)
+ }
+ if _, err := store.Save(models.Receipt{ReceiptID: "b-middle", Decision: models.DecisionAllow, Action: models.ActionMeta{Type: "shell.exec"}, Timestamp: base.Add(-time.Minute)}); err != nil {
+ t.Fatalf("save failed: %v", err)
+ }
+ if _, err := store.Save(models.Receipt{ReceiptID: "c-new", Decision: models.DecisionDeny, Action: models.ActionMeta{Type: "file.write"}, Timestamp: base}); err != nil {
+ t.Fatalf("save failed: %v", err)
+ }
+
+ entries, total := store.List(ListQuery{Limit: 2, Decision: "allow"})
+ if total != 2 {
+ t.Fatalf("expected 2 total allow matches, got %d", total)
+ }
+ if len(entries) != 2 {
+ t.Fatalf("expected 2 returned entries with default limit 2, got %d", len(entries))
+ }
+ if entries[0].ReceiptID != "b-middle" || entries[1].ReceiptID != "a-old" {
+ t.Fatalf("expected ordered newest-to-oldest allow receipts, got %s, %s", entries[0].ReceiptID, entries[1].ReceiptID)
+ }
+
+ limitedEntries, limitedTotal := store.List(ListQuery{Decision: "allow", Limit: 1})
+ if limitedTotal != 2 {
+ t.Fatalf("expected 2 total allow matches for limited query, got %d", limitedTotal)
+ }
+ if len(limitedEntries) != 1 {
+ t.Fatalf("expected 1 returned entry with limit 1, got %d", len(limitedEntries))
+ }
+ if limitedEntries[0].ReceiptID != "b-middle" {
+ t.Fatalf("expected newest allow receipt first, got %s", limitedEntries[0].ReceiptID)
+ }
+
+ before := base
+ actionEntries, total := store.List(ListQuery{ActionType: "file.write", Before: &before, Limit: 10})
+ if total != 1 {
+ t.Fatalf("expected 1 file.write receipt before %s, got %d", before, total)
+ }
+ if len(actionEntries) != 1 {
+ t.Fatalf("expected 1 returned entry, got %d", len(actionEntries))
+ }
+}
+
+func TestReceiptStoreCloseDrainsQueuedWrites(t *testing.T) {
+ path := t.TempDir() + "/receipts.jsonl"
+ store, err := NewReceiptStore(path)
+ if err != nil {
+ t.Fatalf("new store failed: %v", err)
+ }
+
+ receipts := []models.Receipt{
+ {ReceiptID: "flush-1", Decision: models.DecisionAllow},
+ {ReceiptID: "flush-2", Decision: models.DecisionAllow},
+ {ReceiptID: "flush-3", Decision: models.DecisionAllow},
+ }
+
+ for i, receipt := range receipts {
+ if _, err := store.Save(receipt); err != nil {
+ t.Fatalf("save %d failed: %v", i, err)
+ }
+ }
+ if err := store.Close(); err != nil {
+ t.Fatalf("close failed: %v", err)
+ }
+
+ reopened, err := NewReceiptStore(path)
+ if err != nil {
+ t.Fatalf("reopen after close failed: %v", err)
+ }
+ t.Cleanup(func() {
+ _ = reopened.Close()
+ })
+
+ for _, id := range []string{"flush-1", "flush-2", "flush-3"} {
+ if _, ok := reopened.Get(id); !ok {
+ t.Fatalf("expected receipt %q to be flushed to disk and reloaded", id)
+ }
+ }
+}
+
+func TestReceiptStoreSaveAfterCloseReturnsError(t *testing.T) {
+ path := t.TempDir() + "/receipts.jsonl"
+ store, err := NewReceiptStore(path)
+ if err != nil {
+ t.Fatalf("new store failed: %v", err)
+ }
+ if err := store.Close(); err != nil {
+ t.Fatalf("close failed: %v", err)
+ }
+ if _, err := store.Save(models.Receipt{ReceiptID: "closed"}); !errors.Is(err, errStoreClosed) {
+ t.Fatalf("expected closed store error, got %v", err)
+ }
+}
diff --git a/internal/scan/detector.go b/internal/scan/detector.go
index 93dfc3c..213764b 100644
--- a/internal/scan/detector.go
+++ b/internal/scan/detector.go
@@ -1,7 +1,11 @@
package scan
import (
+ "context"
"regexp"
+ "runtime"
+
+ "golang.org/x/sync/errgroup"
"github.com/datafog/datafog-api/internal/models"
)
@@ -77,43 +81,93 @@ var defaultScanEntityTypes = []string{
func ScanText(text string, entityFilter []string) []models.ScanFinding {
requested := requestedEntitySet(entityFilter)
+ if text == "" {
+ return nil
+ }
- findings := make([]models.ScanFinding, 0)
-
- // Phase 1: Regex engine (fast, always available)
+ targets := make([]string, 0, len(defaultScanEntityTypes))
for _, entityType := range defaultScanEntityTypes {
if shouldRunEntityType(requested, entityType) {
- pattern := DefaultEntityPatterns[entityType]
- idxs := pattern.Re.FindAllStringIndex(text, -1)
- for _, idx := range idxs {
- if len(idx) != 2 || idx[0] < 0 || idx[1] < idx[0] {
- continue
- }
- value := text[idx[0]:idx[1]]
- if pattern.Validate != nil && !pattern.Validate(value) {
- continue
- }
- findings = append(findings, models.ScanFinding{
- EntityType: entityType,
- Value: value,
- Start: idx[0],
- End: idx[1],
- Confidence: DefaultEntityConfidences[entityType],
- })
+ targets = append(targets, entityType)
+ }
+ }
+
+ chunked := make([][]models.ScanFinding, len(targets))
+ if len(targets) > 0 {
+ g, _ := errgroup.WithContext(context.Background())
+ limit := scanWorkers()
+ if limit > 1 {
+ g.SetLimit(limit)
+ }
+ for i, entityType := range targets {
+ i := i
+ entityType := entityType
+ pattern, ok := DefaultEntityPatterns[entityType]
+ if !ok || pattern.Re == nil {
+ continue
}
+ g.Go(func() error {
+ local := findMatchesForPattern(text, entityType, pattern)
+ chunked[i] = local
+ return nil
+ })
}
+ _ = g.Wait()
+ }
+
+ findings := make([]models.ScanFinding, 0)
+ for _, chunk := range chunked {
+ if len(chunk) == 0 {
+ continue
+ }
+ findings = append(findings, chunk...)
}
if !shouldRunNERForFilter(requested) {
return findings
}
-
- // Phase 2: NER engine (heuristic, when enabled)
findings = append(findings, scanNERWithFilter(text, requested)...)
-
return findings
}
+func findMatchesForPattern(text string, entityType string, pattern EntityPattern) []models.ScanFinding {
+ if text == "" || pattern.Re == nil {
+ return nil
+ }
+
+ idxs := pattern.Re.FindAllStringIndex(text, -1)
+ if len(idxs) == 0 {
+ return nil
+ }
+
+ results := make([]models.ScanFinding, 0, len(idxs))
+ for _, idx := range idxs {
+ if len(idx) != 2 || idx[0] < 0 || idx[1] < idx[0] {
+ continue
+ }
+ value := text[idx[0]:idx[1]]
+ if pattern.Validate != nil && !pattern.Validate(value) {
+ continue
+ }
+ results = append(results, models.ScanFinding{
+ EntityType: entityType,
+ Value: value,
+ Start: idx[0],
+ End: idx[1],
+ Confidence: DefaultEntityConfidences[entityType],
+ })
+ }
+ return results
+}
+
+func scanWorkers() int {
+ workers := runtime.GOMAXPROCS(0)
+ if workers < 1 {
+ return 1
+ }
+ return workers
+}
+
// luhnValid implements the Luhn algorithm to validate credit card numbers.
// It strips spaces and dashes before checking.
func luhnValid(s string) bool {
diff --git a/internal/scan/ner.go b/internal/scan/ner.go
index 673a4a7..3d43ab4 100644
--- a/internal/scan/ner.go
+++ b/internal/scan/ner.go
@@ -1,9 +1,14 @@
package scan
import (
+ "context"
+ "runtime"
"strings"
+ "sync"
"unicode"
+ "golang.org/x/sync/semaphore"
+
"github.com/datafog/datafog-api/internal/models"
)
@@ -116,9 +121,67 @@ var wellKnownLocations = map[string]bool{
"japan": true, "china": true, "india": true, "brazil": true,
}
-// NEREnabled controls whether the NER engine runs. Can be toggled via env var.
+// NEREnabled controls whether the heuristic NER engine runs.
+// It can be toggled via env var by caller-level config.
var NEREnabled = true
+var (
+ nerOnce sync.Once
+ nerLimiterMu sync.Mutex
+ nerLimiter *semaphore.Weighted
+)
+
+func ensureNERState() {
+ nerOnce.Do(func() {
+ personTriggers = normalizeTokenMap(personTriggers)
+ orgSuffixes = normalizeTokenMap(orgSuffixes)
+ locationTriggers = normalizeTokenMap(locationTriggers)
+ wellKnownLocations = normalizeTokenMap(wellKnownLocations)
+ commonFirstNames = normalizeTokenMap(commonFirstNames)
+
+ workers := runtime.GOMAXPROCS(0)
+ if workers < 1 {
+ workers = 1
+ }
+ numWorkers := int64(workers)
+ nerLimiterMu.Lock()
+ nerLimiter = semaphore.NewWeighted(numWorkers)
+ nerLimiterMu.Unlock()
+ })
+}
+
+func normalizeTokenMap(src map[string]bool) map[string]bool {
+ normalized := make(map[string]bool, len(src))
+ for token := range src {
+ normalized[strings.ToLower(strings.TrimSpace(token))] = true
+ }
+ return normalized
+}
+
+func acquireNERSlot() error {
+ ensureNERState()
+ nerLimiterMu.Lock()
+ lim := nerLimiter
+ nerLimiterMu.Unlock()
+ if lim == nil {
+ return nil
+ }
+ return lim.Acquire(context.Background(), 1)
+}
+
+func releaseNERSlot() {
+ nerLimiterMu.Lock()
+ lim := nerLimiter
+ nerLimiterMu.Unlock()
+ if lim != nil {
+ lim.Release(1)
+ }
+}
+
+// NEREnabled controls whether the NER engine runs. Can be toggled via env var.
+// Kept for backward compatibility with existing tests and integrations.
+// var NEREnabled = true
+
// ScanNER runs the heuristic NER engine over text and returns findings
// for person, organization, and location entities.
func ScanNER(text string, entityFilter []string) []models.ScanFinding {
@@ -127,6 +190,9 @@ func ScanNER(text string, entityFilter []string) []models.ScanFinding {
}
requested := requestedEntitySet(entityFilter)
+ if len(requested) > 0 && !shouldRunNERForFilter(requested) {
+ return nil
+ }
return scanNERWithFilter(text, requested)
}
@@ -134,8 +200,13 @@ func scanNERWithFilter(text string, requested map[string]struct{}) []models.Scan
if !NEREnabled {
return nil
}
+ ensureNERState()
+ if err := acquireNERSlot(); err != nil {
+ return nil
+ }
+ defer releaseNERSlot()
- if len(requested) > 0 && !shouldRunNERForFilter(requested) {
+ if !hasRequestedNERFilter(requested) {
return nil
}
@@ -198,6 +269,10 @@ func scanNERWithFilter(text string, requested map[string]struct{}) []models.Scan
orgSpan = span[1:]
}
}
+ if len(orgSpan) == 0 {
+ i += len(span) - 1
+ continue
+ }
orgText := buildSpanText(text, orgSpan)
findings = append(findings, models.ScanFinding{
EntityType: "organization",
@@ -268,6 +343,22 @@ func scanNERWithFilter(text string, requested map[string]struct{}) []models.Scan
return findings
}
+func hasRequestedNERFilter(requested map[string]struct{}) bool {
+ if len(requested) == 0 {
+ return true
+ }
+ if _, ok := requested["person"]; ok {
+ return true
+ }
+ if _, ok := requested["organization"]; ok {
+ return true
+ }
+ if _, ok := requested["location"]; ok {
+ return true
+ }
+ return false
+}
+
type tokenInfo struct {
text string
start int
diff --git a/internal/server/admin.go b/internal/server/admin.go
new file mode 100644
index 0000000..6cb365a
--- /dev/null
+++ b/internal/server/admin.go
@@ -0,0 +1,35 @@
+package server
+
+import (
+ "net/http"
+ "os"
+)
+
+// AdminHandler serves a lightweight read-only admin dashboard page.
+type AdminHandler struct {
+ adminHTML []byte
+}
+
+// NewAdminHandler creates an admin dashboard handler backed by the given HTML asset.
+func NewAdminHandler(htmlPath string) (*AdminHandler, error) {
+ if htmlPath == "" {
+ htmlPath = "docs/admin.html"
+ }
+
+ html, err := os.ReadFile(htmlPath)
+ if err != nil {
+ return nil, err
+ }
+
+ return &AdminHandler{adminHTML: html}, nil
+}
+
+// Register adds the admin endpoint to the given mux.
+func (d *AdminHandler) Register(mux *http.ServeMux) {
+ mux.HandleFunc("/admin", d.handleAdminPage)
+}
+
+func (d *AdminHandler) handleAdminPage(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ w.Write(d.adminHTML)
+}
diff --git a/internal/server/metrics.go b/internal/server/metrics.go
new file mode 100644
index 0000000..94fd750
--- /dev/null
+++ b/internal/server/metrics.go
@@ -0,0 +1,166 @@
+package server
+
+import (
+ "fmt"
+ "hash/fnv"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+const metricShardCount = 16
+
+type metricCounterMap struct {
+ shards [metricShardCount]metricCounterShard
+}
+
+type metricCounterShard struct {
+ mu sync.RWMutex
+ counters map[string]*atomic.Int64
+}
+
+func newMetricCounterMap() *metricCounterMap {
+ m := &metricCounterMap{}
+ for i := 0; i < metricShardCount; i++ {
+ m.shards[i].counters = make(map[string]*atomic.Int64)
+ }
+ return m
+}
+
+func (m *metricCounterMap) add(key string, delta int64) {
+ sh := m.shard(key)
+ counter := sh.loadCounter(key)
+ if counter == nil {
+ counter = sh.initCounter(key)
+ }
+ counter.Add(delta)
+}
+
+func (m *metricCounterMap) load(key string) int64 {
+ sh := m.shard(key)
+ sh.mu.RLock()
+ counter := sh.counters[key]
+ sh.mu.RUnlock()
+ if counter == nil {
+ return 0
+ }
+ return counter.Load()
+}
+
+func (m *metricCounterMap) snapshot() map[string]int64 {
+ out := make(map[string]int64)
+ for i := 0; i < metricShardCount; i++ {
+ sh := &m.shards[i]
+ sh.mu.RLock()
+ for key, counter := range sh.counters {
+ out[key] = counter.Load()
+ }
+ sh.mu.RUnlock()
+ }
+ return out
+}
+
+func (m *metricCounterMap) shard(key string) *metricCounterShard {
+ h := fnv.New32a()
+ _, _ = h.Write([]byte(key))
+ return &m.shards[h.Sum32()%metricShardCount]
+}
+
+func (s *metricCounterShard) loadCounter(key string) *atomic.Int64 {
+ s.mu.RLock()
+ counter := s.counters[key]
+ s.mu.RUnlock()
+ return counter
+}
+
+func (s *metricCounterShard) initCounter(key string) *atomic.Int64 {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if existing := s.counters[key]; existing != nil {
+ return existing
+ }
+ counter := &atomic.Int64{}
+ s.counters[key] = counter
+ return counter
+}
+
+func (s *metricCounterMap) len() int {
+ count := 0
+ for i := 0; i < metricShardCount; i++ {
+ sh := &s.shards[i]
+ sh.mu.RLock()
+ count += len(sh.counters)
+ sh.mu.RUnlock()
+ }
+ return count
+}
+
+func intKey(value int) string {
+ return fmt.Sprintf("%d", value)
+}
+
+type requestMetrics struct {
+ totalCount atomic.Int64
+ errorCount atomic.Int64
+ totalLatencyNs atomic.Int64
+ pathHits *metricCounterMap
+ methodHits *metricCounterMap
+ statusHits *metricCounterMap
+ pathLatencyNs *metricCounterMap
+ pathLatencyHits *metricCounterMap
+}
+
+func newRequestMetrics() *requestMetrics {
+ return &requestMetrics{
+ pathHits: newMetricCounterMap(),
+ methodHits: newMetricCounterMap(),
+ statusHits: newMetricCounterMap(),
+ pathLatencyNs: newMetricCounterMap(),
+ pathLatencyHits: newMetricCounterMap(),
+ }
+}
+
+func (m *requestMetrics) record(method string, route string, status int, latencyNs int64) {
+ m.totalCount.Add(1)
+ m.totalLatencyNs.Add(latencyNs)
+ if status >= 400 {
+ m.errorCount.Add(1)
+ }
+ m.methodHits.add(method, 1)
+ m.pathHits.add(route, 1)
+ m.statusHits.add(intKey(status), 1)
+ m.pathLatencyNs.add(route, latencyNs)
+ m.pathLatencyHits.add(route, 1)
+}
+
+func (m *requestMetrics) snapshot() (total int64, errorCount int64, avgLatencyMs float64, byMethod, byPath, byStatus map[string]int64, byPathLatency map[string]float64) {
+ count := m.totalCount.Load()
+ errorCount = m.errorCount.Load()
+ latencyNs := m.totalLatencyNs.Load()
+ if count > 0 {
+ avgLatencyMs = float64(latencyNs) / float64(count) / float64(1e6)
+ if avgLatencyMs <= 0 {
+ avgLatencyMs = float64(time.Nanosecond) / float64(time.Millisecond)
+ }
+ }
+
+ byMethod = m.methodHits.snapshot()
+ byPath = m.pathHits.snapshot()
+ byStatus = m.statusHits.snapshot()
+
+ pathNs := m.pathLatencyNs.snapshot()
+ pathHits := m.pathLatencyHits.snapshot()
+ byPathLatency = make(map[string]float64)
+ for path, ns := range pathNs {
+ hits := pathHits[path]
+ if hits == 0 {
+ continue
+ }
+ v := float64(ns) / float64(hits) / float64(1e6)
+ if v <= 0 {
+ v = float64(time.Nanosecond) / float64(time.Millisecond)
+ }
+ byPathLatency[path] = v
+ }
+ return count, errorCount, avgLatencyMs, byMethod, byPath, byStatus, byPathLatency
+}
diff --git a/internal/server/server.go b/internal/server/server.go
index fb1e64c..f34c246 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -1,6 +1,7 @@
package server
import (
+ "compress/gzip"
"context"
"crypto/rand"
"crypto/sha256"
@@ -25,30 +26,25 @@ import (
"github.com/datafog/datafog-api/internal/scan"
"github.com/datafog/datafog-api/internal/shim"
"github.com/datafog/datafog-api/internal/transform"
+ "github.com/datafog/datafog-api/internal/types/ttlcache"
)
type Server struct {
- policy models.Policy
- store *receipts.ReceiptStore
- eventReader shim.EventReader
- apiToken string
- rateLimiter *tokenBucket
- startedAt time.Time
- logger *log.Logger
- mu sync.Mutex
- statsMu sync.Mutex
- decisions map[string]idempotentDecision
- scans map[string]idempotentCachedResponse
- transforms map[string]idempotentCachedResponse
- anonymizes map[string]idempotentCachedResponse
- totalCount int64
- errorCount int64
- statusHits map[int]int64
- pathHits map[string]int64
- methodHits map[string]int64
- totalLatencyNs int64
- pathLatencyNs map[string]int64
- pathLatencyCounts map[string]int64
+ policy models.Policy
+ policyIndex *policy.PolicyIndex
+ store *receipts.ReceiptStore
+ eventReader shim.EventReader
+ eventSink shim.DecisionEventSink
+ apiToken string
+ rateLimiter *tokenBucket
+ startedAt time.Time
+ logger *log.Logger
+
+ decisions *ttlcache.Cache[idempotentDecisionResponse]
+ scans *ttlcache.Cache[idempotentCachedResponse]
+ transforms *ttlcache.Cache[idempotentCachedResponse]
+ anonymizes *ttlcache.Cache[idempotentCachedResponse]
+ metrics *requestMetrics
}
type requestIDContextKey struct{}
@@ -70,11 +66,67 @@ func (w *responseStatusWriter) Write(body []byte) (int, error) {
return w.ResponseWriter.Write(body)
}
+type gzipResponseWriter struct {
+ ResponseWriter http.ResponseWriter
+ writer *gzip.Writer
+ wroteHeader bool
+ status int
+}
+
+func (w *gzipResponseWriter) Header() http.Header {
+ return w.ResponseWriter.Header()
+}
+
+func (w *gzipResponseWriter) WriteHeader(statusCode int) {
+ if w.wroteHeader {
+ return
+ }
+ w.wroteHeader = true
+ w.status = statusCode
+
+ w.ResponseWriter.Header().Del("Content-Length")
+ w.ResponseWriter.Header().Set("Content-Encoding", "gzip")
+ w.ResponseWriter.Header().Set("Vary", "Accept-Encoding")
+ if w.writer == nil {
+ writer := gzip.NewWriter(w.ResponseWriter)
+ w.writer = writer
+ }
+ w.ResponseWriter.WriteHeader(statusCode)
+}
+
+func (w *gzipResponseWriter) Write(body []byte) (int, error) {
+ if w.status == 0 {
+ w.WriteHeader(http.StatusOK)
+ }
+ if w.writer == nil {
+ w.writer = gzip.NewWriter(w.ResponseWriter)
+ }
+ return w.writer.Write(body)
+}
+
+func (w *gzipResponseWriter) Close() error {
+ if w.writer == nil {
+ return nil
+ }
+ return w.writer.Close()
+}
+
+func shouldCompressResponse(r *http.Request) bool {
+ enc := strings.ToLower(r.Header.Get("Accept-Encoding"))
+ if enc == "" {
+ return false
+ }
+ return strings.Contains(enc, "gzip")
+}
+
const (
maxRequestBodyBytes int64 = 1024 * 1024 // 1 MiB
+
+ defaultIdempotencyCacheSize = 4096
+ defaultIdempotencyTTL = 30 * time.Minute
)
-type idempotentDecision struct {
+type idempotentDecisionResponse struct {
requestHash string
response models.DecideResponse
}
@@ -103,31 +155,77 @@ func New(policyData models.Policy, store *receipts.ReceiptStore, logger *log.Log
}
policyData = policy.NormalizeForEvaluation(policyData)
return &Server{
- policy: policyData,
- store: store,
- apiToken: apiToken,
- rateLimiter: newTokenBucket(rateLimitRPS),
- startedAt: time.Now().UTC(),
- logger: logger,
- decisions: map[string]idempotentDecision{},
- scans: map[string]idempotentCachedResponse{},
- transforms: map[string]idempotentCachedResponse{},
- anonymizes: map[string]idempotentCachedResponse{},
- statusHits: map[int]int64{},
- pathHits: map[string]int64{},
- methodHits: map[string]int64{},
- totalLatencyNs: 0,
- pathLatencyNs: map[string]int64{},
- pathLatencyCounts: map[string]int64{},
+ policy: policyData,
+ policyIndex: policy.BuildPolicyIndex(policyData),
+ store: store,
+ apiToken: apiToken,
+ rateLimiter: newTokenBucket(rateLimitRPS),
+ startedAt: time.Now().UTC(),
+ logger: logger,
+ decisions: ttlcache.New[idempotentDecisionResponse](defaultIdempotencyCacheSize, defaultIdempotencyTTL),
+ scans: ttlcache.New[idempotentCachedResponse](defaultIdempotencyCacheSize, defaultIdempotencyTTL),
+ transforms: ttlcache.New[idempotentCachedResponse](defaultIdempotencyCacheSize, defaultIdempotencyTTL),
+ anonymizes: ttlcache.New[idempotentCachedResponse](defaultIdempotencyCacheSize, defaultIdempotencyTTL),
+ metrics: newRequestMetrics(),
}
}
func (s *Server) SetEventReader(reader shim.EventReader) {
s.eventReader = reader
+ if sink, ok := reader.(shim.DecisionEventSink); ok {
+ s.eventSink = sink
+ }
+}
+
+func (s *Server) Shutdown(ctx context.Context) error {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ done := make(chan error, 1)
+ go func() {
+ if s.decisions != nil {
+ s.decisions.Close()
+ }
+ if s.scans != nil {
+ s.scans.Close()
+ }
+ if s.transforms != nil {
+ s.transforms.Close()
+ }
+ if s.anonymizes != nil {
+ s.anonymizes.Close()
+ }
+ if s.eventSink != nil {
+ if closer, ok := s.eventSink.(interface{ Close() error }); ok {
+ _ = closer.Close()
+ }
+ }
+ if s.store != nil {
+ _ = s.store.Close()
+ }
+ done <- nil
+ }()
+
+ select {
+ case err := <-done:
+ return err
+ case <-ctx.Done():
+ return ctx.Err()
+ }
}
// HandlerWithDemo returns the HTTP handler with optional demo endpoints registered.
func (s *Server) HandlerWithDemo(demo *DemoHandler) http.Handler {
+ return s.HandlerWithDemoAndAdmin(demo, nil)
+}
+
+// HandlerWithAdmin returns the HTTP handler with optional admin endpoints registered.
+func (s *Server) HandlerWithAdmin(admin *AdminHandler) http.Handler {
+ return s.HandlerWithDemoAndAdmin(nil, admin)
+}
+
+// HandlerWithDemoAndAdmin returns the HTTP handler with optional demo and admin endpoints.
+func (s *Server) HandlerWithDemoAndAdmin(demo *DemoHandler, admin *AdminHandler) http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/health", s.handleHealth)
mux.HandleFunc("/v1/policy/version", s.handlePolicyVersion)
@@ -135,17 +233,21 @@ func (s *Server) HandlerWithDemo(demo *DemoHandler) http.Handler {
mux.HandleFunc("/v1/decide", s.handleDecide)
mux.HandleFunc("/v1/transform", s.handleTransform)
mux.HandleFunc("/v1/anonymize", s.handleAnonymize)
+ mux.HandleFunc("/v1/receipts", s.handleReceipts)
mux.HandleFunc("/v1/receipts/", s.handleReceipt)
mux.HandleFunc("/v1/events", s.handleEvents)
mux.HandleFunc("/metrics", s.handleMetrics)
if demo != nil {
demo.Register(mux)
}
+ if admin != nil {
+ admin.Register(mux)
+ }
return s.wrapMiddleware(mux)
}
func (s *Server) Handler() http.Handler {
- return s.HandlerWithDemo(nil)
+ return s.HandlerWithDemoAndAdmin(nil, nil)
}
func (s *Server) wrapMiddleware(mux *http.ServeMux) http.Handler {
@@ -169,25 +271,34 @@ func (s *Server) wrapMiddleware(mux *http.ServeMux) http.Handler {
r = r.WithContext(context.WithValue(r.Context(), requestIDContextKey{}, reqID))
w.Header().Set("X-Request-ID", reqID)
- responseWriter := &responseStatusWriter{ResponseWriter: w}
+ statusWriter := &responseStatusWriter{ResponseWriter: w}
+ responseWriter := http.ResponseWriter(statusWriter)
+ if shouldCompressResponse(r) {
+ responseWriter = &gzipResponseWriter{ResponseWriter: statusWriter}
+ }
startedAt := time.Now()
handler, pattern := mux.Handler(r)
defer func() {
latency := time.Since(startedAt)
if rec := recover(); rec != nil {
- responseWriter.status = http.StatusInternalServerError
+ statusWriter.status = http.StatusInternalServerError
s.logger.Printf("request panic request_id=%s method=%s path=%s err=%v", reqID, r.Method, r.URL.Path, rec)
- s.respondError(responseWriter, http.StatusInternalServerError, models.APIError{Code: "internal_error", Message: "internal server error", RequestID: reqID})
+ s.respondError(statusWriter, http.StatusInternalServerError, models.APIError{Code: "internal_error", Message: "internal server error", RequestID: reqID})
}
- if responseWriter.status == 0 {
- responseWriter.status = http.StatusOK
+ if statusWriter.status == 0 {
+ statusWriter.status = http.StatusOK
}
if pattern == "" {
- s.recordRequestMetrics(r.Method, "/_not_found", responseWriter.status, latency)
+ s.recordRequestMetrics(r.Method, "/_not_found", statusWriter.status, latency)
} else {
- s.recordRequestMetrics(r.Method, canonicalizedRoute(pattern, r.URL.Path), responseWriter.status, latency)
+ s.recordRequestMetrics(r.Method, canonicalizedRoute(pattern, r.URL.Path), statusWriter.status, latency)
+ }
+ s.logger.Printf("request complete request_id=%s method=%s path=%s status=%d latency_ms=%d", reqID, r.Method, r.URL.Path, statusWriter.status, latency.Milliseconds())
+ if gzipWriter, ok := responseWriter.(*gzipResponseWriter); ok {
+ if err := gzipWriter.Close(); err != nil {
+ s.logger.Printf("gzip close failed request_id=%s: %v", reqID, err)
+ }
}
- s.logger.Printf("request complete request_id=%s method=%s path=%s status=%d latency_ms=%d", reqID, r.Method, r.URL.Path, responseWriter.status, latency.Milliseconds())
}()
if !s.authorized(r) {
@@ -286,20 +397,7 @@ func canonicalizedRoute(pattern string, path string) string {
}
func (s *Server) recordRequestMetrics(method string, route string, status int, latency time.Duration) {
- s.statsMu.Lock()
- defer s.statsMu.Unlock()
- s.totalCount++
- s.methodHits[method]++
- s.pathHits[route]++
- s.statusHits[status]++
- if status >= 400 {
- s.errorCount++
- }
-
- latencyNs := latency.Nanoseconds()
- s.totalLatencyNs += latencyNs
- s.pathLatencyNs[route] += latencyNs
- s.pathLatencyCounts[route]++
+ s.metrics.record(method, route, status, latency.Nanoseconds())
}
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
@@ -353,15 +451,12 @@ func (s *Server) handleScan(w http.ResponseWriter, r *http.Request) {
s.respondError(w, http.StatusBadRequest, models.APIError{Code: "invalid_request", Message: "unable to hash request payload", Details: err.Error(), RequestID: requestID(r)})
return
}
- s.mu.Lock()
- existing, ok := s.scans[req.IdempotencyKey]
- s.mu.Unlock()
- if ok {
- if existing.requestHash != reqHash {
+ if cached, ok := s.scans.Get(req.IdempotencyKey); ok {
+ if cached.requestHash != reqHash {
s.respondError(w, http.StatusConflict, models.APIError{Code: "idempotency_conflict", Message: "different request payload for same idempotency_key", RequestID: requestID(r)})
return
}
- s.respondRaw(w, existing.status, existing.body)
+ s.respondRaw(w, cached.status, cached.body)
return
}
}
@@ -381,13 +476,11 @@ func (s *Server) handleScan(w http.ResponseWriter, r *http.Request) {
return
}
hash, _ := hashScanRequest(req)
- s.mu.Lock()
- s.scans[req.IdempotencyKey] = idempotentCachedResponse{
+ s.scans.Set(req.IdempotencyKey, idempotentCachedResponse{
requestHash: hash,
body: body,
status: http.StatusOK,
- }
- s.mu.Unlock()
+ })
s.respondRaw(w, http.StatusOK, body)
return
}
@@ -420,15 +513,12 @@ func (s *Server) handleDecide(w http.ResponseWriter, r *http.Request) {
s.respondError(w, http.StatusBadRequest, models.APIError{Code: "invalid_request", Message: "unable to hash request payload", Details: err.Error(), RequestID: requestID(r)})
return
}
- s.mu.Lock()
- existing, ok := s.decisions[req.IdempotencyKey]
- s.mu.Unlock()
- if ok {
- if existing.requestHash != reqHash {
+ if cached, ok := s.decisions.Get(req.IdempotencyKey); ok {
+ if cached.requestHash != reqHash {
s.respondError(w, http.StatusConflict, models.APIError{Code: "idempotency_conflict", Message: "different request payload for same idempotency_key", RequestID: requestID(r)})
return
}
- s.respond(w, http.StatusOK, existing.response)
+ s.respond(w, http.StatusOK, cached.response)
return
}
}
@@ -437,7 +527,7 @@ func (s *Server) handleDecide(w http.ResponseWriter, r *http.Request) {
if len(findings) == 0 && req.Text != "" {
findings = scan.ScanText(req.Text, nil)
}
- result := policy.EvaluateSorted(s.policy, policy.DecisionContext{Action: req.Action, Findings: findings})
+ result := policy.EvaluateWithIndex(s.policy, s.policyIndex, policy.DecisionContext{Action: req.Action, Findings: findings})
actionHash, err := hashDecideAction(req.Action)
if err != nil {
s.respondError(w, http.StatusInternalServerError, models.APIError{Code: "hash_error", Message: "unable to hash action", Details: err.Error(), RequestID: requestID(r)})
@@ -480,12 +570,10 @@ func (s *Server) handleDecide(w http.ResponseWriter, r *http.Request) {
}
if req.IdempotencyKey != "" {
hash, _ := hashDecideRequest(req)
- s.mu.Lock()
- s.decisions[req.IdempotencyKey] = idempotentDecision{
+ s.decisions.Set(req.IdempotencyKey, idempotentDecisionResponse{
requestHash: hash,
response: res,
- }
- s.mu.Unlock()
+ })
}
s.respond(w, http.StatusOK, res)
}
@@ -540,15 +628,12 @@ func (s *Server) handleTransform(w http.ResponseWriter, r *http.Request) {
s.respondError(w, http.StatusBadRequest, models.APIError{Code: "invalid_request", Message: "unable to hash request payload", Details: err.Error(), RequestID: requestID(r)})
return
}
- s.mu.Lock()
- existing, ok := s.transforms[req.IdempotencyKey]
- s.mu.Unlock()
- if ok {
- if existing.requestHash != reqHash {
+ if cached, ok := s.transforms.Get(req.IdempotencyKey); ok {
+ if cached.requestHash != reqHash {
s.respondError(w, http.StatusConflict, models.APIError{Code: "idempotency_conflict", Message: "different request payload for same idempotency_key", RequestID: requestID(r)})
return
}
- s.respondRaw(w, existing.status, existing.body)
+ s.respondRaw(w, cached.status, cached.body)
return
}
}
@@ -588,13 +673,11 @@ func (s *Server) handleTransform(w http.ResponseWriter, r *http.Request) {
return
}
hash, _ := hashTransformRequest(req)
- s.mu.Lock()
- s.transforms[req.IdempotencyKey] = idempotentCachedResponse{
+ s.transforms.Set(req.IdempotencyKey, idempotentCachedResponse{
requestHash: hash,
body: body,
status: http.StatusOK,
- }
- s.mu.Unlock()
+ })
s.respondRaw(w, http.StatusOK, body)
return
}
@@ -627,15 +710,12 @@ func (s *Server) handleAnonymize(w http.ResponseWriter, r *http.Request) {
s.respondError(w, http.StatusBadRequest, models.APIError{Code: "invalid_request", Message: "unable to hash request payload", Details: err.Error(), RequestID: requestID(r)})
return
}
- s.mu.Lock()
- existing, ok := s.anonymizes[req.IdempotencyKey]
- s.mu.Unlock()
- if ok {
- if existing.requestHash != reqHash {
+ if cached, ok := s.anonymizes.Get(req.IdempotencyKey); ok {
+ if cached.requestHash != reqHash {
s.respondError(w, http.StatusConflict, models.APIError{Code: "idempotency_conflict", Message: "different request payload for same idempotency_key", RequestID: requestID(r)})
return
}
- s.respondRaw(w, existing.status, existing.body)
+ s.respondRaw(w, cached.status, cached.body)
return
}
}
@@ -671,13 +751,11 @@ func (s *Server) handleAnonymize(w http.ResponseWriter, r *http.Request) {
return
}
hash, _ := hashAnonymizeRequest(req)
- s.mu.Lock()
- s.anonymizes[req.IdempotencyKey] = idempotentCachedResponse{
+ s.anonymizes.Set(req.IdempotencyKey, idempotentCachedResponse{
requestHash: hash,
body: body,
status: http.StatusOK,
- }
- s.mu.Unlock()
+ })
s.respondRaw(w, http.StatusOK, body)
return
}
@@ -742,47 +820,51 @@ func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) snapshotMetrics() metricsResponse {
- s.statsMu.Lock()
- defer s.statsMu.Unlock()
- byStatus := map[string]int64{}
- for status, count := range s.statusHits {
- byStatus[strconv.Itoa(status)] = count
+ total, errors, avgLatency, byMethod, byPath, byStatus, byPathLatency := s.metrics.snapshot()
+ return metricsResponse{
+ TotalRequests: total,
+ ErrorRequests: errors,
+ ByStatus: byStatus,
+ ByPath: byPath,
+ ByMethod: byMethod,
+ ByPathLatency: byPathLatency,
+ AvgLatencyMs: avgLatency,
+ StartedAt: s.startedAt.Format(time.RFC3339),
+ UptimeSeconds: time.Since(s.startedAt).Seconds(),
}
+}
- byPath := map[string]int64{}
- for path, count := range s.pathHits {
- byPath[path] = count
+func (s *Server) handleReceipts(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet {
+ s.respondError(w, http.StatusMethodNotAllowed, models.APIError{Code: "method_not_allowed", Message: "method must be GET", RequestID: requestID(r)})
+ return
}
- byMethod := map[string]int64{}
- for method, count := range s.methodHits {
- byMethod[method] = count
+ q := receipts.ListQuery{Limit: 100}
+ if limit := strings.TrimSpace(r.URL.Query().Get("limit")); limit != "" {
+ if n, err := strconv.Atoi(limit); err == nil && n > 0 && n <= 1000 {
+ q.Limit = n
+ }
}
-
- byPathLatency := map[string]float64{}
- for path, count := range s.pathLatencyCounts {
- if count == 0 {
- continue
+ if after := strings.TrimSpace(r.URL.Query().Get("after")); after != "" {
+ if t, err := time.Parse(time.RFC3339, after); err == nil {
+ q.After = &t
}
- byPathLatency[path] = float64(s.pathLatencyNs[path]) / float64(count) / float64(time.Millisecond)
}
-
- avgLatencyMs := 0.0
- if s.totalCount > 0 {
- avgLatencyMs = float64(s.totalLatencyNs) / float64(s.totalCount) / float64(time.Millisecond)
+ if before := strings.TrimSpace(r.URL.Query().Get("before")); before != "" {
+ if t, err := time.Parse(time.RFC3339, before); err == nil {
+ q.Before = &t
+ }
}
-
- return metricsResponse{
- TotalRequests: s.totalCount,
- ErrorRequests: s.errorCount,
- ByStatus: byStatus,
- ByPath: byPath,
- ByMethod: byMethod,
- ByPathLatency: byPathLatency,
- AvgLatencyMs: avgLatencyMs,
- StartedAt: s.startedAt.Format(time.RFC3339),
- UptimeSeconds: time.Since(s.startedAt).Seconds(),
+ if decision := strings.TrimSpace(r.URL.Query().Get("decision")); decision != "" {
+ q.Decision = decision
}
+ if actionType := strings.TrimSpace(r.URL.Query().Get("action_type")); actionType != "" {
+ q.ActionType = actionType
+ }
+
+ entries, total := s.store.List(q)
+ s.respond(w, http.StatusOK, map[string]interface{}{"receipts": entries, "total": total})
}
func (s *Server) handleReceipt(w http.ResponseWriter, r *http.Request) {
@@ -793,7 +875,7 @@ func (s *Server) handleReceipt(w http.ResponseWriter, r *http.Request) {
id := strings.TrimPrefix(r.URL.Path, "/v1/receipts/")
if id == "" || strings.Contains(id, "/") {
- s.respondError(w, http.StatusNotFound, models.APIError{Code: "not_found", Message: "receipt id missing"})
+ s.respondError(w, http.StatusNotFound, models.APIError{Code: "not_found", Message: "receipt id missing", RequestID: requestID(r)})
return
}
receipt, ok := s.store.Get(id)
diff --git a/internal/server/server_benchmark_test.go b/internal/server/server_benchmark_test.go
index 318359f..25eb97e 100644
--- a/internal/server/server_benchmark_test.go
+++ b/internal/server/server_benchmark_test.go
@@ -54,6 +54,9 @@ func benchmarkServer(b *testing.B) *http.Server {
if err != nil {
b.Fatalf("new store: %v", err)
}
+ b.Cleanup(func() {
+ _ = store.Close()
+ })
h := New(testPolicy(), store, log.New(io.Discard, "", 0), "", 0)
return &http.Server{Handler: h.Handler()}
}
diff --git a/internal/server/server_test.go b/internal/server/server_test.go
index 5181c6f..ef6f59b 100644
--- a/internal/server/server_test.go
+++ b/internal/server/server_test.go
@@ -5,6 +5,7 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
+ "os"
"strings"
"testing"
"time"
@@ -41,6 +42,9 @@ func makeServerWithTokenAndRateLimit(t *testing.T, apiToken string, rateLimitRPS
if err != nil {
t.Fatalf("new store: %v", err)
}
+ t.Cleanup(func() {
+ _ = store.Close()
+ })
h := New(testPolicy(), store, nil, apiToken, rateLimitRPS)
return &http.Server{Handler: h.Handler()}
}
@@ -137,6 +141,88 @@ func TestTokenAuth(t *testing.T) {
})
}
+func TestAdminAndReceiptListEndpoints(t *testing.T) {
+ tmp := t.TempDir()
+ storePath := tmp + "/receipts.jsonl"
+ store, err := receipts.NewReceiptStore(storePath)
+ if err != nil {
+ t.Fatalf("new store: %v", err)
+ }
+ t.Cleanup(func() {
+ _ = store.Close()
+ })
+
+ if _, err := store.Save(models.Receipt{ReceiptID: "r1", Decision: models.DecisionAllow, Action: models.ActionMeta{Type: "file.write"}, Timestamp: time.Date(2026, 1, 2, 3, 4, 5, 0, time.UTC)}); err != nil {
+ t.Fatalf("save receipt failed: %v", err)
+ }
+ if _, err := store.Save(models.Receipt{ReceiptID: "r2", Decision: models.DecisionDeny, Action: models.ActionMeta{Type: "shell.exec"}, Timestamp: time.Date(2026, 1, 2, 3, 5, 0, 0, time.UTC)}); err != nil {
+ t.Fatalf("save receipt failed: %v", err)
+ }
+
+ h := New(testPolicy(), store, nil, "", 0)
+ srv := &http.Server{Handler: h.Handler()}
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/receipts", nil)
+ resp := httptest.NewRecorder()
+ srv.Handler.ServeHTTP(resp, req)
+ assertJSONError(t, resp, http.StatusMethodNotAllowed, "method_not_allowed")
+
+ listReq := httptest.NewRequest(http.MethodGet, "/v1/receipts?limit=10", nil)
+ listResp := httptest.NewRecorder()
+ srv.Handler.ServeHTTP(listResp, listReq)
+ if listResp.Code != http.StatusOK {
+ t.Fatalf("expected 200 for receipt list, got %d", listResp.Code)
+ }
+ var got struct {
+ Receipts []models.Receipt `json:"receipts"`
+ Total int `json:"total"`
+ }
+ if err := json.NewDecoder(listResp.Body).Decode(&got); err != nil {
+ t.Fatalf("decode failed: %v", err)
+ }
+ if got.Total != 2 {
+ t.Fatalf("expected 2 total receipts, got %d", got.Total)
+ }
+ if len(got.Receipts) != 2 {
+ t.Fatalf("expected 2 receipts in payload, got %d", len(got.Receipts))
+ }
+ if got.Receipts[0].ReceiptID != "r2" {
+ t.Fatalf("expected newest receipt first, got %s", got.Receipts[0].ReceiptID)
+ }
+
+ filteredReq := httptest.NewRequest(http.MethodGet, "/v1/receipts?decision=allow&limit=10", nil)
+ filteredResp := httptest.NewRecorder()
+ h.Handler().ServeHTTP(filteredResp, filteredReq)
+ if filteredResp.Code != http.StatusOK {
+ t.Fatalf("expected 200 for filtered receipt list, got %d", filteredResp.Code)
+ }
+ if err := json.NewDecoder(filteredResp.Body).Decode(&got); err != nil {
+ t.Fatalf("decode filtered payload failed: %v", err)
+ }
+ if got.Total != 1 {
+ t.Fatalf("expected 1 allow receipt, got %d", got.Total)
+ }
+
+ adminHTMLPath := tmp + "/admin.html"
+ if err := os.WriteFile(adminHTMLPath, []byte("admin-ui-ok"), 0o644); err != nil {
+ t.Fatalf("write admin html: %v", err)
+ }
+ admin, err := NewAdminHandler(adminHTMLPath)
+ if err != nil {
+ t.Fatalf("new admin handler: %v", err)
+ }
+ adminServer := &http.Server{Handler: h.HandlerWithAdmin(admin)}
+ adminReq := httptest.NewRequest(http.MethodGet, "/admin", nil)
+ adminResp := httptest.NewRecorder()
+ adminServer.Handler.ServeHTTP(adminResp, adminReq)
+ if adminResp.Code != http.StatusOK {
+ t.Fatalf("expected admin endpoint 200, got %d", adminResp.Code)
+ }
+ if strings.TrimSpace(adminResp.Body.String()) != "admin-ui-ok" {
+ t.Fatalf("expected custom admin html body")
+ }
+}
+
func TestRateLimit(t *testing.T) {
server := makeServerWithTokenAndRateLimit(t, "", 2)
@@ -719,6 +805,7 @@ func TestValidateMethodAndBadInputs(t *testing.T) {
{name: "transform", method: http.MethodGet, path: "/v1/transform", wantStatus: http.StatusMethodNotAllowed},
{name: "anonymize", method: http.MethodGet, path: "/v1/anonymize", wantStatus: http.StatusMethodNotAllowed},
{name: "receipts", method: http.MethodPost, path: "/v1/receipts/abc", wantStatus: http.StatusMethodNotAllowed},
+ {name: "receipts_list", method: http.MethodPost, path: "/v1/receipts", wantStatus: http.StatusMethodNotAllowed},
}
for _, tc := range tests {
@@ -939,6 +1026,9 @@ func TestEventsAdapterFilterCanonicalized(t *testing.T) {
if err != nil {
t.Fatalf("new store: %v", err)
}
+ t.Cleanup(func() {
+ _ = s.Close()
+ })
h := New(testPolicy(), s, nil, "", 0)
h.SetEventReader(fakeEventReader{events: []shim.DecisionEvent{
{Tool: "claude", Decision: string(models.DecisionAllow)},
diff --git a/internal/shim/events.go b/internal/shim/events.go
index 3d921d3..e94a3a9 100644
--- a/internal/shim/events.go
+++ b/internal/shim/events.go
@@ -3,7 +3,6 @@ package shim
import (
"bufio"
"encoding/json"
- "fmt"
"os"
"path/filepath"
"strings"
@@ -53,38 +52,152 @@ type noopEventSink struct{}
func (s noopEventSink) Record(_ DecisionEvent) {}
type NDJSONDecisionEventSink struct {
- path string
- mu sync.Mutex
+ path string
+ writes chan DecisionEvent
+ closed chan struct{}
+ closeOnce sync.Once
+ writer *bufio.Writer
+ file *os.File
+ flushTimeout time.Duration
+ mu sync.RWMutex
+ isClosed bool
+ writeMu sync.Mutex
}
func NewNDJSONDecisionEventSink(path string) *NDJSONDecisionEventSink {
- return &NDJSONDecisionEventSink{path: path}
+ sink := &NDJSONDecisionEventSink{
+ path: path,
+ writes: make(chan DecisionEvent, 512),
+ closed: make(chan struct{}),
+ flushTimeout: 500 * time.Millisecond,
+ }
+ if path == "" {
+ return sink
+ }
+ if err := sink.openWriter(); err != nil {
+ // Will keep trying when events arrive.
+ }
+ go sink.loop()
+ return sink
}
-func (s *NDJSONDecisionEventSink) Record(event DecisionEvent) {
+func (s *NDJSONDecisionEventSink) Close() error {
if s == nil || s.path == "" {
- return
+ return nil
}
+ s.closeOnce.Do(func() {
+ s.mu.Lock()
+ s.isClosed = true
+ s.mu.Unlock()
+
+ close(s.writes)
+ <-s.closed
+ })
+ return nil
+}
+
+func (s *NDJSONDecisionEventSink) openWriter() error {
+ if s.path == "" {
+ return nil
+ }
if err := os.MkdirAll(filepath.Dir(s.path), 0o750); err != nil {
- return
+ return err
}
- payload, err := json.Marshal(event)
+ file, err := os.OpenFile(s.path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o600)
if err != nil {
- return
+ return err
}
+ s.file = file
+ s.writer = bufio.NewWriter(file)
+ return nil
+}
- s.mu.Lock()
- defer s.mu.Unlock()
+func (s *NDJSONDecisionEventSink) closeWriter() error {
+ s.writeMu.Lock()
+ defer s.writeMu.Unlock()
+ if s.writer != nil {
+ if err := s.writer.Flush(); err != nil {
+ return err
+ }
+ s.writer = nil
+ }
+ if s.file != nil {
+ if err := s.file.Close(); err != nil {
+ s.file = nil
+ return err
+ }
+ s.file = nil
+ }
+ return nil
+}
- file, err := os.OpenFile(s.path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o600)
- if err != nil {
+func (s *NDJSONDecisionEventSink) flushWriter() error {
+ s.writeMu.Lock()
+ defer s.writeMu.Unlock()
+ if s.writer == nil {
+ return nil
+ }
+ if err := s.writer.Flush(); err != nil {
+ return err
+ }
+ return s.file.Sync()
+}
+
+func (s *NDJSONDecisionEventSink) loop() {
+ defer close(s.closed)
+ defer func() {
+ _ = s.closeWriter()
+ }()
+
+ ticker := time.NewTicker(s.flushTimeout)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case event, ok := <-s.writes:
+ if !ok {
+ _ = s.flushWriter()
+ return
+ }
+ s.writeMu.Lock()
+ if s.writer == nil {
+ if err := s.openWriter(); err != nil {
+ s.writeMu.Unlock()
+ continue
+ }
+ }
+ payload, err := json.Marshal(event)
+ if err != nil {
+ s.writeMu.Unlock()
+ continue
+ }
+ _, _ = s.writer.Write(append(payload, '\n'))
+ s.writeMu.Unlock()
+ case <-ticker.C:
+ _ = s.flushWriter()
+ }
+ }
+}
+
+func (s *NDJSONDecisionEventSink) Record(event DecisionEvent) {
+ if s == nil || s.path == "" {
return
}
- defer file.Close()
- _, _ = fmt.Fprintln(file, string(payload))
+ s.mu.RLock()
+ closed := s.isClosed
+ s.mu.RUnlock()
+ if closed {
+ return
+ }
+
+ select {
+ case s.writes <- event:
+ default:
+ // Drop events when the buffer is full.
+ }
}
// Query reads events from the NDJSON file and applies filters.
@@ -93,8 +206,8 @@ func (s *NDJSONDecisionEventSink) Query(q EventQuery) ([]DecisionEvent, error) {
return nil, nil
}
- s.mu.Lock()
- defer s.mu.Unlock()
+ s.mu.RLock()
+ defer s.mu.RUnlock()
f, err := os.Open(s.path)
if err != nil {
@@ -137,6 +250,9 @@ func (s *NDJSONDecisionEventSink) Query(q EventQuery) ([]DecisionEvent, error) {
break
}
}
+ if err := scanner.Err(); err != nil {
+ return nil, err
+ }
- return events, scanner.Err()
+ return events, nil
}
diff --git a/internal/shim/events_test.go b/internal/shim/events_test.go
new file mode 100644
index 0000000..66b8106
--- /dev/null
+++ b/internal/shim/events_test.go
@@ -0,0 +1,75 @@
+package shim
+
+import (
+ "testing"
+ "time"
+)
+
+func TestNDJSONDecisionEventSinkWritesAndQueriesAsync(t *testing.T) {
+ path := t.TempDir() + "/events.ndjson"
+ sink := NewNDJSONDecisionEventSink(path)
+ if sink == nil {
+ t.Fatalf("expected sink")
+ }
+ t.Cleanup(func() {
+ _ = sink.Close()
+ })
+
+ expected := DecisionEvent{
+ Timestamp: time.Now().UTC(),
+ Mode: "observe",
+ ActionType: "shell.exec",
+ Tool: "claude",
+ Decision: "allow",
+ Allowed: true,
+ }
+ sink.Record(expected)
+
+ deadline := time.Now().Add(2 * time.Second)
+ var got []DecisionEvent
+ for {
+ events, err := sink.Query(EventQuery{Limit: 10})
+ if err != nil {
+ t.Fatalf("query failed: %v", err)
+ }
+ if len(events) == 1 {
+ got = events
+ break
+ }
+ if time.Now().After(deadline) {
+ t.Fatalf("expected async event to be persisted and queryable, got %d events", len(events))
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+ if got[0].Tool != expected.Tool {
+ t.Fatalf("expected tool %q, got %q", expected.Tool, got[0].Tool)
+ }
+}
+
+func TestNDJSONDecisionEventSinkCloseFlushesPendingEvents(t *testing.T) {
+ path := t.TempDir() + "/events.ndjson"
+ sink := NewNDJSONDecisionEventSink(path)
+ if sink == nil {
+ t.Fatalf("expected sink")
+ }
+
+ sink.Record(DecisionEvent{Timestamp: time.Now().UTC(), Tool: "vcs", Decision: "deny"})
+ sink.Record(DecisionEvent{Timestamp: time.Now().UTC(), Tool: "claude", Decision: "allow"})
+ sink.Record(DecisionEvent{Timestamp: time.Now().UTC(), Tool: "editor", Decision: "allow"})
+
+ if err := sink.Close(); err != nil {
+ t.Fatalf("close failed: %v", err)
+ }
+
+ events, err := sink.Query(EventQuery{Limit: 100})
+ if err != nil {
+ t.Fatalf("query failed: %v", err)
+ }
+ if len(events) != 3 {
+ t.Fatalf("expected 3 events after close, got %d", len(events))
+ }
+
+ if err := sink.Close(); err != nil {
+ t.Fatalf("idempotent close failed: %v", err)
+ }
+}
diff --git a/internal/types/ttlcache/ttlcache.go b/internal/types/ttlcache/ttlcache.go
new file mode 100644
index 0000000..f753d9c
--- /dev/null
+++ b/internal/types/ttlcache/ttlcache.go
@@ -0,0 +1,202 @@
+package ttlcache
+
+import (
+ "container/list"
+ "sync"
+ "time"
+)
+
+const defaultCleanupInterval = 30 * time.Second
+
+// Cache is a concurrency-safe TTL cache with LRU eviction.
+// It is intentionally small and purpose-built for in-process request-caching.
+type Cache[V any] struct {
+ mu sync.Mutex
+ maxSize int
+ ttl time.Duration
+ items map[string]*list.Element
+ order *list.List // front = most recently used
+ clock func() time.Time
+ cleanupInterval time.Duration
+ stop chan struct{}
+ stopped chan struct{}
+}
+
+type item[V any] struct {
+ key string
+ value V
+ expireAt time.Time
+}
+
+func New[V any](maxSize int, ttl time.Duration, options ...Option[V]) *Cache[V] {
+ cfg := &Config[V]{
+ cleanupInterval: 0,
+ }
+ for _, option := range options {
+ option(cfg)
+ }
+
+ if cfg.cleanupInterval < 0 {
+ cfg.cleanupInterval = 0
+ }
+ if cfg.clock == nil {
+ cfg.clock = time.Now
+ }
+
+ c := &Cache[V]{
+ maxSize: maxSize,
+ ttl: ttl,
+ items: make(map[string]*list.Element),
+ order: list.New(),
+ clock: cfg.clock,
+ cleanupInterval: cfg.cleanupInterval,
+ stop: make(chan struct{}),
+ stopped: make(chan struct{}),
+ }
+
+ if c.cleanupInterval == 0 {
+ c.cleanupInterval = defaultCleanupInterval
+ }
+ if c.ttl > 0 && c.cleanupInterval > 0 {
+ go c.cleanupLoop()
+ }
+ return c
+}
+
+type Config[V any] struct {
+ clock func() time.Time
+ cleanupInterval time.Duration
+}
+
+type Option[V any] func(*Config[V])
+
+func WithClock[V any](clock func() time.Time) Option[V] {
+ return func(cfg *Config[V]) {
+ cfg.clock = clock
+ }
+}
+
+func WithCleanupInterval[V any](interval time.Duration) Option[V] {
+ return func(cfg *Config[V]) {
+ cfg.cleanupInterval = interval
+ }
+}
+
+func (c *Cache[V]) Get(key string) (V, bool) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ var zero V
+ e := c.items[key]
+ if e == nil {
+ return zero, false
+ }
+ entry := e.Value.(*item[V])
+ if c.isExpired(entry) {
+ c.remove(e)
+ return zero, false
+ }
+ c.order.MoveToFront(e)
+ return entry.value, true
+}
+
+func (c *Cache[V]) Set(key string, value V) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if existing := c.items[key]; existing != nil {
+ existing.Value.(*item[V]).value = value
+ existing.Value.(*item[V]).expireAt = c.now().Add(c.ttl)
+ c.order.MoveToFront(existing)
+ return
+ }
+
+ entry := &item[V]{
+ key: key,
+ value: value,
+ expireAt: c.now().Add(c.ttl),
+ }
+ e := c.order.PushFront(entry)
+ c.items[key] = e
+ c.enforceLimitLocked()
+}
+
+func (c *Cache[V]) Delete(key string) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if e := c.items[key]; e != nil {
+ c.remove(e)
+ }
+}
+
+func (c *Cache[V]) Len() int {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.evictExpiredLocked()
+ return len(c.items)
+}
+
+func (c *Cache[V]) Close() {
+ select {
+ case <-c.stop:
+ return
+ default:
+ close(c.stop)
+ }
+ <-c.stopped
+}
+
+func (c *Cache[V]) enforceLimitLocked() {
+ if c.maxSize <= 0 {
+ return
+ }
+ for len(c.items) > c.maxSize {
+ back := c.order.Back()
+ if back == nil {
+ return
+ }
+ c.remove(back)
+ }
+}
+
+func (c *Cache[V]) remove(e *list.Element) {
+ ent := e.Value.(*item[V])
+ c.order.Remove(e)
+ delete(c.items, ent.key)
+}
+
+func (c *Cache[V]) cleanupLoop() {
+ defer close(c.stopped)
+ ticker := time.NewTicker(c.cleanupInterval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ c.mu.Lock()
+ c.evictExpiredLocked()
+ c.mu.Unlock()
+ case <-c.stop:
+ return
+ }
+ }
+}
+
+func (c *Cache[V]) evictExpiredLocked() {
+ for e := c.order.Back(); e != nil; {
+ prev := e.Prev()
+ if c.isExpired(e.Value.(*item[V])) {
+ c.remove(e)
+ }
+ e = prev
+ }
+}
+
+func (c *Cache[V]) isExpired(it *item[V]) bool {
+ return c.ttl > 0 && !it.expireAt.IsZero() && !it.expireAt.After(c.now())
+}
+
+func (c *Cache[V]) now() time.Time {
+ return c.clock()
+}