diff --git a/tests/k8s/k8s_test.go b/tests/k8s/k8s_test.go index bfee534..452abb6 100644 --- a/tests/k8s/k8s_test.go +++ b/tests/k8s/k8s_test.go @@ -27,8 +27,7 @@ var ( clientset *kubernetes.Clientset namespace string kubeconfig string - pgPort int - portFwdCmd *exec.Cmd + portForward *portForwardState testEnv k8sTestEnvironment ) @@ -66,6 +65,19 @@ func TestMain(m *testing.M) { if err != nil { log.Fatalf("Failed to create k8s client: %v", err) } + portForward = newPortForwardState( + func() (int, *exec.Cmd, error) { + return startPortForward(namespace, duckgresServiceTarget, duckgresServicePort) + }, + waitForPort, + func(cmd *exec.Cmd) { + if cmd == nil || cmd.Process == nil { + return + } + _ = cmd.Process.Kill() + _ = cmd.Wait() + }, + ) if _, err := waitForSingleReadyPod(namespace, "app=duckgres-control-plane", 90*time.Second); err != nil { log.Fatalf("Control-plane pod not ready: %v", err) @@ -197,6 +209,7 @@ func TestK8sWorkerCrashRecovery(t *testing.T) { func TestK8sMultipleConcurrentConnections(t *testing.T) { const n = 5 + const timeout = 75 * time.Second var wg sync.WaitGroup errs := make(chan error, n) @@ -205,7 +218,7 @@ func TestK8sMultipleConcurrentConnections(t *testing.T) { go func(id int) { defer wg.Done() query := fmt.Sprintf("SELECT %d", id) - if err := retryDBOperationWithReconnect(30*time.Second, fmt.Sprintf("concurrent query %q", query), func(ctx context.Context, db *sql.DB) error { + if err := retryDBOperationWithReconnect(timeout, fmt.Sprintf("concurrent query %q", query), func(ctx context.Context, db *sql.DB) error { var result int if err := db.QueryRowContext(ctx, query).Scan(&result); err != nil { return err @@ -439,33 +452,24 @@ func startPortForward(ns, target string, remotePort int) (int, *exec.Cmd, error) } func closePortForward() { - if portFwdCmd == nil || portFwdCmd.Process == nil { - portFwdCmd = nil + if portForward == nil { return } - - _ = portFwdCmd.Process.Kill() - _ = portFwdCmd.Wait() - portFwdCmd = nil + portForward.closeCurrent() } func restartPortForward() error { - closePortForward() - - localPort, cmd, err := startPortForward(namespace, duckgresServiceTarget, duckgresServicePort) - if err != nil { - return err + if portForward == nil { + return fmt.Errorf("port-forward state is not initialized") } + return portForward.restart(30 * time.Second) +} - pgPort = localPort - portFwdCmd = cmd - - if err := waitForPort(pgPort, 30*time.Second); err != nil { - closePortForward() - return err +func restartPortForwardIfStale(stalePort int) error { + if portForward == nil { + return fmt.Errorf("port-forward state is not initialized") } - - return nil + return portForward.restartIfStale(stalePort, 30*time.Second) } func waitForPort(port int, timeout time.Duration) error { @@ -602,10 +606,17 @@ func openDBConn() (*sql.DB, error) { } func openDBConnAs(username, password string) (*sql.DB, error) { + if portForward == nil { + return nil, fmt.Errorf("port-forward state is not initialized") + } databaseName := username if username == "postgres" { databaseName = "duckgres" } + pgPort := portForward.currentPort() + if pgPort == 0 { + return nil, fmt.Errorf("port-forward port is not initialized") + } // kubectl port-forward passes raw TCP bytes, so the client still needs // SSL. lib/pq sslmode=require skips server cert verification by default, @@ -692,6 +703,10 @@ func retryDBOperationWithReconnectAs(username, password string, timeout time.Dur deadline := time.Now().Add(timeout) var lastErr error for time.Now().Before(deadline) { + stalePort := 0 + if portForward != nil { + stalePort = portForward.currentPort() + } db, err := openDBConnAs(username, password) if err == nil { attemptCtx, cancel := context.WithTimeout(context.Background(), dbAttemptTimeout) @@ -705,7 +720,7 @@ func retryDBOperationWithReconnectAs(username, password string, timeout time.Dur lastErr = err if isTransientDBError(err) { - if restartErr := restartPortForward(); restartErr != nil { + if restartErr := restartPortForwardIfStale(stalePort); restartErr != nil { lastErr = fmt.Errorf("%w; restart port-forward: %v", err, restartErr) } } diff --git a/tests/k8s/port_forward_helper_test.go b/tests/k8s/port_forward_helper_test.go new file mode 100644 index 0000000..18ddf87 --- /dev/null +++ b/tests/k8s/port_forward_helper_test.go @@ -0,0 +1,183 @@ +package k8s_test + +import ( + "fmt" + "os/exec" + "sync" + "testing" + "time" +) + +type portForwardState struct { + mu sync.Mutex + port int + cmd *exec.Cmd + start func() (int, *exec.Cmd, error) + wait func(int, time.Duration) error + close func(*exec.Cmd) +} + +func newPortForwardState( + start func() (int, *exec.Cmd, error), + wait func(int, time.Duration) error, + closeFn func(*exec.Cmd), +) *portForwardState { + return &portForwardState{ + start: start, + wait: wait, + close: closeFn, + } +} + +func (p *portForwardState) currentPort() int { + p.mu.Lock() + defer p.mu.Unlock() + return p.port +} + +func (p *portForwardState) closeCurrent() { + p.mu.Lock() + defer p.mu.Unlock() + p.closeLocked() +} + +func (p *portForwardState) restart(timeout time.Duration) error { + return p.restartIfStale(0, timeout) +} + +func (p *portForwardState) restartIfStale(stalePort int, timeout time.Duration) error { + p.mu.Lock() + defer p.mu.Unlock() + + if stalePort != 0 && p.port != 0 && p.port != stalePort { + if err := p.wait(p.port, 2*time.Second); err == nil { + return nil + } + } + + p.closeLocked() + + nextPort, nextCmd, err := p.start() + if err != nil { + return err + } + if err := p.wait(nextPort, timeout); err != nil { + p.close(nextCmd) + return err + } + + p.port = nextPort + p.cmd = nextCmd + return nil +} + +func (p *portForwardState) closeLocked() { + if p.cmd == nil { + return + } + p.close(p.cmd) + p.cmd = nil +} + +func TestPortForwardStateRestartIfStaleSkipsHealthyReplacement(t *testing.T) { + startCalls := 0 + state := newPortForwardState( + func() (int, *exec.Cmd, error) { + startCalls++ + return 3333, &exec.Cmd{}, nil + }, + func(port int, timeout time.Duration) error { + if port == 2222 { + return nil + } + return fmt.Errorf("port %d unreachable", port) + }, + func(*exec.Cmd) {}, + ) + state.port = 2222 + + if err := state.restartIfStale(1111, 30*time.Second); err != nil { + t.Fatalf("restartIfStale returned error: %v", err) + } + if startCalls != 0 { + t.Fatalf("start called %d times, want 0", startCalls) + } + if got := state.currentPort(); got != 2222 { + t.Fatalf("currentPort() = %d, want 2222", got) + } +} + +func TestPortForwardStateRestartIfStaleReplacesUnhealthyPort(t *testing.T) { + startCalls := 0 + state := newPortForwardState( + func() (int, *exec.Cmd, error) { + startCalls++ + return 3333, &exec.Cmd{}, nil + }, + func(port int, timeout time.Duration) error { + if port == 3333 { + return nil + } + return fmt.Errorf("port %d unreachable", port) + }, + func(*exec.Cmd) {}, + ) + state.port = 1111 + state.cmd = &exec.Cmd{} + + if err := state.restartIfStale(1111, 30*time.Second); err != nil { + t.Fatalf("restartIfStale returned error: %v", err) + } + if startCalls != 1 { + t.Fatalf("start called %d times, want 1", startCalls) + } + if got := state.currentPort(); got != 3333 { + t.Fatalf("currentPort() = %d, want 3333", got) + } +} + +func TestPortForwardStateRestartUsesDefaultStalePort(t *testing.T) { + startCalls := 0 + state := newPortForwardState( + func() (int, *exec.Cmd, error) { + startCalls++ + return 4444, &exec.Cmd{}, nil + }, + func(port int, timeout time.Duration) error { + if port == 4444 { + return nil + } + return fmt.Errorf("port %d unreachable", port) + }, + func(*exec.Cmd) {}, + ) + + if err := state.restart(30 * time.Second); err != nil { + t.Fatalf("restart returned error: %v", err) + } + if startCalls != 1 { + t.Fatalf("start called %d times, want 1", startCalls) + } + if got := state.currentPort(); got != 4444 { + t.Fatalf("currentPort() = %d, want 4444", got) + } +} + +func TestPortForwardStateCloseCurrentClearsCommand(t *testing.T) { + closed := 0 + state := newPortForwardState( + func() (int, *exec.Cmd, error) { return 0, nil, nil }, + func(int, time.Duration) error { return nil }, + func(*exec.Cmd) { closed++ }, + ) + state.cmd = &exec.Cmd{} + + state.closeCurrent() + + if closed != 1 { + t.Fatalf("close called %d times, want 1", closed) + } + if state.cmd != nil { + t.Fatal("expected cmd to be cleared") + } +} diff --git a/tests/k8s/tenant_isolation_helper_test.go b/tests/k8s/tenant_isolation_helper_test.go index 294224c..8212237 100644 --- a/tests/k8s/tenant_isolation_helper_test.go +++ b/tests/k8s/tenant_isolation_helper_test.go @@ -290,10 +290,11 @@ func psqlLiteral(value string) string { } func minioPrefixFileCount(prefix string) (int, error) { + trimmedPrefix := strings.Trim(prefix, "/") cmd := exec.Command( "docker", "exec", "duckgres-local-minio", "sh", "-lc", - fmt.Sprintf("find /data/duckgres-local/%s -type f 2>/dev/null | wc -l", prefix), + fmt.Sprintf("mc ls --recursive local/duckgres-local/%s 2>/dev/null | wc -l", trimmedPrefix), ) out, err := cmd.Output() if err != nil { @@ -317,3 +318,45 @@ func waitForMinioPrefixFileCountAtLeast(prefix string, minimum int, timeout time } return fmt.Errorf("prefix %s did not reach %d files within %s", prefix, minimum, timeout) } + +func waitForMinioPrefixFileCountToStayAtMost(prefix string, maximum int, duration time.Duration) error { + deadline := time.Now().Add(duration) + for time.Now().Before(deadline) { + count, err := minioPrefixFileCount(prefix) + if err != nil { + return err + } + if count > maximum { + return fmt.Errorf("prefix %s exceeded %d files during stability window: got %d", prefix, maximum, count) + } + time.Sleep(2 * time.Second) + } + return nil +} + +func ensureWorkerPodLacksServiceAccountToken(podName string) error { + pod, err := clientset.CoreV1().Pods(namespace).Get(context.Background(), podName, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("get worker pod %s: %w", podName, err) + } + if len(pod.Spec.Containers) == 0 { + return fmt.Errorf("worker pod %s has no containers", podName) + } + containerName := pod.Spec.Containers[0].Name + + cmd := exec.Command( + "kubectl", "-n", namespace, "exec", podName, "-c", containerName, "--", + "sh", "-lc", + "if [ -e /var/run/secrets/kubernetes.io/serviceaccount/token ]; then " + + "echo 'service account token present'; " + + "ls -la /var/run/secrets/kubernetes.io/serviceaccount || true; " + + "exit 1; " + + "fi", + ) + cmd.Env = commandEnv() + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("%w: %s", err, strings.TrimSpace(string(out))) + } + return nil +} diff --git a/tests/k8s/tenant_isolation_test.go b/tests/k8s/tenant_isolation_test.go index a1a2c07..1b54cb3 100644 --- a/tests/k8s/tenant_isolation_test.go +++ b/tests/k8s/tenant_isolation_test.go @@ -85,6 +85,69 @@ func TestK8sTenantIsolation_DifferentTenantsSeeDistinctCatalogs(t *testing.T) { } } +func TestK8sTenantIsolation_WritesStayInOwnObjectStorePrefix(t *testing.T) { + analyticsTable := fmt.Sprintf("analytics_prefix_%d", time.Now().UnixNano()) + billingTable := fmt.Sprintf("billing_prefix_%d", time.Now().UnixNano()) + + analyticsPrefix := "orgs/analytics" + billingPrefix := "orgs/billing" + + analyticsBefore, err := minioPrefixFileCount(analyticsPrefix) + if err != nil { + t.Fatalf("count analytics prefix before write: %v", err) + } + billingBefore, err := minioPrefixFileCount(billingPrefix) + if err != nil { + t.Fatalf("count billing prefix before write: %v", err) + } + + analyticsDB, err := openDBConnAs("analytics", "postgres") + if err != nil { + t.Fatalf("open analytics DB: %v", err) + } + if _, err := execDBWithTimeout(analyticsDB, "CREATE OR REPLACE TABLE "+analyticsTable+" AS SELECT i AS value, repeat('x', 4096) AS payload FROM generate_series(1, 2048) AS t(i)"); err != nil { + _ = analyticsDB.Close() + t.Fatalf("create analytics table: %v", err) + } + if err := analyticsDB.Close(); err != nil { + t.Fatalf("close analytics DB: %v", err) + } + + if err := waitForMinioPrefixFileCountAtLeast(analyticsPrefix, analyticsBefore+1, 60*time.Second); err != nil { + t.Fatalf("wait for analytics prefix growth: %v", err) + } + if err := waitForMinioPrefixFileCountToStayAtMost(billingPrefix, billingBefore, 8*time.Second); err != nil { + t.Fatalf("billing prefix changed during analytics write: %v", err) + } + + billingDB, err := openDBConnAs("billing", "postgres") + if err != nil { + t.Fatalf("open billing DB: %v", err) + } + if _, err := execDBWithTimeout(billingDB, "CREATE OR REPLACE TABLE "+billingTable+" AS SELECT i AS value, repeat('x', 4096) AS payload FROM generate_series(1, 2048) AS t(i)"); err != nil { + _ = billingDB.Close() + t.Fatalf("create billing table: %v", err) + } + if err := billingDB.Close(); err != nil { + t.Fatalf("close billing DB: %v", err) + } + + if err := waitForMinioPrefixFileCountAtLeast(billingPrefix, billingBefore+1, 60*time.Second); err != nil { + t.Fatalf("wait for billing prefix growth: %v", err) + } +} + +func TestK8sWorkerPodsDoNotMountServiceAccountTokens(t *testing.T) { + if err := retryQueryWithReconnect("SELECT 1", 30*time.Second); err != nil { + t.Fatalf("query failed: %v", err) + } + + pod := latestWorkerPod(t) + if err := ensureWorkerPodLacksServiceAccountToken(pod.Name); err != nil { + t.Fatalf("worker pod %s has ambient service account token: %v", pod.Name, err) + } +} + func execDBWithTimeout(db *sql.DB, query string, args ...any) (sql.Result, error) { ctx, cancel := context.WithTimeout(context.Background(), dbAttemptTimeout) defer cancel()