diff --git a/postgres/transport_test.go b/postgres/transport_test.go index 11549a1..ec5fde7 100644 --- a/postgres/transport_test.go +++ b/postgres/transport_test.go @@ -76,6 +76,56 @@ func TestTransport(t *testing.T) { speedupIsAboveMinRatio(t, cold, cached, 0.75) } +func TestTransportInvalidateAllResponses(t *testing.T) { + t.Setenv("TESTCONTAINERS_RYUK_DISABLED", "true") + container, fn := startContainer(t, "user", "password") + defer fn() + + connstr, err := container.ConnectionString(t.Context(), "sslmode=disable") + if err != nil { + t.Fatalf("generating postgres connection string: %v", err) + } + + delay := 3 * time.Second + srv := httptest.NewServer(delayedResponse(delay)) + defer srv.Close() + + db, err := pgx.Connect(t.Context(), connstr) + if err != nil { + t.Fatalf("connecting to postgres database: %v", err) + } + + tr, err := httpcache.NewTransport(t.Context(), Connection{db}, nil) + if err != nil { + t.Fatalf("couldn't initialize transport for test: %v", err) + } + + client := &http.Client{Transport: tr} + + _, err = measureDuration(get(client, srv.URL)) + if err != nil { + t.Fatal(err) + } + + _, err = measureDuration(get(client, srv.URL)) + if err != nil { + t.Fatal(err) + } + + if err := tr.InvalidateAllResponses(t.Context()); err != nil { + t.Fatalf("couldn't invalidate all responses in the cache: %v", err) + } + + cold, err := measureDuration(get(client, srv.URL)) + if err != nil { + t.Fatal(err) + } + + if cold <= (delay) { + t.Fatalf("expected cold latency to be lower than or equal to %d, got %d\n", delay, cold) + } +} + func speedupIsAboveMinRatio(t *testing.T, cold, cached time.Duration, minRatio float64) { t.Helper()