Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 38 additions & 23 deletions tests/k8s/k8s_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ var (
clientset *kubernetes.Clientset
namespace string
kubeconfig string
pgPort int
portFwdCmd *exec.Cmd
portForward *portForwardState
testEnv k8sTestEnvironment
)

Expand Down Expand Up @@ -66,6 +65,19 @@ func TestMain(m *testing.M) {
if err != nil {
log.Fatalf("Failed to create k8s client: %v", err)
}
portForward = newPortForwardState(
func() (int, *exec.Cmd, error) {
return startPortForward(namespace, duckgresServiceTarget, duckgresServicePort)
},
waitForPort,
func(cmd *exec.Cmd) {
if cmd == nil || cmd.Process == nil {
return
}
_ = cmd.Process.Kill()
_ = cmd.Wait()
},
)

if _, err := waitForSingleReadyPod(namespace, "app=duckgres-control-plane", 90*time.Second); err != nil {
log.Fatalf("Control-plane pod not ready: %v", err)
Expand Down Expand Up @@ -197,6 +209,7 @@ func TestK8sWorkerCrashRecovery(t *testing.T) {

func TestK8sMultipleConcurrentConnections(t *testing.T) {
const n = 5
const timeout = 75 * time.Second
var wg sync.WaitGroup
errs := make(chan error, n)

Expand All @@ -205,7 +218,7 @@ func TestK8sMultipleConcurrentConnections(t *testing.T) {
go func(id int) {
defer wg.Done()
query := fmt.Sprintf("SELECT %d", id)
if err := retryDBOperationWithReconnect(30*time.Second, fmt.Sprintf("concurrent query %q", query), func(ctx context.Context, db *sql.DB) error {
if err := retryDBOperationWithReconnect(timeout, fmt.Sprintf("concurrent query %q", query), func(ctx context.Context, db *sql.DB) error {
var result int
if err := db.QueryRowContext(ctx, query).Scan(&result); err != nil {
return err
Expand Down Expand Up @@ -439,33 +452,24 @@ func startPortForward(ns, target string, remotePort int) (int, *exec.Cmd, error)
}

func closePortForward() {
if portFwdCmd == nil || portFwdCmd.Process == nil {
portFwdCmd = nil
if portForward == nil {
return
}

_ = portFwdCmd.Process.Kill()
_ = portFwdCmd.Wait()
portFwdCmd = nil
portForward.closeCurrent()
}

func restartPortForward() error {
closePortForward()

localPort, cmd, err := startPortForward(namespace, duckgresServiceTarget, duckgresServicePort)
if err != nil {
return err
if portForward == nil {
return fmt.Errorf("port-forward state is not initialized")
}
return portForward.restart(30 * time.Second)
}

pgPort = localPort
portFwdCmd = cmd

if err := waitForPort(pgPort, 30*time.Second); err != nil {
closePortForward()
return err
func restartPortForwardIfStale(stalePort int) error {
if portForward == nil {
return fmt.Errorf("port-forward state is not initialized")
}

return nil
return portForward.restartIfStale(stalePort, 30*time.Second)
}

func waitForPort(port int, timeout time.Duration) error {
Expand Down Expand Up @@ -602,10 +606,17 @@ func openDBConn() (*sql.DB, error) {
}

func openDBConnAs(username, password string) (*sql.DB, error) {
if portForward == nil {
return nil, fmt.Errorf("port-forward state is not initialized")
}
databaseName := username
if username == "postgres" {
databaseName = "duckgres"
}
pgPort := portForward.currentPort()
if pgPort == 0 {
return nil, fmt.Errorf("port-forward port is not initialized")
}

// kubectl port-forward passes raw TCP bytes, so the client still needs
// SSL. lib/pq sslmode=require skips server cert verification by default,
Expand Down Expand Up @@ -692,6 +703,10 @@ func retryDBOperationWithReconnectAs(username, password string, timeout time.Dur
deadline := time.Now().Add(timeout)
var lastErr error
for time.Now().Before(deadline) {
stalePort := 0
if portForward != nil {
stalePort = portForward.currentPort()
}
db, err := openDBConnAs(username, password)
if err == nil {
attemptCtx, cancel := context.WithTimeout(context.Background(), dbAttemptTimeout)
Expand All @@ -705,7 +720,7 @@ func retryDBOperationWithReconnectAs(username, password string, timeout time.Dur

lastErr = err
if isTransientDBError(err) {
if restartErr := restartPortForward(); restartErr != nil {
if restartErr := restartPortForwardIfStale(stalePort); restartErr != nil {
lastErr = fmt.Errorf("%w; restart port-forward: %v", err, restartErr)
}
}
Expand Down
183 changes: 183 additions & 0 deletions tests/k8s/port_forward_helper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package k8s_test

import (
"fmt"
"os/exec"
"sync"
"testing"
"time"
)

type portForwardState struct {
mu sync.Mutex
port int
cmd *exec.Cmd
start func() (int, *exec.Cmd, error)
wait func(int, time.Duration) error
close func(*exec.Cmd)
}

func newPortForwardState(
start func() (int, *exec.Cmd, error),
wait func(int, time.Duration) error,
closeFn func(*exec.Cmd),
) *portForwardState {
return &portForwardState{
start: start,
wait: wait,
close: closeFn,
}
}

func (p *portForwardState) currentPort() int {
p.mu.Lock()
defer p.mu.Unlock()
return p.port
}

func (p *portForwardState) closeCurrent() {
p.mu.Lock()
defer p.mu.Unlock()
p.closeLocked()
}

func (p *portForwardState) restart(timeout time.Duration) error {
return p.restartIfStale(0, timeout)
}

func (p *portForwardState) restartIfStale(stalePort int, timeout time.Duration) error {
p.mu.Lock()
defer p.mu.Unlock()

if stalePort != 0 && p.port != 0 && p.port != stalePort {
if err := p.wait(p.port, 2*time.Second); err == nil {
return nil
}
}

p.closeLocked()

nextPort, nextCmd, err := p.start()
if err != nil {
return err
}
if err := p.wait(nextPort, timeout); err != nil {
p.close(nextCmd)
return err
}

p.port = nextPort
p.cmd = nextCmd
return nil
}

func (p *portForwardState) closeLocked() {
if p.cmd == nil {
return
}
p.close(p.cmd)
p.cmd = nil
}

func TestPortForwardStateRestartIfStaleSkipsHealthyReplacement(t *testing.T) {
startCalls := 0
state := newPortForwardState(
func() (int, *exec.Cmd, error) {
startCalls++
return 3333, &exec.Cmd{}, nil
},
func(port int, timeout time.Duration) error {
if port == 2222 {
return nil
}
return fmt.Errorf("port %d unreachable", port)
},
func(*exec.Cmd) {},
)
state.port = 2222

if err := state.restartIfStale(1111, 30*time.Second); err != nil {
t.Fatalf("restartIfStale returned error: %v", err)
}
if startCalls != 0 {
t.Fatalf("start called %d times, want 0", startCalls)
}
if got := state.currentPort(); got != 2222 {
t.Fatalf("currentPort() = %d, want 2222", got)
}
}

func TestPortForwardStateRestartIfStaleReplacesUnhealthyPort(t *testing.T) {
startCalls := 0
state := newPortForwardState(
func() (int, *exec.Cmd, error) {
startCalls++
return 3333, &exec.Cmd{}, nil
},
func(port int, timeout time.Duration) error {
if port == 3333 {
return nil
}
return fmt.Errorf("port %d unreachable", port)
},
func(*exec.Cmd) {},
)
state.port = 1111
state.cmd = &exec.Cmd{}

if err := state.restartIfStale(1111, 30*time.Second); err != nil {
t.Fatalf("restartIfStale returned error: %v", err)
}
if startCalls != 1 {
t.Fatalf("start called %d times, want 1", startCalls)
}
if got := state.currentPort(); got != 3333 {
t.Fatalf("currentPort() = %d, want 3333", got)
}
}

func TestPortForwardStateRestartUsesDefaultStalePort(t *testing.T) {
startCalls := 0
state := newPortForwardState(
func() (int, *exec.Cmd, error) {
startCalls++
return 4444, &exec.Cmd{}, nil
},
func(port int, timeout time.Duration) error {
if port == 4444 {
return nil
}
return fmt.Errorf("port %d unreachable", port)
},
func(*exec.Cmd) {},
)

if err := state.restart(30 * time.Second); err != nil {
t.Fatalf("restart returned error: %v", err)
}
if startCalls != 1 {
t.Fatalf("start called %d times, want 1", startCalls)
}
if got := state.currentPort(); got != 4444 {
t.Fatalf("currentPort() = %d, want 4444", got)
}
}

func TestPortForwardStateCloseCurrentClearsCommand(t *testing.T) {
closed := 0
state := newPortForwardState(
func() (int, *exec.Cmd, error) { return 0, nil, nil },
func(int, time.Duration) error { return nil },
func(*exec.Cmd) { closed++ },
)
state.cmd = &exec.Cmd{}

state.closeCurrent()

if closed != 1 {
t.Fatalf("close called %d times, want 1", closed)
}
if state.cmd != nil {
t.Fatal("expected cmd to be cleared")
}
}
45 changes: 44 additions & 1 deletion tests/k8s/tenant_isolation_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,11 @@ func psqlLiteral(value string) string {
}

func minioPrefixFileCount(prefix string) (int, error) {
trimmedPrefix := strings.Trim(prefix, "/")
cmd := exec.Command(
"docker", "exec", "duckgres-local-minio",
"sh", "-lc",
fmt.Sprintf("find /data/duckgres-local/%s -type f 2>/dev/null | wc -l", prefix),
fmt.Sprintf("mc ls --recursive local/duckgres-local/%s 2>/dev/null | wc -l", trimmedPrefix),
)
out, err := cmd.Output()
if err != nil {
Expand All @@ -317,3 +318,45 @@ func waitForMinioPrefixFileCountAtLeast(prefix string, minimum int, timeout time
}
return fmt.Errorf("prefix %s did not reach %d files within %s", prefix, minimum, timeout)
}

func waitForMinioPrefixFileCountToStayAtMost(prefix string, maximum int, duration time.Duration) error {
deadline := time.Now().Add(duration)
for time.Now().Before(deadline) {
count, err := minioPrefixFileCount(prefix)
if err != nil {
return err
}
if count > maximum {
return fmt.Errorf("prefix %s exceeded %d files during stability window: got %d", prefix, maximum, count)
}
time.Sleep(2 * time.Second)
}
return nil
}

func ensureWorkerPodLacksServiceAccountToken(podName string) error {
pod, err := clientset.CoreV1().Pods(namespace).Get(context.Background(), podName, metav1.GetOptions{})
if err != nil {
return fmt.Errorf("get worker pod %s: %w", podName, err)
}
if len(pod.Spec.Containers) == 0 {
return fmt.Errorf("worker pod %s has no containers", podName)
}
containerName := pod.Spec.Containers[0].Name

cmd := exec.Command(
"kubectl", "-n", namespace, "exec", podName, "-c", containerName, "--",
"sh", "-lc",
"if [ -e /var/run/secrets/kubernetes.io/serviceaccount/token ]; then " +
"echo 'service account token present'; " +
"ls -la /var/run/secrets/kubernetes.io/serviceaccount || true; " +
"exit 1; " +
"fi",
)
cmd.Env = commandEnv()
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("%w: %s", err, strings.TrimSpace(string(out)))
}
return nil
}
Loading
Loading