Skip to content

Enforce unique job keys #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Enforce unique job keys
This commit does a couple of things:
1. It adds a new `SetIfNotExists` methods to replicated maps that will set a value only if it isn't already in the map.
2. It ensures the pool node `DispatchJob` method returns an the new `ErrJobExists` error when attempting to queue a job with a key that already exists.

It was previously the responsibility of the client to ensure no two calls to DispatchJob used the same key, now the pool package enforces it.
  • Loading branch information
raphael committed Jan 21, 2025
commit 23ebcc69b4c46d7a821f90c8eab7eb4625c52223
89 changes: 79 additions & 10 deletions pool/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type (
nodeReader *streaming.Reader // node event reader
workerMap *rmap.Map // worker creation times by ID
jobsMap *rmap.Map // jobs by worker ID
pendingJobsMap *rmap.Map // pending jobs by job key
jobPayloadsMap *rmap.Map // job payloads by job key
nodeKeepAliveMap *rmap.Map // node keep-alive timestamps indexed by ID
workerKeepAliveMap *rmap.Map // worker keep-alive timestamps indexed by ID
Expand All @@ -50,10 +51,10 @@ type (
wg sync.WaitGroup // allows to wait until all goroutines exit
rdb *redis.Client

localWorkers sync.Map // workers created by this node
workerStreams sync.Map // worker streams indexed by ID
pendingJobs sync.Map // channels used to send DispatchJob results, nil if event is requeued
pendingEvents sync.Map // pending events indexed by sender and event IDs
localWorkers sync.Map // workers created by this node
workerStreams sync.Map // worker streams indexed by ID
pendingJobChannels sync.Map // channels used to send DispatchJob results, nil if event is requeued
pendingEvents sync.Map // pending events indexed by sender and event IDs

lock sync.RWMutex
closing bool
Expand Down Expand Up @@ -88,6 +89,9 @@ const (
// pendingEventTTL is the TTL for pending events.
var pendingEventTTL = 2 * time.Minute

// ErrJobExists is returned when attempting to dispatch a job with a key that already exists.
var ErrJobExists = errors.New("job already exists")

// AddNode adds a new node to the pool with the given name and returns it. The
// node can be used to dispatch jobs and add new workers. A node also routes
// dispatched jobs to the proper worker and acks the corresponding events once
Expand Down Expand Up @@ -139,6 +143,7 @@ func AddNode(ctx context.Context, poolName string, rdb *redis.Client, opts ...No
wm *rmap.Map
jm *rmap.Map
jpm *rmap.Map
pjm *rmap.Map
km *rmap.Map
tm *rmap.Map

Expand Down Expand Up @@ -176,6 +181,12 @@ func AddNode(ctx context.Context, poolName string, rdb *redis.Client, opts ...No
return nil, fmt.Errorf("AddNode: failed to join pool ticker replicated map %q: %w", tickerMapName(poolName), err)
}

// Initialize and join pending jobs map
pjm, err = rmap.Join(ctx, pendingJobsMapName(poolName), rdb, rmap.WithLogger(logger))
if err != nil {
return nil, fmt.Errorf("AddNode: failed to join pending jobs replicated map %q: %w", pendingJobsMapName(poolName), err)
}

poolSink, err = poolStream.NewSink(ctx, "events",
soptions.WithSinkBlockDuration(o.jobSinkBlockDuration),
soptions.WithSinkAckGracePeriod(o.ackGracePeriod))
Expand Down Expand Up @@ -203,10 +214,11 @@ func AddNode(ctx context.Context, poolName string, rdb *redis.Client, opts ...No
workerMap: wm,
jobsMap: jm,
jobPayloadsMap: jpm,
pendingJobsMap: pjm,
shutdownMap: wsm,
tickerMap: tm,
workerStreams: sync.Map{},
pendingJobs: sync.Map{},
pendingJobChannels: sync.Map{},
pendingEvents: sync.Map{},
poolStream: poolStream,
poolSink: poolSink,
Expand Down Expand Up @@ -311,6 +323,7 @@ func (node *Node) PoolWorkers() []*Worker {
// the job key using consistent hashing.
// It returns:
// - nil if the job is successfully dispatched and started by a worker
// - ErrJobExists if a job with the same key already exists in the pool
// - an error returned by the worker's start handler if the job fails to start
// - an error if the pool is closed or if there's a failure in adding the job
//
Expand All @@ -320,14 +333,59 @@ func (node *Node) DispatchJob(ctx context.Context, key string, payload []byte) e
return fmt.Errorf("DispatchJob: pool %q is closed", node.PoolName)
}

// Check if job already exists in job payloads map
if _, exists := node.jobPayloadsMap.Get(key); exists {
return fmt.Errorf("%w: job %q", ErrJobExists, key)
}

// Check if there's a pending dispatch for this job
pendingTS, exists := node.pendingJobsMap.Get(key)
if exists {
ts, err := strconv.ParseInt(pendingTS, 10, 64)
if err != nil {
_, err := node.pendingJobsMap.TestAndDelete(ctx, key, pendingTS)
if err != nil {
node.logger.Error(fmt.Errorf("DispatchJob: failed to delete invalid pending timestamp for job %q: %w", key, err))
}
exists = false
} else if time.Until(time.Unix(0, ts)) > 0 {
return fmt.Errorf("%w: job %q is already being dispatched", ErrJobExists, key)
}
}

// Set pending timestamp using atomic operation
pendingUntil := time.Now().Add(2 * node.ackGracePeriod).UnixNano()
newTS := strconv.FormatInt(pendingUntil, 10)
if exists {
current, err := node.pendingJobsMap.TestAndSet(ctx, key, pendingTS, newTS)
if err != nil {
return fmt.Errorf("DispatchJob: failed to set pending timestamp for job %q: %w", key, err)
}
if current != pendingTS {
return fmt.Errorf("%w: job %q is already being dispatched", ErrJobExists, key)
}
} else {
ok, err := node.pendingJobsMap.SetIfNotExists(ctx, key, newTS)
if err != nil {
return fmt.Errorf("DispatchJob: failed to set initial pending timestamp for job %q: %w", key, err)
}
if !ok {
return fmt.Errorf("%w: job %q is already being dispatched", ErrJobExists, key)
}
}

job := marshalJob(&Job{Key: key, Payload: payload, CreatedAt: time.Now(), NodeID: node.ID})
eventID, err := node.poolStream.Add(ctx, evStartJob, job)
if err != nil {
// Clean up pending entry on failure
if _, err := node.pendingJobsMap.Delete(ctx, key); err != nil {
node.logger.Error(fmt.Errorf("DispatchJob: failed to clean up pending entry for job %q: %w", key, err))
}
return fmt.Errorf("DispatchJob: failed to add job to stream %q: %w", node.poolStream.Name, err)
}

cherr := make(chan error, 1)
node.pendingJobs.Store(eventID, cherr)
node.pendingJobChannels.Store(eventID, cherr)

timer := time.NewTimer(2 * node.ackGracePeriod)
defer timer.Stop()
Expand All @@ -340,9 +398,14 @@ func (node *Node) DispatchJob(ctx context.Context, key string, payload []byte) e
err = ctx.Err()
}

node.pendingJobs.Delete(eventID)
node.pendingJobChannels.Delete(eventID)
close(cherr)

// Clean up pending entry
if _, err := node.pendingJobsMap.Delete(ctx, key); err != nil {
node.logger.Error(fmt.Errorf("DispatchJob: failed to clean up pending entry for job %q: %w", key, err))
}

if err != nil {
node.logger.Error(fmt.Errorf("DispatchJob: failed to dispatch job: %w", err), "key", key)
return err
Expand Down Expand Up @@ -654,15 +717,15 @@ func (node *Node) ackWorkerEvent(ctx context.Context, ev *streaming.Event) {
// returnDispatchStatus returns the start job result to the caller.
func (node *Node) returnDispatchStatus(_ context.Context, ev *streaming.Event) {
ack := unmarshalAck(ev.Payload)
val, ok := node.pendingJobs.Load(ack.EventID)
val, ok := node.pendingJobChannels.Load(ack.EventID)
if !ok {
node.logger.Error(fmt.Errorf("returnDispatchStatus: received dispatch return for unknown event"), "id", ack.EventID)
return
}
node.logger.Debug("dispatch return", "event", ev.EventName, "id", ev.ID, "ack-id", ack.EventID)
if val == nil {
// Event was requeued, just clean up
node.pendingJobs.Delete(ack.EventID)
node.pendingJobChannels.Delete(ack.EventID)
return
}
var err error
Expand Down Expand Up @@ -740,7 +803,7 @@ func (node *Node) requeueJob(ctx context.Context, workerID string, job *Job) (ch
return nil, fmt.Errorf("requeueJob: failed to add job %q to stream %q: %w", job.Key, node.poolStream.Name, err)
}
cherr := make(chan error, 1)
node.pendingJobs.Store(eventID, cherr)
node.pendingJobChannels.Store(eventID, cherr)
return cherr, nil
}

Expand Down Expand Up @@ -1205,3 +1268,9 @@ func nodeStreamName(pool, nodeID string) string {
func pendingEventKey(workerID, eventID string) string {
return fmt.Sprintf("%s:%s", workerID, eventID)
}

// pendingJobsMapName returns the name of the replicated map used to store the
// pending jobs by job key.
func pendingJobsMapName(poolName string) string {
return poolName + ":pending-jobs"
}
130 changes: 130 additions & 0 deletions pool/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pool

import (
"context"
"errors"
"fmt"
"strconv"
"strings"
Expand Down Expand Up @@ -258,6 +259,135 @@ func TestDispatchJobTwoWorkers(t *testing.T) {
assert.NoError(t, node.Shutdown(ctx), "Failed to shutdown node")
}

func TestDispatchJobRaceCondition(t *testing.T) {
testName := strings.Replace(t.Name(), "/", "_", -1)
ctx := ptesting.NewTestContext(t)
rdb := ptesting.NewRedisClient(t)
defer ptesting.CleanupRedis(t, rdb, true, testName)

node1 := newTestNode(t, ctx, rdb, testName)
node2 := newTestNode(t, ctx, rdb, testName)
newTestWorker(t, ctx, node1)
newTestWorker(t, ctx, node2)
defer func() {
assert.NoError(t, node1.Shutdown(ctx))
assert.NoError(t, node2.Shutdown(ctx))
}()

t.Run("concurrent dispatch of same job returns error", func(t *testing.T) {
// Start dispatching same job from both nodes concurrently
errCh := make(chan error, 2)
jobKey := "concurrent-job"
payload := []byte("test payload")

var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
errCh <- node1.DispatchJob(ctx, jobKey, payload)
}()
go func() {
defer wg.Done()
errCh <- node2.DispatchJob(ctx, jobKey, payload)
}()
wg.Wait()
close(errCh)

// Collect results
var errs []error
for err := range errCh {
errs = append(errs, err)
}

// Verify that exactly one dispatch succeeded and one failed
successCount := 0
errorCount := 0
for _, err := range errs {
if err == nil {
successCount++
} else if errors.Is(err, ErrJobExists) {
errorCount++
} else {
t.Errorf("unexpected error: %v", err)
}
}
assert.Equal(t, 1, successCount, "Expected exactly one successful dispatch")
assert.Equal(t, 1, errorCount, "Expected exactly one ErrJobExists error")
})

t.Run("dispatch after existing job returns error", func(t *testing.T) {
jobKey := "sequential-job"
payload := []byte("test payload")

// First dispatch should succeed
err := node1.DispatchJob(ctx, jobKey, payload)
require.NoError(t, err, "First dispatch should succeed")

// Second dispatch should fail with ErrJobExists
err = node2.DispatchJob(ctx, jobKey, payload)
assert.True(t, errors.Is(err, ErrJobExists), "Expected ErrJobExists, got: %v", err)
})

t.Run("dispatch after pending job times out succeeds", func(t *testing.T) {
jobKey := "timeout-job"
payload := []byte("test payload")

// Set a stale pending timestamp
staleTS := time.Now().Add(-3 * node1.ackGracePeriod).UnixNano()
_, err := node1.pendingJobsMap.Set(ctx, jobKey, strconv.FormatInt(staleTS, 10))
require.NoError(t, err, "Failed to set stale pending timestamp")

// Dispatch should succeed because pending timestamp is in the past
err = node2.DispatchJob(ctx, jobKey, payload)
assert.NoError(t, err, "Dispatch should succeed after pending timeout")
})

t.Run("dispatch cleans up pending entry on failure", func(t *testing.T) {
jobKey := "cleanup-job"
payload := []byte("test payload")

// Corrupt the pool stream to force dispatch failure
err := rdb.Del(ctx, "pulse:stream:"+poolStreamName(node1.PoolName)).Err()
require.NoError(t, err, "Failed to delete pool stream")

// Attempt dispatch (should fail)
err = node1.DispatchJob(ctx, jobKey, payload)
require.Error(t, err, "Expected dispatch to fail")

// Verify pending entry was cleaned up
_, exists := node1.pendingJobsMap.Get(jobKey)
assert.False(t, exists, "Pending entry should be cleaned up after failed dispatch")
})

t.Run("dispatch cleans up pending entry on success", func(t *testing.T) {
jobKey := "success-cleanup-job"
payload := []byte("test payload")

// Dispatch job
err := node1.DispatchJob(ctx, jobKey, payload)
require.NoError(t, err, "Dispatch should succeed")
// Verify pending entry was cleaned up
require.Eventually(t, func() bool {
val, exists := node1.pendingJobsMap.Get(jobKey)
t.Logf("Got pending value: %q", val)
return !exists
}, max, delay, "Pending entry should be cleaned up after successful dispatch")
})

t.Run("dispatch with invalid pending timestamp", func(t *testing.T) {
jobKey := "invalid-timestamp-job"
payload := []byte("test payload")

// Set an invalid pending timestamp
_, err := node1.pendingJobsMap.SetAndWait(ctx, jobKey, "invalid-timestamp")
require.NoError(t, err, "Failed to set invalid pending timestamp")

// Dispatch should succeed (invalid timestamps are logged and ignored)
err = node1.DispatchJob(ctx, jobKey, payload)
assert.NoError(t, err, "Dispatch should succeed with invalid pending timestamp")
})
}

func TestNotifyWorker(t *testing.T) {
testName := strings.Replace(t.Name(), "/", "_", -1)
ctx := ptesting.NewTestContext(t)
Expand Down
2 changes: 1 addition & 1 deletion pool/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ func (w *Worker) requeueJob(ctx context.Context, job *Job) error {
if err != nil {
return fmt.Errorf("requeueJob: failed to add job to pool stream: %w", err)
}
w.node.pendingJobs.Store(eventID, nil)
w.node.pendingJobChannels.Store(eventID, nil)
if err := w.stopJob(ctx, job.Key, true); err != nil {
return fmt.Errorf("failed to stop job: %w", err)
}
Expand Down
Loading
Loading