Skip to content

Allow heartbeating a request #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ linters-settings:
- name: var-naming
arguments: [['ID', 'URL', 'HTTP', 'API'], []]

tenv:
all: true

varcheck:
exported-fields: false # this appears to improperly detect exported variables as unused when they are used from a package with the same name
Expand All @@ -69,9 +67,7 @@ linters:
- bodyclose # checks whether HTTP response body is closed successfully
- durationcheck # check for two durations multiplied together
- errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13.
- execinquery # execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds
- exhaustive # check exhaustiveness of enum switch statements
- exportloopref # checks for pointers to enclosing loop variables
- forbidigo # Forbids identifiers
- gochecknoinits # Checks that no init functions are present in Go code
- goconst # Finds repeated strings that could be replaced by a constant
Expand All @@ -89,11 +85,11 @@ linters:
- nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL.
- predeclared # find code that shadows one of Go's predeclared identifiers
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- unconvert # Remove unnecessary type conversions
- usestdlibvars # detect the possibility to use variables/constants from the Go standard library
- whitespace # Tool for detection of leading and trailing whitespace
- usetesting # Replaced tenv (detecting os.Setenv vs t.Setenv)

issues:
max-same-issues: 50
Expand Down
122 changes: 122 additions & 0 deletions concurrency.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ func (tk *Limiter) Take(ctx context.Context, key string, requestID string, limit
return rv[key], nil
}

func (tk *Limiter) Heartbeat(ctx context.Context, key string, requestID string, limit ConcurrencyLimit) (ConcurrencyResult, error) {
rv, err := tk.heartbeatMulti(ctx, requestID, map[string]ConcurrencyLimit{key: limit}, 0)
if err != nil {
return ConcurrencyResult{}, err
}
return rv[key], nil
}

func (p *pipeline) takePipe(ctx context.Context, pipe redis.Pipeliner, rv *ConcurrencyResult) func() error {
p.buf.Reset()
_, _ = p.buf.WriteString(p.l.concurrentPrefix)
Expand Down Expand Up @@ -75,6 +83,35 @@ func (p *pipeline) takePipe(ctx context.Context, pipe redis.Pipeliner, rv *Concu
}
}

func (p *pipeline) heartbeatPipe(ctx context.Context, pipe redis.Pipeliner, rv *ConcurrencyResult) func() error {
p.buf.Reset()
_, _ = p.buf.WriteString(p.l.concurrentPrefix)
_, _ = p.buf.WriteString(rv.Key)

reqPeriod := rv.Limit.RequestMaxDuration.Round(time.Second) / time.Second
if reqPeriod <= 0 {
reqPeriod = 60
}

values := []interface{}{rv.RequestID, reqPeriod}

eval := concurrencyHeartbeat.EvalSha(ctx, pipe, []string{p.buf.String()}, values...)
return func() error {
v, err := eval.Result()
if err != nil {
return err
}
values := v.([]interface{})

ok := values[0].(int64) == 1
current := values[1].(int64)
rv.Allowed = ok
rv.Used = current
rv.Remaining = rv.Limit.Max - current
return nil
}
}

func (tk *Limiter) Release(ctx context.Context, key string, requestID string, limit ConcurrencyLimit) error {
err := tk.releaseMulti(ctx, requestID, map[string]ConcurrencyLimit{key: limit})
if err != nil {
Expand Down Expand Up @@ -200,3 +237,88 @@ func (tk *Limiter) takeMulti(ctx context.Context, requestID string, limits map[s

return rv, nil
}

type heartbeatResult struct {
key string
limit ConcurrencyLimit
cmd *redis.Cmd
}

func (tk *Limiter) heartbeatMulti(ctx context.Context, requestID string, limits map[string]ConcurrencyLimit, depth int) (map[string]ConcurrencyResult, error) {
if depth > 10 {
return nil, ErrTooManyRetries
}

results := make([]*heartbeatResult, 0, len(limits))
buf := bytes.Buffer{}
pl := tk.rdb.Pipeline()
existsCmd := concurrencyHeartbeat.Exists(ctx, pl)
for key, limit := range limits {
reqPeriod := limit.RequestMaxDuration.Round(time.Second) / time.Second
if reqPeriod <= 0 {
reqPeriod = 60
}
values := []interface{}{requestID, reqPeriod}

buf.Reset()
_, _ = buf.WriteString(tk.concurrentPrefix)
_, _ = buf.WriteString(key)

results = append(results, &heartbeatResult{
key: key,
limit: limit,
cmd: concurrencyHeartbeat.EvalSha(
ctx,
pl,
[]string{buf.String()},
values...,
),
})
}
if len(results) == 0 {
return nil, nil
}
_, err := pl.Exec(ctx)
if err != nil {
return nil, err
}

exists, err := existsCmd.Result()
if err != nil {
return nil, err
}
if len(exists) != 1 {
return nil, ErrScriptFailed
}

if !exists[0] {
err = tk.LoadScripts(ctx)
if err != nil {
return nil, err
}
return tk.heartbeatMulti(ctx, requestID, limits, depth+1)
}

rv := make(map[string]ConcurrencyResult, len(results))
for _, result := range results {
v, err := result.cmd.Result()
if err != nil {
return nil, err
}
values := v.([]interface{})

ok := values[0].(int64) == 1
current := values[1].(int64)
cr := ConcurrencyResult{
RequestID: requestID,
Key: result.key,
Allowed: ok,
Limit: result.limit,
Used: current,
Remaining: result.limit.Max - current,
}
rv[result.key] = cr
}

return rv, nil
}
5 changes: 5 additions & 0 deletions limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ func (l *Limiter) LoadScripts(ctx context.Context) error {
return fmt.Errorf("redis_rate: failed to load 'script_concurrency_take.lua': %w", err)
}

_, err = concurrencyHeartbeat.Load(ctx, l.rdb).Result()
if err != nil {
return fmt.Errorf("redis_rate: failed to load 'script_concurrency_heartbeat.lua': %w", err)
}

_, err = allowN.Load(ctx, l.rdb).Result()
if err != nil {
return fmt.Errorf("redis_rate: failed to load 'script_allow_n.lua': %w", err)
Expand Down
32 changes: 27 additions & 5 deletions pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ type Pipeline interface {

Take(ctx context.Context, key string, requestID string, limit ConcurrencyLimit) *ConcurrencyResult

Heartbeat(ctx context.Context, key string, requestID string, limit ConcurrencyLimit) *ConcurrencyResult

Release(ctx context.Context, key string, requestID string)

Exec(ctx context.Context) error
Expand All @@ -36,11 +38,12 @@ type pair[TA any, TB any] struct {
}

type pipeline struct {
l *Limiter
buf bytes.Buffer
releaseCommands []pair[string, string]
allowCommands []*Result
takeCommands []*ConcurrencyResult
l *Limiter
buf bytes.Buffer
releaseCommands []pair[string, string]
allowCommands []*Result
takeCommands []*ConcurrencyResult
heartbeatCommands []*ConcurrencyResult
}

func (p *pipeline) Allow(ctx context.Context,
Expand All @@ -66,6 +69,18 @@ func (p *pipeline) Take(ctx context.Context,
return rv
}

func (p *pipeline) Heartbeat(ctx context.Context,
key string, requestID string,
limit ConcurrencyLimit) *ConcurrencyResult {
rv := &ConcurrencyResult{
Key: key,
Limit: limit,
RequestID: requestID,
}
p.heartbeatCommands = append(p.heartbeatCommands, rv)
return rv
}

func (p *pipeline) Release(ctx context.Context,
key string, requestID string) {
p.releaseCommands = append(p.releaseCommands, pair[string, string]{key, requestID})
Expand Down Expand Up @@ -99,6 +114,13 @@ func (p *pipeline) exec(ctx context.Context, depth int) error {
}
}

if len(p.heartbeatCommands) > 0 {
scriptExistChecks = append(scriptExistChecks, concurrencyHeartbeat.Exists(ctx, pipe))
for _, v := range p.heartbeatCommands {
finishFuncs = append(finishFuncs, p.heartbeatPipe(ctx, pipe, v))
}
}

if len(p.releaseCommands) > 0 {
p.l.releasePipe(ctx, pipe, p.releaseCommands)
}
Expand Down
41 changes: 41 additions & 0 deletions script_concurrency_heartbeat.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
-- Heartbeat to extend the expiration time for a request ID.
local rate_limit_key = KEYS[1]
local request_id = ARGV[1]
local max_request_time_seconds = tonumber(ARGV[2])

-- redis returns time as an array containing two integers: seconds of the epoch
-- time (10 digits) and microseconds (6 digits). for convenience we need to
-- convert them to a floating point number. the resulting number is 16 digits,
-- bordering on the limits of a 64-bit double-precision floating point number.
-- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
local jan_1_2017 = 1483228800
local now = redis.call("TIME")
now = (now[1] - jan_1_2017) + (now[2] / 1000000)

-- Check if the request ID exists in the hash
local exists = redis.call("HEXISTS", rate_limit_key, request_id)
if exists == 0 then
return {0, 0} -- Request ID not found
end

-- Update the expiration time for the request ID
redis.call("HSET", rate_limit_key, request_id, now + max_request_time_seconds)
redis.call("EXPIRE", rate_limit_key, 5 * max_request_time_seconds)

-- Get the current count of active requests
local count = 0
local bulk = redis.call('HGETALL', rate_limit_key)
local nextkey
for i, v in ipairs(bulk) do

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is O(N) on the number of requests for a key - is that OK?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not what I'd love to do but the N is going to be small and it's the same implementation as the way we calculate concurrency for take

for i, v in ipairs(bulk) do
if i % 2 == 1 then
nextkey = v
else
if tonumber(v) < now then
redis.call("HDEL", rate_limit_key, nextkey)
else
count = count + 1
end
end
end

if i % 2 == 1 then
nextkey = v
else
if tonumber(v) >= now then
count = count + 1
end
end
end

return {1, count} -- Success, return the current count
5 changes: 5 additions & 0 deletions scripts.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@ var allowAtMostScript string
//go:embed script_concurrency_take.lua
var concurrencyTakeScript string

//go:embed script_concurrency_heartbeat.lua
var concurrencyHeartbeatScript string

var allowN = redis.NewScript(alloNScript)

var allowAtMost = redis.NewScript(allowAtMostScript)

var concurrencyTake = redis.NewScript(concurrencyTakeScript)

var concurrencyHeartbeat = redis.NewScript(concurrencyHeartbeatScript)
Loading