From 9afead2daf3d0694d4a9dc0d6a3a1fa567125f1b Mon Sep 17 00:00:00 2001 From: Igor Serganov Date: Fri, 27 Mar 2026 16:57:36 -0700 Subject: [PATCH] Lock.Refresh method: use provided RetryStrategy At the moment Refresh method doesn't do any retries and returns immediately if Redis client encounters an error. I find it misleading since Refresh accepts Options which contain RetryStrategy. So why not to use it? --- go.mod | 11 +- go.sum | 32 +++- redislock.go | 50 +++++-- refresh_test.go | 387 ++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 460 insertions(+), 20 deletions(-) create mode 100644 refresh_test.go diff --git a/go.mod b/go.mod index 5c906cd..306cf7a 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,15 @@ module github.com/bsm/redislock -go 1.17 +go 1.23 -require github.com/redis/go-redis/v9 v9.0.3 +require ( + github.com/alicebob/miniredis/v2 v2.37.0 + github.com/redis/go-redis/v9 v9.18.0 +) require ( - github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect + go.uber.org/atomic v1.11.0 // indirect ) diff --git a/go.sum b/go.sum index 02c58f3..8f56d2d 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,26 @@ -github.com/bsm/ginkgo/v2 v2.7.0 h1:ItPMPH90RbmZJt5GtkcNvIRuGEdwlBItdNVoyzaNQao= -github.com/bsm/ginkgo/v2 v2.7.0/go.mod h1:AiKlXPm7ItEHNc/2+OkrNG4E0ITzojb9/xWzvQ9XZ9w= -github.com/bsm/gomega v1.26.0 h1:LhQm+AFcgV2M0WyKroMASzAzCAJVpAxQXv4SaI9a69Y= -github.com/bsm/gomega v1.26.0/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/alicebob/miniredis/v2 v2.37.0 h1:RheObYW32G1aiJIj81XVt78ZHJpHonHLHW7OLIshq68= +github.com/alicebob/miniredis/v2 v2.37.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= -github.com/redis/go-redis/v9 v9.0.3 h1:+7mmR26M0IvyLxGZUHxu4GiBkJkVDid0Un+j4ScYu4k= -github.com/redis/go-redis/v9 v9.0.3/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs= +github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= +github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= +github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= diff --git a/redislock.go b/redislock.go index 3d620e3..96fb2e8 100644 --- a/redislock.go +++ b/redislock.go @@ -80,7 +80,7 @@ func (c *Client) ObtainMulti(ctx context.Context, keys []string, ttl time.Durati } value := token + opt.getMetadata() - ttlVal := strconv.FormatInt(int64(ttl/time.Millisecond), 10) + ttlVal := strconv.FormatInt(ttl.Milliseconds(), 10) retry := opt.getRetryStrategy() // make sure we don't retry forever @@ -91,6 +91,11 @@ func (c *Client) ObtainMulti(ctx context.Context, keys []string, ttl time.Durati } var ticker *time.Ticker + defer func() { + if ticker != nil { + ticker.Stop() + } + }() for { ok, err := c.obtain(ctx, keys, value, len(token), ttlVal) if err != nil { @@ -106,7 +111,6 @@ func (c *Client) ObtainMulti(ctx context.Context, keys []string, ttl time.Durati if ticker == nil { ticker = time.NewTicker(backoff) - defer ticker.Stop() } else { ticker.Reset(backoff) } @@ -205,20 +209,44 @@ func (l *Lock) TTL(ctx context.Context) (time.Duration, error) { } // Refresh extends the lock with a new TTL. -// May return ErrNotObtained if refresh is unsuccessful. +// May return ErrNotObtained if refresh is unsuccessful func (l *Lock) Refresh(ctx context.Context, ttl time.Duration, opt *Options) error { if l == nil { return ErrNotObtained } - ttlVal := strconv.FormatInt(int64(ttl/time.Millisecond), 10) - _, err := luaRefresh.Run(ctx, l.client, l.keys, l.value, ttlVal).Result() - if err != nil { - if errors.Is(err, redis.Nil) { + ttlVal := strconv.FormatInt(ttl.Milliseconds(), 10) + retry := opt.getRetryStrategy() + var ticker *time.Ticker + defer func() { + if ticker != nil { + ticker.Stop() + } + }() + for { + _, err := luaRefresh.Run(ctx, l.client, l.keys, l.value, ttlVal).Result() + if err == nil { + return nil + } + backoff := retry.NextBackoff() + // if the lock is not held, return ErrNotObtained without retrying + if errors.Is(err, redis.Nil) || backoff < 1 { return ErrNotObtained } - return err + + if !isRetryableRedisError(err) { + return err + } + if ticker == nil { + ticker = time.NewTicker(backoff) + } else { + ticker.Reset(backoff) + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } } - return nil } // Release manually releases the lock. @@ -345,3 +373,7 @@ func (r *exponentialBackoff) NextBackoff() time.Duration { return d } } + +func isRetryableRedisError(err error) bool { + return !redis.IsPermissionError(err) +} diff --git a/refresh_test.go b/refresh_test.go new file mode 100644 index 0000000..8003518 --- /dev/null +++ b/refresh_test.go @@ -0,0 +1,387 @@ +package redislock + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" +) + +func setupMiniRedis(t *testing.T) (*miniredis.Miniredis, *redis.Client) { + t.Helper() + mr, err := miniredis.Run() + if err != nil { + t.Fatalf("miniredis: %v", err) + } + rc := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + t.Cleanup(func() { rc.Close(); mr.Close() }) + return mr, rc +} + +func setLock(t *testing.T, rc *redis.Client, keys []string, value string, ttl time.Duration) *Lock { + t.Helper() + for _, k := range keys { + if err := rc.Set(context.Background(), k, value, ttl).Err(); err != nil { + t.Fatalf("SET %q: %v", k, err) + } + } + return &Lock{Client: &Client{client: rc}, keys: keys, value: value, tokenLen: len(value)} +} + +func assertTTLRange(t *testing.T, rc *redis.Client, key string, lo, hi time.Duration) { + t.Helper() + d, err := rc.PTTL(context.Background(), key).Result() + if err != nil { + t.Fatalf("PTTL: %v", err) + } + if d < lo || d > hi { + t.Fatalf("key %s: expected TTL in [%v, %v], got %v", key, lo, hi, d) + } +} + +func assertKeyGone(t *testing.T, rc *redis.Client, key string) { + t.Helper() + n, err := rc.Exists(context.Background(), key).Result() + if err != nil { + t.Fatal(err) + } + if n != 0 { + t.Fatalf("expected key %s to be gone", key) + } +} + +func TestRefresh(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T, mr *miniredis.Miniredis, rc *redis.Client) *Lock + ttl time.Duration + opt *Options + wantErr error + verify func(t *testing.T, mr *miniredis.Miniredis, rc *redis.Client) + }{ + { + name: "nil receiver", + setup: func(*testing.T, *miniredis.Miniredis, *redis.Client) *Lock { return nil }, + ttl: time.Minute, + wantErr: ErrNotObtained, + }, + { + name: "success with nil options sets new TTL", + setup: func(t *testing.T, _ *miniredis.Miniredis, rc *redis.Client) *Lock { + return setLock(t, rc, []string{"k"}, "v", time.Hour) + }, + ttl: time.Minute, + verify: func(t *testing.T, _ *miniredis.Miniredis, rc *redis.Client) { + assertTTLRange(t, rc, "k", 59*time.Second, time.Minute) + }, + }, + { + name: "extends TTL", + setup: func(t *testing.T, _ *miniredis.Miniredis, rc *redis.Client) *Lock { + return setLock(t, rc, []string{"k"}, "v", time.Minute) + }, + ttl: time.Hour, opt: &Options{}, + verify: func(t *testing.T, _ *miniredis.Miniredis, rc *redis.Client) { + assertTTLRange(t, rc, "k", 59*time.Minute, time.Hour) + }, + }, + { + name: "expired lock", + setup: func(t *testing.T, mr *miniredis.Miniredis, rc *redis.Client) *Lock { + l := setLock(t, rc, []string{"k"}, "v", time.Millisecond) + mr.FastForward(10 * time.Millisecond) + return l + }, + ttl: time.Minute, wantErr: ErrNotObtained, + }, + { + name: "value mismatch (taken by someone else)", + setup: func(t *testing.T, _ *miniredis.Miniredis, rc *redis.Client) *Lock { + l := setLock(t, rc, []string{"k"}, "v", time.Hour) + rc.Set(context.Background(), "k", "other", time.Hour) + return l + }, + ttl: time.Minute, wantErr: ErrNotObtained, + }, + { + name: "deleted key", + setup: func(t *testing.T, _ *miniredis.Miniredis, rc *redis.Client) *Lock { + l := setLock(t, rc, []string{"k"}, "v", time.Hour) + rc.Del(context.Background(), "k") + return l + }, + ttl: time.Minute, wantErr: ErrNotObtained, + }, + { + name: "expired lock with NoRetry", + setup: func(t *testing.T, mr *miniredis.Miniredis, rc *redis.Client) *Lock { + l := setLock(t, rc, []string{"k"}, "v", time.Millisecond) + mr.FastForward(10 * time.Millisecond) + return l + }, + ttl: time.Minute, opt: &Options{RetryStrategy: NoRetry()}, wantErr: ErrNotObtained, + }, + { + name: "expired lock with retry strategy still fails", + setup: func(t *testing.T, mr *miniredis.Miniredis, rc *redis.Client) *Lock { + l := setLock(t, rc, []string{"k"}, "v", time.Millisecond) + mr.FastForward(10 * time.Millisecond) + return l + }, + ttl: time.Minute, opt: &Options{RetryStrategy: LimitRetry(LinearBackoff(time.Millisecond), 3)}, wantErr: ErrNotObtained, + }, + { + name: "sequential refreshes", + setup: func(t *testing.T, _ *miniredis.Miniredis, rc *redis.Client) *Lock { + return setLock(t, rc, []string{"k"}, "v", time.Hour) + }, + ttl: 30 * time.Second, + verify: func(t *testing.T, _ *miniredis.Miniredis, rc *redis.Client) { + assertTTLRange(t, rc, "k", 29*time.Second, 30*time.Second) + l := &Lock{Client: &Client{client: rc}, keys: []string{"k"}, value: "v", tokenLen: 1} + if err := l.Refresh(context.Background(), 2*time.Minute, nil); err != nil { + t.Fatal(err) + } + assertTTLRange(t, rc, "k", 119*time.Second, 2*time.Minute) + }, + }, + { + name: "preserves value", + setup: func(t *testing.T, _ *miniredis.Miniredis, rc *redis.Client) *Lock { + return setLock(t, rc, []string{"k"}, "tok+meta", time.Hour) + }, + ttl: time.Minute, + verify: func(t *testing.T, _ *miniredis.Miniredis, rc *redis.Client) { + v, _ := rc.Get(context.Background(), "k").Result() + if v != "tok+meta" { + t.Fatalf("expected %q, got %q", "tok+meta", v) + } + }, + }, + { + name: "small TTL expires quickly", + setup: func(t *testing.T, _ *miniredis.Miniredis, rc *redis.Client) *Lock { + return setLock(t, rc, []string{"k"}, "v", time.Hour) + }, + ttl: time.Millisecond, + verify: func(t *testing.T, mr *miniredis.Miniredis, rc *redis.Client) { + mr.FastForward(10 * time.Millisecond) + assertKeyGone(t, rc, "k") + }, + }, + { + name: "zero TTL expires immediately", + setup: func(t *testing.T, _ *miniredis.Miniredis, rc *redis.Client) *Lock { + return setLock(t, rc, []string{"k"}, "v", time.Hour) + }, + ttl: 0, + verify: func(t *testing.T, mr *miniredis.Miniredis, rc *redis.Client) { + mr.FastForward(time.Millisecond) + assertKeyGone(t, rc, "k") + }, + }, + { + name: "multi-key fails when one key missing", + setup: func(t *testing.T, _ *miniredis.Miniredis, rc *redis.Client) *Lock { + l := setLock(t, rc, []string{"a", "b"}, "v", time.Hour) + rc.Del(context.Background(), "b") + return l + }, + ttl: time.Minute, wantErr: ErrNotObtained, + }, + { + name: "multi-key fails when one key has wrong value", + setup: func(t *testing.T, _ *miniredis.Miniredis, rc *redis.Client) *Lock { + l := setLock(t, rc, []string{"a", "b"}, "v", time.Hour) + rc.Set(context.Background(), "a", "x", time.Hour) + return l + }, + ttl: time.Minute, wantErr: ErrNotObtained, + }, + { + name: "context deadline during retry backoff", + setup: func(t *testing.T, mr *miniredis.Miniredis, rc *redis.Client) *Lock { + l := setLock(t, rc, []string{"k"}, "v", time.Millisecond) + mr.FastForward(10 * time.Millisecond) + return l + }, + ttl: time.Minute, opt: &Options{RetryStrategy: LinearBackoff(200 * time.Millisecond)}, wantErr: ErrNotObtained, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mr, rc := setupMiniRedis(t) + lock := tc.setup(t, mr, rc) + err := lock.Refresh(context.Background(), tc.ttl, tc.opt) + if tc.wantErr != nil { + if !errors.Is(err, tc.wantErr) { + t.Fatalf("expected %v, got %v", tc.wantErr, err) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc.verify != nil { + tc.verify(t, mr, rc) + } + }) + } +} + +// interceptingScripter wraps a real RedisClient, calling hook before each +// Eval/EvalSha. If hook returns a non-nil *redis.Cmd, that is returned +// instead of delegating to the real client. +type interceptingScripter struct { + real RedisClient + hook func(ctx context.Context) *redis.Cmd +} + +func (s *interceptingScripter) eval(ctx context.Context, fn func() *redis.Cmd) *redis.Cmd { + if cmd := s.hook(ctx); cmd != nil { + return cmd + } + return fn() +} +func (s *interceptingScripter) Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd { + return s.eval(ctx, func() *redis.Cmd { return s.real.Eval(ctx, script, keys, args...) }) +} +func (s *interceptingScripter) EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd { + return s.eval(ctx, func() *redis.Cmd { return s.real.EvalSha(ctx, sha1, keys, args...) }) +} +func (s *interceptingScripter) EvalRO(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd { + return s.eval(ctx, func() *redis.Cmd { return s.real.EvalRO(ctx, script, keys, args...) }) +} +func (s *interceptingScripter) EvalShaRO(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd { + return s.eval(ctx, func() *redis.Cmd { return s.real.EvalShaRO(ctx, sha1, keys, args...) }) +} +func (s *interceptingScripter) ScriptExists(ctx context.Context, hashes ...string) *redis.BoolSliceCmd { + return s.real.ScriptExists(ctx, hashes...) +} +func (s *interceptingScripter) ScriptLoad(ctx context.Context, script string) *redis.StringCmd { + return s.real.ScriptLoad(ctx, script) +} + +func errCmd(ctx context.Context, err error) *redis.Cmd { + cmd := redis.NewCmd(ctx) + cmd.SetErr(err) + return cmd +} + +func failNTimes(n int32, injectedErr error) func(context.Context) *redis.Cmd { + var count int32 + return func(ctx context.Context) *redis.Cmd { + if atomic.AddInt32(&count, 1) <= n { + return errCmd(ctx, injectedErr) + } + return nil + } +} + +func lockWithClient(client RedisClient, key, value string) *Lock { + return &Lock{Client: &Client{client: client}, keys: []string{key}, value: value, tokenLen: len(value)} +} + +func TestRefresh_TransientErrors(t *testing.T) { + loadingErr := errors.New("LOADING Redis is loading the dataset in memory") + + t.Run("succeeds after transient errors", func(t *testing.T) { + _, rc := setupMiniRedis(t) + rc.Set(context.Background(), "k", "v", time.Hour) + w := &interceptingScripter{real: rc, hook: failNTimes(2, loadingErr)} + err := lockWithClient(w, "k", "v").Refresh(context.Background(), time.Minute, &Options{ + RetryStrategy: LimitRetry(LinearBackoff(time.Millisecond), 5), + }) + if err != nil { + t.Fatalf("expected success, got %v", err) + } + assertTTLRange(t, rc, "k", 59*time.Second, time.Minute) + }) + + t.Run("exhausts retries", func(t *testing.T) { + _, rc := setupMiniRedis(t) + rc.Set(context.Background(), "k", "v", time.Hour) + w := &interceptingScripter{real: rc, hook: failNTimes(100, loadingErr)} + err := lockWithClient(w, "k", "v").Refresh(context.Background(), time.Minute, &Options{ + RetryStrategy: LimitRetry(LinearBackoff(time.Millisecond), 2), + }) + if !errors.Is(err, ErrNotObtained) { + t.Fatalf("expected ErrNotObtained, got %v", err) + } + }) + + t.Run("non-retryable error returned directly", func(t *testing.T) { + _, rc := setupMiniRedis(t) + rc.Set(context.Background(), "k", "v", time.Hour) + permErr := errors.New("NOPERM user has no permissions to run 'evalsha' command") + w := &interceptingScripter{real: rc, hook: failNTimes(1, permErr)} + err := lockWithClient(w, "k", "v").Refresh(context.Background(), time.Minute, &Options{ + RetryStrategy: LinearBackoff(time.Millisecond), + }) + if err == nil || err.Error() != permErr.Error() { + t.Fatalf("expected permission error, got %v", err) + } + }) + + t.Run("context cancelled during retry attempt", func(t *testing.T) { + _, rc := setupMiniRedis(t) + rc.Set(context.Background(), "k", "v", time.Hour) + ctx, cancel := context.WithCancel(context.Background()) + var count int32 + w := &interceptingScripter{real: rc, hook: func(c context.Context) *redis.Cmd { + if atomic.AddInt32(&count, 1) == 2 { + cancel() + return errCmd(c, loadingErr) + } + return nil + }} + err := lockWithClient(w, "k", "v").Refresh(ctx, time.Minute, &Options{ + RetryStrategy: LinearBackoff(time.Millisecond), + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } + }) +} + +func TestRefresh_RetryCallCount(t *testing.T) { + t.Run("NoRetry calls NextBackoff once on failure", func(t *testing.T) { + mr, rc := setupMiniRedis(t) + setLock(t, rc, []string{"k"}, "v", time.Millisecond) + mr.FastForward(10 * time.Millisecond) + n := 0 + spy := &spyRetry{inner: NoRetry(), onCall: func() { n++ }} + _ = (&Lock{Client: &Client{client: rc}, keys: []string{"k"}, value: "v", tokenLen: 1}). + Refresh(context.Background(), time.Minute, &Options{RetryStrategy: spy}) + if n != 1 { + t.Fatalf("expected 1 call, got %d", n) + } + }) + + t.Run("success never calls NextBackoff", func(t *testing.T) { + _, rc := setupMiniRedis(t) + setLock(t, rc, []string{"k"}, "v", time.Hour) + n := 0 + spy := &spyRetry{inner: LinearBackoff(time.Millisecond), onCall: func() { n++ }} + _ = (&Lock{Client: &Client{client: rc}, keys: []string{"k"}, value: "v", tokenLen: 1}). + Refresh(context.Background(), time.Minute, &Options{RetryStrategy: spy}) + if n != 0 { + t.Fatalf("expected 0 calls, got %d", n) + } + }) +} + +type spyRetry struct { + inner RetryStrategy + onCall func() +} + +func (s *spyRetry) NextBackoff() time.Duration { + s.onCall() + return s.inner.NextBackoff() +}