diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 05babc4..083f0f4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,7 +8,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v6 with: - go-version: "1.24" + go-version: "1.25" - name: Check out code uses: actions/checkout@v5 diff --git a/go.mod b/go.mod index ffcc660..1fb6548 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/ostcar/topic -go 1.24 +go 1.25 diff --git a/topic.go b/topic.go index 2361343..dd47fb8 100644 --- a/topic.go +++ b/topic.go @@ -78,16 +78,17 @@ func (t *Topic[T]) Publish(value ...T) uint64 { // topic. It is not allowed to manipulate the values. func (t *Topic[T]) Receive(ctx context.Context, id uint64) (uint64, []T, error) { t.mu.RLock() + lastIDWhenStarted := t.lastID() // Request data, that is not in the topic yet. Block until the next // Publish() call. - if t.data == nil || id >= t.lastID() { + if t.data == nil || id >= lastIDWhenStarted { c := t.signal t.mu.RUnlock() select { case <-c: - return t.Receive(ctx, id) + return t.Receive(ctx, lastIDWhenStarted) case <-ctx.Done(): return 0, nil, ctx.Err() } diff --git a/topic_test.go b/topic_test.go index 2f1a96a..f086d68 100644 --- a/topic_test.go +++ b/topic_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "sort" + "sync" "testing" "time" @@ -69,6 +70,54 @@ func TestPublishReceive(t *testing.T) { } } +func TestPublishCreatesIncreasingIDs(t *testing.T) { + top := topic.New[string]() + + id1 := top.Publish("v1") + id2 := top.Publish("v2", "v3") + id3 := top.Publish("v4") + + if !(id1 < id2 && id2 < id3) { + t.Errorf("Got ids %d %d %d, expected increasing", id1, id2, id3) + } +} + +func TestPublishWithoutValues(t *testing.T) { + top := topic.New[string]() + top.Publish("v1") + + id := top.Publish() + + if id != 1 { + t.Errorf("Publish() without values returned %d, expected 1", id) + } + + _, data, err := top.Receive(t.Context(), 0) + if err != nil { + t.Errorf("Receive() error: %v", err) + } + if len(data) != 1 { + t.Errorf("Expected 1 value, got %d", len(data)) + } +} + +func TestReceiveWithIDEqualLastID(t *testing.T) { + top := topic.New[string]() + top.Publish("v1") + lastID := top.Publish("v2") + + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond) + defer cancel() + + _, data, err := top.Receive(ctx, lastID) + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("Expected context deadline exceeded, got %v", err) + } + if data != nil { + t.Errorf("Expected nil data, got %v", data) + } +} + func TestPrune(t *testing.T) { for _, tt := range []struct { name string @@ -144,6 +193,25 @@ func TestPruneEmptyTopic(t *testing.T) { } } +func TestPruneAllElements(t *testing.T) { + top := topic.New[string]() + top.Publish("v1") + top.Publish("v2") + + top.Prune(time.Now()) + + if lastID := top.LastID(); lastID != 2 { + t.Errorf("LastID() = %d, expected 2", lastID) + } + + ctxCanceled, cancel := context.WithCancel(t.Context()) + cancel() + _, data, _ := top.Receive(ctxCanceled, 0) + if data != nil { + t.Errorf("Expected nil data after pruning all, got %v", data) + } +} + func TestPruneUsedValue(t *testing.T) { top := topic.New[string]() top.Publish("val1") @@ -162,6 +230,63 @@ func TestPruneUsedValue(t *testing.T) { } } +func TestPruneWithPastTime(t *testing.T) { + top := topic.New[string]() + + top.Publish("v1") + top.Publish("v2") + + pastTime := time.Now().Add(-1 * time.Hour) + top.Prune(pastTime) + + _, data, err := top.Receive(t.Context(), 0) + if err != nil { + t.Errorf("Receive() error: %v", err) + } + if len(data) != 2 { + t.Errorf("Expected 2 values after pruning past time, got %d", len(data)) + } +} + +func TestMultiplePrunes(t *testing.T) { + top := topic.New[string]() + top.Publish("v1") + top.Publish("v2") + t1 := time.Now() + top.Publish("v3") + t2 := time.Now() + top.Publish("v4") + + top.Prune(t1) + top.Prune(t2) + + _, data, err := top.Receive(t.Context(), 0) + if err != nil { + t.Errorf("Receive() error: %v", err) + } + if len(data) != 1 || data[0] != "v4" { + t.Errorf("After multiple prunes got %v, expected [v4]", data) + } +} + +func TestReceiveWithExactOffsetAfterPrune(t *testing.T) { + top := topic.New[string]() + top.Publish("v1") + top.Publish("v2") + ti := time.Now() + top.Publish("v3") + + top.Prune(ti) + + _, data, err := top.Receive(t.Context(), 2) + if err != nil { + t.Errorf("Receive() error: %v", err) + } + if len(data) != 1 || data[0] != "v3" { + t.Errorf("Receive(2) = %v, expected [v3]", data) + } +} + func TestErrUnknownID(t *testing.T) { top := topic.New[string]() top.Publish("v1") @@ -189,6 +314,58 @@ func TestErrUnknownID(t *testing.T) { } } +func TestReceiveWithFutureID(t *testing.T) { + top := topic.New[string]() + top.Publish("v1") + + done := make(chan struct{}) + go func() { + _, _, err := top.Receive(t.Context(), 100) + if err != nil { + t.Errorf("Receive() returned unexpected error: %v", err) + } + close(done) + }() + + timer := time.NewTimer(time.Millisecond) + defer timer.Stop() + select { + case <-done: + t.Error("Receive() should block when ID > lastID") + case <-timer.C: + top.Publish("v2") + } + + timer.Reset(100 * time.Millisecond) + select { + case <-done: + case <-timer.C: + t.Error("Receive() should unblock after Publish()") + } +} + +func TestReceiveReturnsCorrectID(t *testing.T) { + top := topic.New[string]() + top.Publish("v1") + top.Publish("v2") + + id, _, err := top.Receive(t.Context(), 0) + if err != nil { + t.Errorf("Receive() error: %v", err) + } + if id != 2 { + t.Errorf("Receive() returned id %d, expected 2", id) + } + + id, _, err = top.Receive(t.Context(), 1) + if err != nil { + t.Errorf("Receive() error: %v", err) + } + if id != 2 { + t.Errorf("Receive() returned id %d, expected 2", id) + } +} + func TestLastID(t *testing.T) { for _, tt := range []struct { name string @@ -226,11 +403,23 @@ func TestLastID(t *testing.T) { if got != tt.expect { t.Errorf("LastID() == %d, expected %d", got, tt.expect) } - }) } } +func TestLastIDAfterPrune(t *testing.T) { + top := topic.New[string]() + top.Publish("v1", "v2") + ti := time.Now() + top.Publish("v3") + + top.Prune(ti) + + if id := top.LastID(); id != 3 { + t.Errorf("LastID() after prune = %d, expected 3", id) + } +} + func TestReceiveBlocking(t *testing.T) { // Tests, that Receive() blocks until there is new data. top := topic.New[string]() @@ -339,6 +528,32 @@ func TestBlockOnHighestID(t *testing.T) { } } +func TestPruneDuringBlockedReceive(t *testing.T) { + top := topic.New[string]() + top.Publish("v1") + + done := make(chan struct{}) + go func() { + top.Receive(t.Context(), 1) // Blocking + close(done) + }() + + time.Sleep(time.Millisecond) + + // Prune while blocking + top.Prune(time.Now()) + + top.Publish("v2") + + timer := time.NewTimer(100 * time.Millisecond) + defer timer.Stop() + select { + case <-done: + case <-timer.C: + t.Error("Receive() should unblock after Publish()") + } +} + func TestReceiveOnCanceledChannel(t *testing.T) { top := topic.New[string]() top.Publish("v1") @@ -378,6 +593,64 @@ func TestReceiveOnCanceledChannel(t *testing.T) { } } +func TestConcurrentPublishes(t *testing.T) { + top := topic.New[int]() + + const numPublishers = 10 + const publishesPerPublisher = 100 + + var wg sync.WaitGroup + for i := range numPublishers { + wg.Go(func() { + for j := range publishesPerPublisher { + top.Publish(i*publishesPerPublisher + j) + } + }) + } + + wg.Wait() + + _, data, err := top.Receive(t.Context(), 0) + if err != nil { + t.Errorf("Receive() error: %v", err) + } + + expectedLen := numPublishers * publishesPerPublisher + if len(data) != expectedLen { + t.Errorf("Expected %d values, got %d", expectedLen, len(data)) + } +} + +func TestMultipleConcurrentReceives(t *testing.T) { + top := topic.New[string]() + + const numReceivers = 100 + var wg sync.WaitGroup + + results := make([][]string, numReceivers) + for i := range numReceivers { + wg.Go(func() { + _, data, err := top.Receive(t.Context(), 0) + if err != nil { + t.Errorf("Receive() error: %v", err) + return + } + results[i] = data + }) + } + + time.Sleep(time.Millisecond) + top.Publish("value") + + wg.Wait() + + for i, result := range results { + if len(result) != 1 || result[0] != "value" { + t.Errorf("Receiver %d got %v, expected [value]", i, result) + } + } +} + func TestTopicWithStruct(t *testing.T) { type myType struct { number int @@ -418,6 +691,19 @@ func TestTopicWithPointer(t *testing.T) { } } +func TestTopicWithNilPointers(t *testing.T) { + top := topic.New[*string]() + top.Publish(nil, nil) + + _, data, err := top.Receive(t.Context(), 0) + if err != nil { + t.Errorf("Receive() error: %v", err) + } + if len(data) != 2 || data[0] != nil || data[1] != nil { + t.Errorf("Expected [nil, nil], got %v", data) + } +} + func cmpSlice(one, two []string) bool { if len(one) != len(two) { return false