Skip to content

feat: Introducing StreamingCredentialsProvider for token based authentication #3320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
5410adb
wip
ndyakov Mar 18, 2025
9ef438b
Merge remote-tracking branch 'origin/master' into ndyakov/token-based…
ndyakov Mar 24, 2025
df9bfce
update documentation
ndyakov Mar 24, 2025
140a278
add streamingcredentialsprovider in options
ndyakov Mar 24, 2025
d3a25f9
Merge branch 'master' into ndyakov/token-based-auth
ndyakov Mar 24, 2025
7f5d87b
fix: put back option in pool creation
ndyakov Mar 24, 2025
fa59cce
add package level comment
ndyakov Mar 24, 2025
c248425
Merge branch 'master' into ndyakov/token-based-auth
ndyakov Mar 31, 2025
847f1f9
Merge branch 'master' into ndyakov/token-based-auth
ndyakov Apr 3, 2025
40a89c5
Initial re authentication implementation
ndyakov Apr 15, 2025
d0a8c76
Change function type name
ndyakov Apr 16, 2025
e0c224d
Merge branch 'master' into ndyakov/token-based-auth
ndyakov Apr 22, 2025
44628c5
add tests
ndyakov Apr 22, 2025
4ab4980
fix race in tests
ndyakov Apr 22, 2025
420c4fb
fix example tests
ndyakov Apr 22, 2025
5fac913
wip, hooks refactor
ndyakov Apr 22, 2025
2a97f2e
fix build
ndyakov Apr 22, 2025
f103a7d
update README.md
ndyakov Apr 22, 2025
6e17fb4
Merge branch 'master' into ndyakov/token-based-auth
ndyakov Apr 22, 2025
3acfb1c
update wordlist
ndyakov Apr 22, 2025
7eea9e7
update README.md
ndyakov Apr 22, 2025
5f91e66
Merge branch 'master' into ndyakov/token-based-auth
ndyakov Apr 23, 2025
d0bfdab
refactor(auth): early returns in cred listener
ndyakov Apr 24, 2025
f6f892d
Merge branch 'master' into ndyakov/token-based-auth
ndyakov Apr 24, 2025
544bdb2
fix(doctest): simulate some delay
ndyakov Apr 24, 2025
cff6b9b
Merge branch 'master' into ndyakov/token-based-auth
ndyakov Apr 29, 2025
8f05aef
Merge branch 'master' into ndyakov/token-based-auth
ndyakov May 12, 2025
c5054e2
feat(conn): add close hook on conn
ndyakov May 13, 2025
8b51596
fix(tests): simulate start/stop in mock credentials provider
ndyakov May 13, 2025
b80969f
fix(auth): don't double close the conn
ndyakov May 14, 2025
a6a2c9d
docs(README): mark streaming credentials provider as experimental
ndyakov May 14, 2025
5228611
fix(auth): streamline auth err proccess
ndyakov May 14, 2025
3345fd1
fix(auth): check err on close conn
ndyakov May 14, 2025
1628b87
Merge branch 'master' into ndyakov/token-based-auth
ndyakov May 19, 2025
57584be
Merge branch 'master' into ndyakov/token-based-auth
ndyakov May 19, 2025
45e5ee9
chore(entraid): use the repo under redis org
ndyakov May 20, 2025
b7ce3cd
Merge branch 'master' into ndyakov/token-based-auth
ndyakov May 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add tests
  • Loading branch information
ndyakov committed Apr 22, 2025
commit 44628c5dbd47868d8ad62626269a6cebfbba6bfd
302 changes: 302 additions & 0 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
package auth

import (
"errors"
"testing"
"time"
)

type mockStreamingProvider struct {
credentials Credentials
err error
updates chan Credentials
}

func newMockStreamingProvider(initialCreds Credentials) *mockStreamingProvider {
return &mockStreamingProvider{
credentials: initialCreds,
updates: make(chan Credentials, 10),
}
}

func (m *mockStreamingProvider) Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) {
if m.err != nil {
return nil, nil, m.err
}

// Send initial credentials
listener.OnNext(m.credentials)

// Start goroutine to handle updates
go func() {
for creds := range m.updates {
listener.OnNext(creds)
}
}()

return m.credentials, func() error {
close(m.updates)
return nil
}, nil
}

func TestStreamingCredentialsProvider(t *testing.T) {
t.Run("successful subscription", func(t *testing.T) {
initialCreds := NewBasicCredentials("user1", "pass1")
provider := newMockStreamingProvider(initialCreds)

var receivedCreds []Credentials
var receivedErrors []error

listener := NewReAuthCredentialsListener(
func(creds Credentials) error {
receivedCreds = append(receivedCreds, creds)
return nil
},
func(err error) {
receivedErrors = append(receivedErrors, err)
},
)

creds, cancel, err := provider.Subscribe(listener)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cancel == nil {
t.Fatal("expected cancel function to be non-nil")
}
if creds != initialCreds {
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
}
if len(receivedCreds) != 1 {
t.Fatalf("expected 1 received credential, got %d", len(receivedCreds))
}
if receivedCreds[0] != initialCreds {
t.Fatalf("expected received credential %v, got %v", initialCreds, receivedCreds[0])
}
if len(receivedErrors) != 0 {
t.Fatalf("expected no errors, got %d", len(receivedErrors))
}

// Send an update
newCreds := NewBasicCredentials("user2", "pass2")
provider.updates <- newCreds

// Wait for update to be processed
time.Sleep(100 * time.Millisecond)
if len(receivedCreds) != 2 {
t.Fatalf("expected 2 received credentials, got %d", len(receivedCreds))
}
if receivedCreds[1] != newCreds {
t.Fatalf("expected received credential %v, got %v", newCreds, receivedCreds[1])
}

// Cancel subscription
if err := cancel(); err != nil {
t.Fatalf("unexpected error cancelling subscription: %v", err)
}
})

t.Run("subscription error", func(t *testing.T) {
provider := &mockStreamingProvider{
err: errors.New("subscription failed"),
}

var receivedCreds []Credentials
var receivedErrors []error

listener := NewReAuthCredentialsListener(
func(creds Credentials) error {
receivedCreds = append(receivedCreds, creds)
return nil
},
func(err error) {
receivedErrors = append(receivedErrors, err)
},
)

creds, cancel, err := provider.Subscribe(listener)
if err == nil {
t.Fatal("expected error, got nil")
}
if cancel != nil {
t.Fatal("expected cancel function to be nil")
}
if creds != nil {
t.Fatalf("expected nil credentials, got %v", creds)
}
if len(receivedCreds) != 0 {
t.Fatalf("expected no received credentials, got %d", len(receivedCreds))
}
if len(receivedErrors) != 0 {
t.Fatalf("expected no errors, got %d", len(receivedErrors))
}
})

t.Run("re-auth error", func(t *testing.T) {
initialCreds := NewBasicCredentials("user1", "pass1")
provider := newMockStreamingProvider(initialCreds)

reauthErr := errors.New("re-auth failed")
var receivedErrors []error

listener := NewReAuthCredentialsListener(
func(creds Credentials) error {
return reauthErr
},
func(err error) {
receivedErrors = append(receivedErrors, err)
},
)

creds, cancel, err := provider.Subscribe(listener)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cancel == nil {
t.Fatal("expected cancel function to be non-nil")
}
if creds != initialCreds {
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
}
if len(receivedErrors) != 1 {
t.Fatalf("expected 1 error, got %d", len(receivedErrors))
}
if receivedErrors[0] != reauthErr {
t.Fatalf("expected error %v, got %v", reauthErr, receivedErrors[0])
}

if err := cancel(); err != nil {
t.Fatalf("unexpected error cancelling subscription: %v", err)
}
})
}

func TestBasicCredentials(t *testing.T) {
t.Run("basic auth", func(t *testing.T) {
creds := NewBasicCredentials("user1", "pass1")
username, password := creds.BasicAuth()
if username != "user1" {
t.Fatalf("expected username 'user1', got '%s'", username)
}
if password != "pass1" {
t.Fatalf("expected password 'pass1', got '%s'", password)
}
})

t.Run("raw credentials", func(t *testing.T) {
creds := NewBasicCredentials("user1", "pass1")
raw := creds.RawCredentials()
expected := "user1:pass1"
if raw != expected {
t.Fatalf("expected raw credentials '%s', got '%s'", expected, raw)
}
})

t.Run("empty username", func(t *testing.T) {
creds := NewBasicCredentials("", "pass1")
username, password := creds.BasicAuth()
if username != "" {
t.Fatalf("expected empty username, got '%s'", username)
}
if password != "pass1" {
t.Fatalf("expected password 'pass1', got '%s'", password)
}
})
}

func TestReAuthCredentialsListener(t *testing.T) {
t.Run("successful re-auth", func(t *testing.T) {
var reAuthCalled bool
var onErrCalled bool
var receivedCreds Credentials

listener := NewReAuthCredentialsListener(
func(creds Credentials) error {
reAuthCalled = true
receivedCreds = creds
return nil
},
func(err error) {
onErrCalled = true
},
)

creds := NewBasicCredentials("user1", "pass1")
listener.OnNext(creds)

if !reAuthCalled {
t.Fatal("expected reAuth to be called")
}
if onErrCalled {
t.Fatal("expected onErr not to be called")
}
if receivedCreds != creds {
t.Fatalf("expected credentials %v, got %v", creds, receivedCreds)
}
})

t.Run("re-auth error", func(t *testing.T) {
var reAuthCalled bool
var onErrCalled bool
var receivedErr error
expectedErr := errors.New("re-auth failed")

listener := NewReAuthCredentialsListener(
func(creds Credentials) error {
reAuthCalled = true
return expectedErr
},
func(err error) {
onErrCalled = true
receivedErr = err
},
)

creds := NewBasicCredentials("user1", "pass1")
listener.OnNext(creds)

if !reAuthCalled {
t.Fatal("expected reAuth to be called")
}
if !onErrCalled {
t.Fatal("expected onErr to be called")
}
if receivedErr != expectedErr {
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
}
})

t.Run("on error", func(t *testing.T) {
var onErrCalled bool
var receivedErr error
expectedErr := errors.New("provider error")

listener := NewReAuthCredentialsListener(
func(creds Credentials) error {
return nil
},
func(err error) {
onErrCalled = true
receivedErr = err
},
)

listener.OnError(expectedErr)

if !onErrCalled {
t.Fatal("expected onErr to be called")
}
if receivedErr != expectedErr {
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
}
})

t.Run("nil callbacks", func(t *testing.T) {
listener := NewReAuthCredentialsListener(nil, nil)

// Should not panic
listener.OnNext(NewBasicCredentials("user1", "pass1"))
listener.OnError(errors.New("test error"))
})
}
86 changes: 86 additions & 0 deletions command_recorder_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package redis_test

import (
"context"
"strings"
"sync"

"github.com/redis/go-redis/v9"
)

// commandRecorder records the last N commands executed by a Redis client.
type commandRecorder struct {
mu sync.Mutex
commands []string
maxSize int
}

// newCommandRecorder creates a new command recorder with the specified maximum size.
func newCommandRecorder(maxSize int) *commandRecorder {
return &commandRecorder{
commands: make([]string, 0, maxSize),
maxSize: maxSize,
}
}

// Record adds a command to the recorder.
func (r *commandRecorder) Record(cmd string) {
cmd = strings.ToLower(cmd)
r.mu.Lock()
defer r.mu.Unlock()

r.commands = append(r.commands, cmd)
if len(r.commands) > r.maxSize {
r.commands = r.commands[1:]
}
}

// LastCommands returns a copy of the recorded commands.
func (r *commandRecorder) LastCommands() []string {
r.mu.Lock()
defer r.mu.Unlock()
return append([]string(nil), r.commands...)
}

// Contains checks if the recorder contains a specific command.
func (r *commandRecorder) Contains(cmd string) bool {
cmd = strings.ToLower(cmd)
r.mu.Lock()
defer r.mu.Unlock()
for _, c := range r.commands {
if strings.Contains(c, cmd) {
return true
}
}
return false
}

// Hook returns a Redis hook that records commands.
func (r *commandRecorder) Hook() redis.Hook {
return &commandHook{recorder: r}
}

// commandHook implements the redis.Hook interface to record commands.
type commandHook struct {
recorder *commandRecorder
}

func (h *commandHook) DialHook(next redis.DialHook) redis.DialHook {
return next
}

func (h *commandHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
return func(ctx context.Context, cmd redis.Cmder) error {
h.recorder.Record(cmd.String())
return next(ctx, cmd)
}
}

func (h *commandHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook {
return func(ctx context.Context, cmds []redis.Cmder) error {
for _, cmd := range cmds {
h.recorder.Record(cmd.String())
}
return next(ctx, cmds)
}
}
2 changes: 2 additions & 0 deletions internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"github.com/redis/go-redis/v9/internal/rand"
)

type ParentHooksMixinKey struct{}

func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration {
if retry < 0 {
panic("not reached")
Expand Down
Loading
Loading