From f3a3cab96bae43ea0d7ea642e971eae47326b5d8 Mon Sep 17 00:00:00 2001 From: Bob Bass Date: Mon, 9 Mar 2026 23:53:02 -0500 Subject: [PATCH 1/2] feat: add opt-in file persistence for DuckDB MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a `file_persistence` config option (YAML, env var, CLI flag) that stores DuckDB data in `/.duckdb` instead of using per-connection in-memory databases. When enabled, DuckDB memory-maps the file and serves from RAM — giving Redis-like performance with disk durability. Data persists across client disconnects and duckgres restarts. The default behavior is unchanged: `file_persistence` defaults to false, so every connection still gets an isolated ephemeral in-memory DuckDB. Config resolution follows the existing YAML → env → CLI precedence: - YAML: `file_persistence: true` - Env: `DUCKGRES_FILE_PERSISTENCE=true` - CLI: `--file-persistence` --- config_resolution.go | 12 ++++++ main.go | 4 ++ main_test.go | 54 ++++++++++++++++++++++++ server/server.go | 21 ++++++++-- server/server_test.go | 97 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 185 insertions(+), 3 deletions(-) diff --git a/config_resolution.go b/config_resolution.go index 0baa73b4..fa80600a 100644 --- a/config_resolution.go +++ b/config_resolution.go @@ -21,6 +21,7 @@ type configCLIInputs struct { DataDir string CertFile string KeyFile string + FilePersistence bool ProcessIsolation bool IdleTimeout string MemoryLimit string @@ -225,6 +226,7 @@ func resolveEffectiveConfig(fileCfg *FileConfig, cli configCLIInputs, getenv fun cfg.DuckLake.S3Profile = fileCfg.DuckLake.S3Profile } + cfg.FilePersistence = fileCfg.FilePersistence cfg.ProcessIsolation = fileCfg.ProcessIsolation if fileCfg.IdleTimeout != "" { if d, err := time.ParseDuration(fileCfg.IdleTimeout); err == nil { @@ -435,6 +437,13 @@ func resolveEffectiveConfig(fileCfg *FileConfig, cli configCLIInputs, getenv fun if v := getenv("DUCKGRES_DUCKLAKE_S3_PROFILE"); v != "" { cfg.DuckLake.S3Profile = v } + if v := getenv("DUCKGRES_FILE_PERSISTENCE"); v != "" { + if b, err := strconv.ParseBool(v); err == nil { + cfg.FilePersistence = b + } else { + warn("Invalid DUCKGRES_FILE_PERSISTENCE: " + err.Error()) + } + } if v := getenv("DUCKGRES_PROCESS_ISOLATION"); v != "" { if b, err := strconv.ParseBool(v); err == nil { cfg.ProcessIsolation = b @@ -638,6 +647,9 @@ func resolveEffectiveConfig(fileCfg *FileConfig, cli configCLIInputs, getenv fun if cli.Set["key"] { cfg.TLSKeyFile = cli.KeyFile } + if cli.Set["file-persistence"] { + cfg.FilePersistence = cli.FilePersistence + } if cli.Set["process-isolation"] { cfg.ProcessIsolation = cli.ProcessIsolation } diff --git a/main.go b/main.go index 65ff577f..bd9f5d01 100644 --- a/main.go +++ b/main.go @@ -35,6 +35,7 @@ type FileConfig struct { RateLimit RateLimitFileConfig `yaml:"rate_limit"` Extensions []string `yaml:"extensions"` DuckLake DuckLakeFileConfig `yaml:"ducklake"` + FilePersistence bool `yaml:"file_persistence"` // Persist DuckDB to /.duckdb instead of :memory: ProcessIsolation bool `yaml:"process_isolation"` // Enable process isolation per connection IdleTimeout string `yaml:"idle_timeout"` // e.g., "24h", "1h", "-1" to disable MemoryLimit string `yaml:"memory_limit"` // DuckDB memory_limit per session (e.g., "4GB") @@ -194,6 +195,7 @@ func main() { dataDir := flag.String("data-dir", "", "Directory for DuckDB files (env: DUCKGRES_DATA_DIR)") certFile := flag.String("cert", "", "TLS certificate file (env: DUCKGRES_CERT)") keyFile := flag.String("key", "", "TLS private key file (env: DUCKGRES_KEY)") + filePersistence := flag.Bool("file-persistence", false, "Persist DuckDB to /.duckdb instead of in-memory (env: DUCKGRES_FILE_PERSISTENCE)") processIsolation := flag.Bool("process-isolation", false, "Enable process isolation (spawn child process per connection)") idleTimeout := flag.String("idle-timeout", "", "Connection idle timeout (e.g., '30m', '1h', '-1' to disable) (env: DUCKGRES_IDLE_TIMEOUT)") memoryLimit := flag.String("memory-limit", "", "DuckDB memory_limit per session (e.g., '4GB') (env: DUCKGRES_MEMORY_LIMIT)") @@ -259,6 +261,7 @@ func main() { fmt.Fprintf(os.Stderr, " DUCKGRES_DATA_DIR Directory for DuckDB files (default: ./data)\n") fmt.Fprintf(os.Stderr, " DUCKGRES_CERT TLS certificate file (default: ./certs/server.crt)\n") fmt.Fprintf(os.Stderr, " DUCKGRES_KEY TLS private key file (default: ./certs/server.key)\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_FILE_PERSISTENCE Persist DuckDB to /.duckdb (1 or true)\n") fmt.Fprintf(os.Stderr, " DUCKGRES_PROCESS_ISOLATION Enable process isolation (1 or true)\n") fmt.Fprintf(os.Stderr, " DUCKGRES_IDLE_TIMEOUT Connection idle timeout (e.g., 30m, 1h, -1 to disable)\n") fmt.Fprintf(os.Stderr, " DUCKGRES_MEMORY_LIMIT DuckDB memory_limit per session (e.g., 4GB)\n") @@ -353,6 +356,7 @@ func main() { DataDir: *dataDir, CertFile: *certFile, KeyFile: *keyFile, + FilePersistence: *filePersistence, ProcessIsolation: *processIsolation, IdleTimeout: *idleTimeout, MemoryLimit: *memoryLimit, diff --git a/main_test.go b/main_test.go index 8999c595..ba1daa2a 100644 --- a/main_test.go +++ b/main_test.go @@ -684,6 +684,60 @@ func TestResolveEffectiveConfigACMEDNSProviderValidation(t *testing.T) { } } +func TestResolveEffectiveConfigFilePersistenceFromFile(t *testing.T) { + fileCfg := &FileConfig{ + FilePersistence: true, + DataDir: "/tmp/data", + } + resolved := resolveEffectiveConfig(fileCfg, configCLIInputs{}, envFromMap(nil), nil) + if !resolved.Server.FilePersistence { + t.Fatal("expected file_persistence from YAML to be true") + } +} + +func TestResolveEffectiveConfigFilePersistenceFromEnv(t *testing.T) { + env := map[string]string{ + "DUCKGRES_FILE_PERSISTENCE": "true", + } + resolved := resolveEffectiveConfig(nil, configCLIInputs{}, envFromMap(env), nil) + if !resolved.Server.FilePersistence { + t.Fatal("expected file_persistence from env to be true") + } +} + +func TestResolveEffectiveConfigFilePersistenceEnvOverridesFile(t *testing.T) { + fileCfg := &FileConfig{ + FilePersistence: true, + } + env := map[string]string{ + "DUCKGRES_FILE_PERSISTENCE": "false", + } + resolved := resolveEffectiveConfig(fileCfg, configCLIInputs{}, envFromMap(env), nil) + if resolved.Server.FilePersistence { + t.Fatal("expected env false to override file true") + } +} + +func TestResolveEffectiveConfigFilePersistenceCLIOverridesEnv(t *testing.T) { + env := map[string]string{ + "DUCKGRES_FILE_PERSISTENCE": "false", + } + resolved := resolveEffectiveConfig(nil, configCLIInputs{ + Set: map[string]bool{"file-persistence": true}, + FilePersistence: true, + }, envFromMap(env), nil) + if !resolved.Server.FilePersistence { + t.Fatal("expected CLI true to override env false") + } +} + +func TestResolveEffectiveConfigFilePersistenceDefaultFalse(t *testing.T) { + resolved := resolveEffectiveConfig(nil, configCLIInputs{}, envFromMap(nil), nil) + if resolved.Server.FilePersistence { + t.Fatal("expected file_persistence to default to false") + } +} + func TestResolveEffectiveConfigACMEDNSRequiresDomain(t *testing.T) { fileCfg := &FileConfig{ TLS: TLSConfig{ diff --git a/server/server.go b/server/server.go index 2a4bc317..607af0ba 100644 --- a/server/server.go +++ b/server/server.go @@ -154,6 +154,11 @@ type Config struct { // uncleanly. Default: 24 hours. Set to a negative value (e.g., -1) to disable. IdleTimeout time.Duration + // FilePersistence stores DuckDB data in /.duckdb instead of :memory:. + // DuckDB memory-maps the file and serves queries from RAM, so performance is similar + // to in-memory mode while data persists across connections and restarts. + FilePersistence bool + // ProcessIsolation enables spawning each client connection in a separate OS process. // This prevents DuckDB C++ crashes from taking down the entire server. // When enabled, rate limiting and cancel requests are handled by the parent process, @@ -625,11 +630,21 @@ func (s *Server) createDBConnection(username string) (*sql.DB, error) { return CreateDBConnection(s.cfg, s.duckLakeSem, username, processStartTime, processVersion) } -// openBaseDB creates and configures a bare DuckDB in-memory connection with -// threads, memory limit, temp directory, extensions, and cache_httpfs settings. +// openBaseDB creates and configures a DuckDB connection with threads, memory +// limit, temp directory, extensions, and cache_httpfs settings. // This shared setup is used by both regular and passthrough connections. +// +// When DataDir is set, the database is file-backed at /.duckdb. +// DuckDB memory-maps the file and serves queries from RAM (like Redis with AOF), +// so performance is equivalent to in-memory while data persists across restarts. +// When DataDir is empty, falls back to a pure in-memory database. func openBaseDB(cfg Config, username string) (*sql.DB, error) { - db, err := sql.Open("duckdb", ":memory:") + dsn := ":memory:" + if cfg.FilePersistence && cfg.DataDir != "" && username != "" { + dsn = filepath.Join(cfg.DataDir, username+".duckdb") + slog.Info("Opening file-backed DuckDB.", "path", dsn) + } + db, err := sql.Open("duckdb", dsn) if err != nil { return nil, fmt.Errorf("failed to open duckdb: %w", err) } diff --git a/server/server_test.go b/server/server_test.go index d688f53a..40f47d7b 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -309,6 +309,103 @@ func TestStartCredentialRefresh_RollbackAndRetryWhenNoActiveTransaction(t *testi } } +func TestOpenBaseDBInMemoryByDefault(t *testing.T) { + cfg := Config{} + db, err := openBaseDB(cfg, "testuser") + if err != nil { + t.Fatalf("openBaseDB failed: %v", err) + } + defer func() { _ = db.Close() }() + + var dbName string + err = db.QueryRow("SELECT current_database()").Scan(&dbName) + if err != nil { + t.Fatalf("failed to query current_database(): %v", err) + } + if dbName != "memory" { + t.Fatalf("expected in-memory database (current_database()='memory'), got %q", dbName) + } +} + +func TestOpenBaseDBFilePersistence(t *testing.T) { + dataDir := t.TempDir() + cfg := Config{ + FilePersistence: true, + DataDir: dataDir, + } + db, err := openBaseDB(cfg, "alice") + if err != nil { + t.Fatalf("openBaseDB failed: %v", err) + } + + // Write data + if _, err := db.Exec("CREATE TABLE test_persist (id INTEGER)"); err != nil { + t.Fatalf("failed to create table: %v", err) + } + if _, err := db.Exec("INSERT INTO test_persist VALUES (42)"); err != nil { + t.Fatalf("failed to insert: %v", err) + } + _ = db.Close() + + // Reopen the same file and verify data survives + db2, err := openBaseDB(cfg, "alice") + if err != nil { + t.Fatalf("openBaseDB (reopen) failed: %v", err) + } + defer func() { _ = db2.Close() }() + + var val int + err = db2.QueryRow("SELECT id FROM test_persist").Scan(&val) + if err != nil { + t.Fatalf("failed to read persisted data: %v", err) + } + if val != 42 { + t.Fatalf("expected persisted value 42, got %d", val) + } +} + +func TestOpenBaseDBFilePersistenceFallsBackWithoutDataDir(t *testing.T) { + cfg := Config{ + FilePersistence: true, + // DataDir intentionally empty + } + db, err := openBaseDB(cfg, "testuser") + if err != nil { + t.Fatalf("openBaseDB failed: %v", err) + } + defer func() { _ = db.Close() }() + + var dbName string + err = db.QueryRow("SELECT current_database()").Scan(&dbName) + if err != nil { + t.Fatalf("failed to query current_database(): %v", err) + } + if dbName != "memory" { + t.Fatalf("expected fallback to in-memory when DataDir is empty, got %q", dbName) + } +} + +func TestOpenBaseDBFilePersistenceFallsBackWithoutUsername(t *testing.T) { + cfg := Config{ + FilePersistence: true, + DataDir: t.TempDir(), + } + db, err := openBaseDB(cfg, "") + if err != nil { + t.Fatalf("openBaseDB failed: %v", err) + } + defer func() { _ = db.Close() }() + + var dbName string + err = db.QueryRow("SELECT current_database()").Scan(&dbName) + if err != nil { + t.Fatalf("failed to query current_database(): %v", err) + } + if dbName != "memory" { + t.Fatalf("expected fallback to in-memory when username is empty, got %q", dbName) + } +} + func TestHasCacheHTTPFS(t *testing.T) { tests := []struct { name string From 45aa7729358e04b119cc2b5d325fdc5663da0444 Mon Sep 17 00:00:00 2001 From: Bob Bass Date: Wed, 11 Mar 2026 22:38:28 -0500 Subject: [PATCH 2/2] fix: address review feedback for file persistence (#300) - Reject usernames with path traversal characters (/, \, ..) in file-persistence mode - Add per-user shared DB pool so concurrent connections reuse the same DuckDB file handle - Warn and disable file_persistence when data_dir is empty instead of silent fallback - Create data directory with os.MkdirAll before opening file-backed DBs --- config_resolution.go | 5 ++ main_test.go | 31 +++++++++++ server/conn.go | 73 ++++++++++++++++++++------ server/executor.go | 53 +++++++++++++++++++ server/server.go | 76 +++++++++++++++++++++++++++ server/server_test.go | 117 ++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 340 insertions(+), 15 deletions(-) diff --git a/config_resolution.go b/config_resolution.go index fa80600a..d2d04dc0 100644 --- a/config_resolution.go +++ b/config_resolution.go @@ -745,6 +745,11 @@ func resolveEffectiveConfig(fileCfg *FileConfig, cli configCLIInputs, getenv fun cfg.QueryLog.Enabled = cli.QueryLog } + if cfg.FilePersistence && cfg.DataDir == "" { + warn("file_persistence is enabled but data_dir is empty; disabling file persistence") + cfg.FilePersistence = false + } + if cfg.ACMEDNSProvider != "" { provider := strings.ToLower(cfg.ACMEDNSProvider) if provider != "route53" { diff --git a/main_test.go b/main_test.go index ba1daa2a..7c23d669 100644 --- a/main_test.go +++ b/main_test.go @@ -771,3 +771,34 @@ func TestResolveEffectiveConfigACMEDNSRequiresDomain(t *testing.T) { t.Fatalf("expected warning about missing ACME domain for DNS mode, warnings: %v", warns) } } + +func TestFilePersistenceRequiresDataDir(t *testing.T) { + var warns []string + // Use CLI to explicitly set data-dir to empty, overriding the default. + resolved := resolveEffectiveConfig( + &FileConfig{ + FilePersistence: true, + }, + configCLIInputs{ + Set: map[string]bool{"data-dir": true}, + DataDir: "", + }, + nil, + func(msg string) { warns = append(warns, msg) }, + ) + + if resolved.Server.FilePersistence { + t.Fatal("expected FilePersistence to be disabled when DataDir is empty") + } + + found := false + for _, w := range warns { + if strings.Contains(w, "file_persistence is enabled but data_dir is empty") { + found = true + break + } + } + if !found { + t.Fatalf("expected warning about empty data_dir, warnings: %v", warns) + } +} diff --git a/server/conn.go b/server/conn.go index 23f318e1..9412ed19 100644 --- a/server/conn.go +++ b/server/conn.go @@ -166,6 +166,10 @@ type clientConn struct { ctx context.Context // connection context, cancelled when connection is closed cancel context.CancelFunc // cancels the connection context + // sharedDB is true when this connection uses a shared file-persistence DB pool. + // Cleanup differs: we return the pinned conn to the pool instead of closing the DB. + sharedDB bool + // pg_stat_activity fields backendStart time.Time // when this connection started applicationName string // from startup params @@ -359,6 +363,28 @@ func (c *clientConn) safeCleanupDB() { }() cleanupTimeout := 5 * time.Second + + if c.sharedDB { + // Shared file-persistence pool: ROLLBACK any open transaction on the + // pinned connection, then return it to the pool. Skip DuckLake DETACH + // since the underlying DB is shared across connections. + if c.txStatus == txStatusTransaction || c.txStatus == txStatusError { + ctx, cancel := context.WithTimeout(context.Background(), cleanupTimeout) + _, err := c.executor.ExecContext(ctx, "ROLLBACK") + cancel() + if err != nil { + slog.Warn("Failed to rollback transaction during cleanup.", + "user", c.username, "error", err) + } + } + // Close returns the pinned *sql.Conn to the pool (does not close the DB). + if err := c.executor.Close(); err != nil { + slog.Warn("Failed to return connection to pool.", "user", c.username, "error", err) + } + c.server.releaseFileDB(c.username) + return + } + connHealthy := true // Check connection health. For DuckLake, we need to actually run a query that @@ -562,23 +588,40 @@ func (c *clientConn) serve() error { // Create a DuckDB connection for this client session (unless pre-created by caller) var stopRefresh func() if c.executor == nil { - var db *sql.DB - var err error - if c.passthrough { - db, err = CreatePassthroughDBConnection(c.server.cfg, c.server.duckLakeSem, c.username, processStartTime, processVersion) + if c.server.cfg.FilePersistence { + db, err := c.server.acquireFileDB(c.username, c.passthrough) + if err != nil { + c.sendError("FATAL", "28000", fmt.Sprintf("failed to open database: %v", err)) + return err + } + conn, err := db.Conn(c.ctx) + if err != nil { + c.server.releaseFileDB(c.username) + c.sendError("FATAL", "28000", fmt.Sprintf("failed to get pooled connection: %v", err)) + return err + } + c.executor = NewPinnedExecutor(conn, db) + c.sharedDB = true + // Don't start per-connection credential refresh; the pool manages it. } else { - db, err = c.server.createDBConnection(c.username) - } - if err != nil { - c.sendError("FATAL", "28000", fmt.Sprintf("failed to open database: %v", err)) - return err - } - c.executor = NewLocalExecutor(db) + var db *sql.DB + var err error + if c.passthrough { + db, err = CreatePassthroughDBConnection(c.server.cfg, c.server.duckLakeSem, c.username, processStartTime, processVersion) + } else { + db, err = c.server.createDBConnection(c.username) + } + if err != nil { + c.sendError("FATAL", "28000", fmt.Sprintf("failed to open database: %v", err)) + return err + } + c.executor = NewLocalExecutor(db) - // Start background credential refresh for long-lived connections. - // Only needed when we create the DB here; the control plane manages - // refresh for pre-created connections via DBPool. - stopRefresh = StartCredentialRefresh(db, c.server.cfg.DuckLake) + // Start background credential refresh for long-lived connections. + // Only needed when we create the DB here; the control plane manages + // refresh for pre-created connections via DBPool. + stopRefresh = StartCredentialRefresh(db, c.server.cfg.DuckLake) + } } // Defers run LIFO: close cursors first (they hold open RowSets), then stop // credential refresh, then clean up the database connection. diff --git a/server/executor.go b/server/executor.go index 8dc83076..d9a11339 100644 --- a/server/executor.go +++ b/server/executor.go @@ -96,6 +96,59 @@ func (e *LocalExecutor) Close() error { return e.db.Close() } +// PinnedExecutor wraps a pinned *sql.Conn from a shared *sql.DB pool +// to implement QueryExecutor for file-persistence mode. +type PinnedExecutor struct { + conn *sql.Conn + db *sql.DB +} + +func NewPinnedExecutor(conn *sql.Conn, db *sql.DB) *PinnedExecutor { + return &PinnedExecutor{conn: conn, db: db} +} + +// DB returns the underlying *sql.DB (for credential refresh and other direct access). +func (e *PinnedExecutor) DB() *sql.DB { + return e.db +} + +func (e *PinnedExecutor) QueryContext(ctx context.Context, query string, args ...any) (RowSet, error) { + rows, err := e.conn.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + return &LocalRowSet{rows: rows}, nil +} + +func (e *PinnedExecutor) ExecContext(ctx context.Context, query string, args ...any) (ExecResult, error) { + return e.conn.ExecContext(ctx, query, args...) +} + +func (e *PinnedExecutor) Query(query string, args ...any) (RowSet, error) { + rows, err := e.conn.QueryContext(context.Background(), query, args...) + if err != nil { + return nil, err + } + return &LocalRowSet{rows: rows}, nil +} + +func (e *PinnedExecutor) Exec(query string, args ...any) (ExecResult, error) { + return e.conn.ExecContext(context.Background(), query, args...) +} + +func (e *PinnedExecutor) ConnContext(ctx context.Context) (RawConn, error) { + return e.db.Conn(ctx) +} + +func (e *PinnedExecutor) PingContext(ctx context.Context) error { + return e.conn.PingContext(ctx) +} + +// Close returns the pinned connection to the pool; it does not close the underlying DB. +func (e *PinnedExecutor) Close() error { + return e.conn.Close() +} + // LocalRowSet wraps *sql.Rows to implement RowSet. type LocalRowSet struct { rows *sql.Rows diff --git a/server/server.go b/server/server.go index 607af0ba..ddecd355 100644 --- a/server/server.go +++ b/server/server.go @@ -244,6 +244,14 @@ type DuckLakeConfig struct { S3Profile string // AWS profile name to use (for "config" chain) } +// fileDBEntry tracks a shared *sql.DB for file-persistence mode. +// One entry per user file; multiple PG connections share the pool via pinned *sql.Conn. +type fileDBEntry struct { + db *sql.DB + refs int + stopRefresh func() // credential refresh goroutine +} + type Server struct { cfg Config listener net.Listener @@ -280,6 +288,11 @@ type Server struct { // Query logger for DuckLake system.query_log queryLogger *QueryLogger + + // Per-user shared DB pool for file persistence mode. + // Each user gets one *sql.DB; PG connections share it via pinned *sql.Conn. + fileDBsMu sync.Mutex + fileDBs map[string]*fileDBEntry } func New(cfg Config) (*Server, error) { @@ -327,6 +340,7 @@ func New(cfg Config) (*Server, error) { activeQueries: make(map[BackendKey]context.CancelFunc), duckLakeSem: make(chan struct{}, 1), conns: make(map[int32]*clientConn), + fileDBs: make(map[string]*fileDBEntry), } // Configure TLS: ACME DNS-01, ACME HTTP-01, or static certificate files @@ -630,6 +644,62 @@ func (s *Server) createDBConnection(username string) (*sql.DB, error) { return CreateDBConnection(s.cfg, s.duckLakeSem, username, processStartTime, processVersion) } +// acquireFileDB returns a shared *sql.DB for the given user, creating one if needed. +// The caller must call releaseFileDB when the connection is no longer needed. +func (s *Server) acquireFileDB(username string, passthrough bool) (*sql.DB, error) { + s.fileDBsMu.Lock() + defer s.fileDBsMu.Unlock() + + if entry, ok := s.fileDBs[username]; ok { + entry.refs++ + return entry.db, nil + } + + var db *sql.DB + var err error + if passthrough { + db, err = CreatePassthroughDBConnection(s.cfg, s.duckLakeSem, username, processStartTime, processVersion) + } else { + db, err = CreateDBConnection(s.cfg, s.duckLakeSem, username, processStartTime, processVersion) + } + if err != nil { + return nil, err + } + + // openBaseDB sets MaxOpenConns(1) for single-session use; override for shared pool. + db.SetMaxOpenConns(0) // unlimited + db.SetMaxIdleConns(4) + + stopRefresh := StartCredentialRefresh(db, s.cfg.DuckLake) + + s.fileDBs[username] = &fileDBEntry{ + db: db, + refs: 1, + stopRefresh: stopRefresh, + } + return db, nil +} + +// releaseFileDB decrements the ref count for a user's shared DB. +// When the last reference is released, the DB is closed and removed from the pool. +func (s *Server) releaseFileDB(username string) { + s.fileDBsMu.Lock() + defer s.fileDBsMu.Unlock() + + entry, ok := s.fileDBs[username] + if !ok { + return + } + entry.refs-- + if entry.refs <= 0 { + if entry.stopRefresh != nil { + entry.stopRefresh() + } + _ = entry.db.Close() + delete(s.fileDBs, username) + } +} + // openBaseDB creates and configures a DuckDB connection with threads, memory // limit, temp directory, extensions, and cache_httpfs settings. // This shared setup is used by both regular and passthrough connections. @@ -641,6 +711,12 @@ func (s *Server) createDBConnection(username string) (*sql.DB, error) { func openBaseDB(cfg Config, username string) (*sql.DB, error) { dsn := ":memory:" if cfg.FilePersistence && cfg.DataDir != "" && username != "" { + if strings.ContainsAny(username, "/\\") || strings.Contains(username, "..") { + return nil, fmt.Errorf("invalid username for file persistence: %q (contains path separator or ..)", username) + } + if err := os.MkdirAll(cfg.DataDir, 0750); err != nil { + return nil, fmt.Errorf("failed to create data directory %s: %w", cfg.DataDir, err) + } dsn = filepath.Join(cfg.DataDir, username+".duckdb") slog.Info("Opening file-backed DuckDB.", "path", dsn) } diff --git a/server/server_test.go b/server/server_test.go index 40f47d7b..127ae3f6 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" "errors" + "os" + "path/filepath" "strings" "sync" "sync/atomic" @@ -430,3 +432,118 @@ func TestHasCacheHTTPFS(t *testing.T) { }) } } + +func TestOpenBaseDBFilePersistenceRejectsPathTraversal(t *testing.T) { + dataDir := t.TempDir() + cfg := Config{ + FilePersistence: true, + DataDir: dataDir, + } + + cases := []struct { + name string + username string + }{ + {"parent directory", "../etc/evil"}, + {"slash in name", "foo/bar"}, + {"backslash dot-dot", "..\\windows"}, + {"dot-dot between names", "alice/../bob"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + db, err := openBaseDB(cfg, tc.username) + if err == nil { + _ = db.Close() + t.Fatalf("expected error for username %q, got nil", tc.username) + } + if !strings.Contains(err.Error(), "invalid username for file persistence") { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestOpenBaseDBCreatesDataDir(t *testing.T) { + base := t.TempDir() + nested := filepath.Join(base, "deep", "nested", "dir") + + cfg := Config{ + FilePersistence: true, + DataDir: nested, + } + db, err := openBaseDB(cfg, "testuser") + if err != nil { + t.Fatalf("openBaseDB failed: %v", err) + } + defer func() { _ = db.Close() }() + + info, err := os.Stat(nested) + if err != nil { + t.Fatalf("data directory was not created: %v", err) + } + if !info.IsDir() { + t.Fatalf("expected directory, got file") + } +} + +func TestFileDBPoolRefCounting(t *testing.T) { + dataDir := t.TempDir() + s := &Server{ + cfg: Config{ + FilePersistence: true, + DataDir: dataDir, + }, + fileDBs: make(map[string]*fileDBEntry), + duckLakeSem: make(chan struct{}, 1), + } + + // First acquire + db1, err := s.acquireFileDB("alice", false) + if err != nil { + t.Fatalf("first acquireFileDB failed: %v", err) + } + + // Second acquire should return the same *sql.DB + db2, err := s.acquireFileDB("alice", false) + if err != nil { + t.Fatalf("second acquireFileDB failed: %v", err) + } + if db1 != db2 { + t.Fatal("expected same *sql.DB for same user, got different instances") + } + + // Check ref count is 2 + s.fileDBsMu.Lock() + refs := s.fileDBs["alice"].refs + s.fileDBsMu.Unlock() + if refs != 2 { + t.Fatalf("expected refs=2, got %d", refs) + } + + // Release one — DB should still be open + s.releaseFileDB("alice") + s.fileDBsMu.Lock() + entry, exists := s.fileDBs["alice"] + s.fileDBsMu.Unlock() + if !exists { + t.Fatal("entry removed too early (refs should be 1)") + } + if entry.refs != 1 { + t.Fatalf("expected refs=1, got %d", entry.refs) + } + + // DB should still work + if err := db1.Ping(); err != nil { + t.Fatalf("db should still be usable: %v", err) + } + + // Release last — DB should be closed and removed + s.releaseFileDB("alice") + s.fileDBsMu.Lock() + _, exists = s.fileDBs["alice"] + s.fileDBsMu.Unlock() + if exists { + t.Fatal("entry should be removed after last release") + } +}