Skip to content

Commit 51e1998

Browse files
authored
Merge pull request go-redis#8 from heynemann/v4
Added support for verifying the allow limit without incrementing it.
2 parents 423ef21 + e3420f4 commit 51e1998

File tree

3 files changed

+72
-20
lines changed

3 files changed

+72
-20
lines changed

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ func handler(w http.ResponseWriter, req *http.Request, rateLimiter *rate.Limiter
3535
fmt.Fprint(w, "Rate limit remaining: ", strconv.FormatInt(limit-rate, 10))
3636
}
3737

38+
func statusHandler(w http.ResponseWriter, req *http.Request, rateLimiter *rate.Limiter) {
39+
userID := "user-12345"
40+
limit := int64(5)
41+
42+
//With increment 0, we just retrieve the current limit
43+
rate, reset, allowed := rateLimiter.AllowN(userID, limit, time.Minute, 0)
44+
fmt.Fprintf(w, "Current rate: %d", rate)
45+
fmt.Fprintf(w, "Reset: %d", reset)
46+
fmt.Fprintf(w, "Allowed: %v", allowed)
47+
}
48+
3849
func main() {
3950
ring := redis.NewRing(&redis.RingOptions{
4051
Addrs: map[string]string{
@@ -48,6 +59,10 @@ func main() {
4859
handler(w, req, limiter)
4960
})
5061

62+
http.HandleFunc("/status", func(w http.ResponseWriter, req *http.Request) {
63+
statusHandler(w, req, limiter)
64+
})
65+
5166
http.HandleFunc("/favicon.ico", http.NotFound)
5267
log.Println("listening on localhost:8888...")
5368
log.Println(http.ListenAndServe("localhost:8888", nil))

rate.go

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
package rate // import "gopkg.in/go-redis/rate.v4"
1+
package rate
22

33
import (
44
"fmt"
55
"time"
66

7-
timerate "golang.org/x/time/rate"
7+
redis "gopkg.in/redis.v4"
88

9-
"gopkg.in/redis.v4"
9+
timerate "golang.org/x/time/rate"
1010
)
1111

1212
const redisPrefix = "rate"
@@ -15,38 +15,46 @@ type rediser interface {
1515
Pipelined(func(pipe *redis.Pipeline) error) ([]redis.Cmder, error)
1616
}
1717

18+
//Limiter type
1819
type Limiter struct {
1920
fallbackLimiter *timerate.Limiter
2021
redis rediser
2122
}
2223

23-
// A Limiter controls how frequently events are allowed to happen. It uses
24-
// the redis to store data and fallbacks to the fallbackLimiter
25-
// when Redis Server is not available.
24+
// NewLimiter creates a limiter that controls how frequently events
25+
// are allowed to happen. It uses redis to store data and fallbacks
26+
// to the fallbackLimiter when Redis Server is not available.
2627
func NewLimiter(redis rediser, fallbackLimiter *timerate.Limiter) *Limiter {
2728
return &Limiter{
2829
fallbackLimiter: fallbackLimiter,
2930
redis: redis,
3031
}
3132
}
3233

33-
// Allow reports whether an event with given name may happen at time now.
34-
// It allows up to maxn events within duration dur.
35-
func (l *Limiter) Allow(name string, maxn int64, dur time.Duration) (count, reset int64, allow bool) {
34+
// AllowN reports whether an event with given name may happen at time now.
35+
// It allows up to maxn events within duration dur, with each interaction incrementing
36+
// the limit by n.
37+
func (l *Limiter) AllowN(name string, maxn int64, dur time.Duration, n int64) (count, reset int64, allow bool) {
3638
udur := int64(dur / time.Second)
3739
slot := time.Now().Unix() / udur
3840
reset = (slot + 1) * udur
3941
allow = l.fallbackLimiter.Allow()
4042

4143
name = fmt.Sprintf("%s:%s-%d", redisPrefix, name, slot)
42-
count, err := l.incr(name, dur)
44+
count, err := l.incr(name, dur, n)
4345
if err == nil {
4446
allow = count <= maxn
4547
}
4648

4749
return count, reset, allow
4850
}
4951

52+
// Allow reports whether an event with given name may happen at time now.
53+
// It allows up to maxn events within duration dur.
54+
func (l *Limiter) Allow(name string, maxn int64, dur time.Duration) (count, reset int64, allow bool) {
55+
return l.AllowN(name, maxn, dur, 1)
56+
}
57+
5058
// AllowMinute is shorthand for Allow(name, maxn, time.Minute).
5159
func (l *Limiter) AllowMinute(name string, maxn int64) (int64, int64, bool) {
5260
return l.Allow(name, maxn, time.Minute)
@@ -79,7 +87,7 @@ func (l *Limiter) AllowRate(name string, rateLimit timerate.Limit) (delay time.D
7987
allow = l.fallbackLimiter.Allow()
8088

8189
name = fmt.Sprintf("%s:%s-%d-%d", redisPrefix, name, dur, slot)
82-
count, err := l.incr(name, dur)
90+
count, err := l.incr(name, dur, 1)
8391
if err == nil {
8492
allow = count <= limit
8593
}
@@ -91,10 +99,10 @@ func (l *Limiter) AllowRate(name string, rateLimit timerate.Limit) (delay time.D
9199
return delay, allow
92100
}
93101

94-
func (l *Limiter) incr(name string, dur time.Duration) (int64, error) {
102+
func (l *Limiter) incr(name string, dur time.Duration, n int64) (int64, error) {
95103
var incr *redis.IntCmd
96104
_, err := l.redis.Pipelined(func(pipe *redis.Pipeline) error {
97-
incr = pipe.Incr(name)
105+
incr = pipe.IncrBy(name, n)
98106
pipe.Expire(name, dur)
99107
return nil
100108
})

rate_test.go

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
1-
package rate_test
1+
package rate
22

33
import (
44
"testing"
55
"time"
66

7-
timerate "golang.org/x/time/rate"
8-
"gopkg.in/redis.v4"
7+
redis "gopkg.in/redis.v4"
98

10-
"gopkg.in/go-redis/rate.v4"
9+
timerate "golang.org/x/time/rate"
1110
)
1211

13-
func rateLimiter() *rate.Limiter {
12+
func rateLimiter() *Limiter {
1413
ring := redis.NewRing(&redis.RingOptions{
1514
Addrs: map[string]string{"server0": ":6379"},
1615
})
1716
if err := ring.FlushDb().Err(); err != nil {
1817
panic(err)
1918
}
2019
fallbackLimiter := timerate.NewLimiter(timerate.Every(time.Millisecond), 100)
21-
return rate.NewLimiter(ring, fallbackLimiter)
20+
return NewLimiter(ring, fallbackLimiter)
2221
}
2322

2423
func TestAllow(t *testing.T) {
@@ -90,7 +89,7 @@ func TestAllowRateSecond(t *testing.T) {
9089
func TestRedisIsDown(t *testing.T) {
9190
ring := redis.NewRing(&redis.RingOptions{})
9291
limiter := timerate.NewLimiter(timerate.Every(time.Second), 1)
93-
l := rate.NewLimiter(ring, limiter)
92+
l := NewLimiter(ring, limiter)
9493

9594
rate, _, allow := l.AllowMinute("test_id", 1)
9695
if !allow {
@@ -109,6 +108,36 @@ func TestRedisIsDown(t *testing.T) {
109108
}
110109
}
111110

111+
func TestAllowN(t *testing.T) {
112+
l := rateLimiter()
113+
114+
rate, reset, allow := l.AllowN("test_allow_n", 1, time.Minute, 1)
115+
if !allow {
116+
t.Fatalf("rate limited with rate %d", rate)
117+
}
118+
if rate != 1 {
119+
t.Fatalf("got %d, wanted 1", rate)
120+
}
121+
dur := time.Duration(reset-time.Now().Unix()) * time.Second
122+
if dur > time.Minute {
123+
t.Fatalf("got %s, wanted <= %s", dur, time.Minute)
124+
}
125+
126+
l.AllowN("test_allow_n", 1, time.Minute, 2)
127+
128+
rate, reset, allow = l.AllowN("test_allow_n", 1, time.Minute, 0)
129+
if allow {
130+
t.Fatalf("should rate limit with rate %d", rate)
131+
}
132+
if rate != 3 {
133+
t.Fatalf("got %d, wanted 3", rate)
134+
}
135+
dur = time.Duration(reset-time.Now().Unix()) * time.Second
136+
if dur > time.Minute {
137+
t.Fatalf("got %s, wanted <= %s", dur, time.Minute)
138+
}
139+
}
140+
112141
func durEqual(got, wanted time.Duration) bool {
113142
return got > 0 && got < wanted
114143
}

0 commit comments

Comments
 (0)