Skip to content

Commit bbf64f4

Browse files
authored
Merge pull request #17 from domsolutions/fix-user-supplied-jwts
Fix user supplied jwts
2 parents 2c94e56 + c2d5af6 commit bbf64f4

File tree

13 files changed

+276
-89
lines changed

13 files changed

+276
-89
lines changed

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,13 @@ Flags:
8787
-H, --headers strings headers to send in request, can have multiple i.e -H 'content-type:application/json' -H' connection:close'
8888
-h, --help help for run
8989
--jwt-aud string JWT audience (aud) claim
90+
--jwt-claims string JWT custom claims
9091
--jwt-header string JWT header field name
9192
--jwt-iss string JWT issuer (iss) claim
9293
--jwt-key string JWT signing private key path
9394
--jwt-kid string JWT KID
9495
--jwt-sub string JWT subject (sub) claim
95-
--jwt-claims string JWT custom claims as a JSON string, ex: {"iat": 1719410063, "browser": "chrome"}
96+
-f, --jwts-filename string File path for pre-generated JWTs, separated by new lines
9697
-m, --method string request method (default "GET")
9798
--mtls-cert string mTLS cert path
9899
--mtls-key string mTLS cert private key path
@@ -221,6 +222,13 @@ https://github.com/domsolutions/gopayloader
221222
+-----------------------+-------------------------------+
222223
```
223224
225+
If you have your own JWTs you want to test, you can supply a file to send the JWTs i.e. `./my-jwts.txt` where each jwt is separated by a new line.
226+
227+
```shell
228+
./gopayloader run http://localhost:8081 -c 1 -r 1000000 --jwt-header "my-jwt" -f ./my-jwts.txt
229+
```
230+
231+
224232
To remove all generated jwts;
225233
226234
```shell

cmd/payloader/run.go

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,30 @@ import (
99
)
1010

1111
const (
12-
argMethod = "method"
13-
argConnections = "connections"
14-
argRequests = "requests"
15-
argKeepAlive = "disable-keep-alive"
16-
argVerifySigner = "skip-verify"
17-
argTime = "time"
18-
argMTLSKey = "mtls-key"
19-
argMTLSCert = "mtls-cert"
20-
argReadTimeout = "read-timeout"
21-
argWriteTimeout = "write-timeout"
22-
argVerbose = "verbose"
23-
argTicker = "ticker"
24-
argJWTKey = "jwt-key"
25-
argJWTSUb = "jwt-sub"
12+
argMethod = "method"
13+
argConnections = "connections"
14+
argRequests = "requests"
15+
argKeepAlive = "disable-keep-alive"
16+
argVerifySigner = "skip-verify"
17+
argTime = "time"
18+
argMTLSKey = "mtls-key"
19+
argMTLSCert = "mtls-cert"
20+
argReadTimeout = "read-timeout"
21+
argWriteTimeout = "write-timeout"
22+
argVerbose = "verbose"
23+
argTicker = "ticker"
24+
argJWTKey = "jwt-key"
25+
argJWTSUb = "jwt-sub"
2626
argJWTCustomClaims = "jwt-claims"
27-
argJWTIss = "jwt-iss"
28-
argJWTAud = "jwt-aud"
29-
argJWTHeader = "jwt-header"
30-
argJWTKid = "jwt-kid"
31-
argHeaders = "headers"
32-
argBody = "body"
33-
argBodyFile = "body-file"
34-
argClient = "client"
27+
argJWTIss = "jwt-iss"
28+
argJWTAud = "jwt-aud"
29+
argJWTHeader = "jwt-header"
30+
argJWTKid = "jwt-kid"
31+
argJWTsFilename = "jwts-filename"
32+
argHeaders = "headers"
33+
argBody = "body"
34+
argBodyFile = "body-file"
35+
argClient = "client"
3536
)
3637

3738
var (
@@ -55,6 +56,7 @@ var (
5556
jwtAud string
5657
jwtHeader string
5758
jwtKID string
59+
jwtsFilename string
5860
headers *[]string
5961
body string
6062
bodyFile string
@@ -92,6 +94,7 @@ var runCmd = &cobra.Command{
9294
jwtIss,
9395
jwtAud,
9496
jwtHeader,
97+
jwtsFilename,
9598
*headers,
9699
body,
97100
bodyFile,
@@ -128,11 +131,16 @@ func init() {
128131
runCmd.Flags().StringVar(&jwtIss, argJWTIss, "", "JWT issuer (iss) claim")
129132
runCmd.Flags().StringVar(&jwtSub, argJWTSUb, "", "JWT subject (sub) claim")
130133
runCmd.Flags().StringVar(&jwtCustomClaims, argJWTCustomClaims, "", "JWT custom claims")
134+
runCmd.Flags().StringVarP(&jwtsFilename, argJWTsFilename, "f", "", "File path for pre-generated JWTs, separated by new lines")
131135
runCmd.Flags().StringVar(&jwtHeader, argJWTHeader, "", "JWT header field name")
132136

133137
runCmd.MarkFlagsRequiredTogether(argMTLSCert, argMTLSKey)
134-
runCmd.MarkFlagsRequiredTogether(argJWTKey, argJWTHeader)
135138
runCmd.MarkFlagsMutuallyExclusive(argBody, argBodyFile)
136-
139+
runCmd.MarkFlagsMutuallyExclusive(argJWTsFilename, argJWTKid)
140+
runCmd.MarkFlagsMutuallyExclusive(argJWTsFilename, argJWTAud)
141+
runCmd.MarkFlagsMutuallyExclusive(argJWTsFilename, argJWTIss)
142+
runCmd.MarkFlagsMutuallyExclusive(argJWTsFilename, argJWTCustomClaims)
143+
runCmd.MarkFlagsMutuallyExclusive(argJWTsFilename, argJWTSUb)
144+
runCmd.MarkFlagsMutuallyExclusive(argJWTsFilename, argJWTKey)
137145
rootCmd.AddCommand(runCmd)
138146
}

cmd/payloader/test-server.go

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ import (
1111
"log"
1212
"net/http"
1313
"os"
14+
"os/signal"
1415
"path/filepath"
1516
"strconv"
1617
"strings"
18+
"syscall"
1719
"time"
1820
)
1921

@@ -87,9 +89,25 @@ var runServerCmd = &cobra.Command{
8789
},
8890
}
8991

90-
if err := server.ListenAndServe(addr); err != nil {
91-
return err
92+
errs := make(chan error)
93+
go func() {
94+
if err := server.ListenAndServe(addr); err != nil {
95+
log.Println(err)
96+
errs <- err
97+
}
98+
}()
99+
100+
c := make(chan os.Signal, 1)
101+
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
102+
103+
select {
104+
case <-c:
105+
log.Println("User cancelled, shutting down")
106+
server.Shutdown()
107+
case err := <-errs:
108+
log.Printf("Got error from server; %v \n", err)
92109
}
110+
93111
return nil
94112
}
95113

config/config.go

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@ type Config struct {
3434
JwtIss string
3535
JwtAud string
3636
JwtHeader string
37+
JwtsFilename string
3738
SendJWT bool
3839
Headers []string
3940
Body string
4041
BodyFile string
4142
Client string
4243
}
4344

44-
func NewConfig(ctx context.Context, reqURI, mTLScert, mTLSKey string, disableKeepAlive bool, reqs int64, conns uint, totalTime time.Duration, skipVerify bool, readTimeout, writeTimeout time.Duration, method string, verbose bool, ticker time.Duration, jwtKID, jwtKey, jwtSub, jwtCustomClaimsJSON, jwtIss, jwtAud, jwtHeader string, headers []string, body, bodyFile string, client string) *Config {
45+
func NewConfig(ctx context.Context, reqURI, mTLScert, mTLSKey string, disableKeepAlive bool, reqs int64, conns uint, totalTime time.Duration, skipVerify bool, readTimeout, writeTimeout time.Duration, method string, verbose bool, ticker time.Duration, jwtKID, jwtKey, jwtSub, jwtCustomClaimsJSON, jwtIss, jwtAud, jwtHeader, jwtsFilename string, headers []string, body, bodyFile string, client string) *Config {
4546
return &Config{
4647
Ctx: ctx,
4748
ReqURI: reqURI,
@@ -64,6 +65,7 @@ func NewConfig(ctx context.Context, reqURI, mTLScert, mTLSKey string, disableKee
6465
JwtIss: jwtIss,
6566
JwtAud: jwtAud,
6667
JwtHeader: jwtHeader,
68+
JwtsFilename: jwtsFilename,
6769
Headers: headers,
6870
Body: body,
6971
BodyFile: bodyFile,
@@ -139,14 +141,14 @@ func (c *Config) Validate() error {
139141
}
140142
}
141143

142-
if (c.JwtHeader == "") != (c.JwtKey == "") {
143-
if c.JwtHeader == "" {
144-
return errors.New("config: empty jwt header")
145-
}
144+
// Require JwtHeader if JwtKey or JwtsFilename is present
145+
if (c.JwtsFilename != "" || c.JwtKey != "") && c.JwtHeader == "" {
146+
return errors.New("config: empty jwt header")
147+
}
146148

147-
if c.JwtKey == "" {
148-
return errors.New("empty jwt key")
149-
}
149+
// Require JwtKey or JwtsFilename if JwtHeader is present
150+
if c.JwtHeader != "" && c.JwtsFilename == "" && c.JwtKey == "" {
151+
return errors.New("config: empty jwt filename and jwt key, one of those is needed to send requests with JWTs")
150152
}
151153

152154
if c.JwtKey != "" {
@@ -163,6 +165,20 @@ func (c *Config) Validate() error {
163165
c.SendJWT = true
164166
}
165167

168+
if c.JwtsFilename != "" {
169+
_, err := os.OpenFile(c.JwtsFilename, os.O_RDONLY, os.ModePerm)
170+
if err != nil {
171+
if os.IsNotExist(err) {
172+
return errors.New("config: jwt file does not exist: " + c.JwtsFilename)
173+
}
174+
return fmt.Errorf("config: jwt file error checking file exists; %v", err)
175+
}
176+
if c.ReqTarget == 0 {
177+
return errors.New("can only send jwts when request number is specified")
178+
}
179+
c.SendJWT = true
180+
}
181+
166182
if len(c.Headers) > 0 {
167183
for _, h := range c.Headers {
168184
if !strings.Contains(h, ":") {

pkgs/http-clients/definitions.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ type Config struct {
4141
Method string
4242
Verbose bool
4343
JwtStreamReceiver <-chan string
44-
JwtStreamErr <-chan error
4544
JWTHeader string
4645
Headers []string
4746
Body string

pkgs/jwt-generator/cache.go

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@ package jwt_generator
22

33
import (
44
"bufio"
5-
"encoding/binary"
65
"errors"
76
"fmt"
7+
"github.com/pterm/pterm"
88
"os"
9+
"strconv"
910
"strings"
1011
"time"
1112
)
1213

14+
const byteSizeCounter = 20
15+
1316
type cache struct {
1417
f *os.File
1518
count int64
@@ -20,14 +23,22 @@ func newCache(f *os.File) (*cache, error) {
2023
c := cache{f: f}
2124

2225
c.scanner = bufio.NewScanner(c.f)
26+
// Get count found on first line of the file
2327
c.scanner.Split(bufio.ScanLines)
2428
if c.scanner.Scan() {
25-
bb := make([]byte, 8)
29+
bb := make([]byte, byteSizeCounter)
2630
_, err := f.ReadAt(bb, 0)
2731
if err != nil {
2832
return nil, err
2933
}
30-
c.count = int64(binary.LittleEndian.Uint64(bb))
34+
35+
count, err := getCount(bb)
36+
if err != nil {
37+
pterm.Error.Printf("Got error reading jwt count from cache; %v", err)
38+
return nil, err
39+
}
40+
41+
c.count = count
3142
return &c, nil
3243
}
3344
return &c, nil
@@ -37,6 +48,23 @@ func (c *cache) getJwtCount() int64 {
3748
return c.count
3849
}
3950

51+
func getCount(bb []byte) (int64, error) {
52+
num := make([]byte, 0)
53+
for _, m := range bb {
54+
if m == 0 {
55+
break
56+
}
57+
num = append(num, m)
58+
}
59+
60+
s := string(num)
61+
i, err := strconv.ParseInt(s, 10, 64)
62+
if err != nil {
63+
return 0, err
64+
}
65+
return i, nil
66+
}
67+
4068
func (c *cache) get(count int64) (<-chan string, <-chan error) {
4169
recv := make(chan string, 1000000)
4270
errs := make(chan error, 1)
@@ -61,14 +89,22 @@ func (c *cache) get(count int64) (<-chan string, <-chan error) {
6189
}
6290

6391
meta := c.scanner.Bytes()
64-
if len(meta) < 8 {
92+
if len(meta) < byteSizeCounter {
6593
errs <- fmt.Errorf("jwt_generator: retrieving; corrupt jwt cache, wanted 8 bytes got %d", len(meta))
6694
close(errs)
6795
close(recv)
6896
return recv, errs
6997
}
7098

71-
if count > int64(binary.LittleEndian.Uint64(meta[0:8])) {
99+
i, err := getCount(meta)
100+
if err != nil {
101+
errs <- fmt.Errorf("failed to get jwt count; %v", err)
102+
close(errs)
103+
close(recv)
104+
return recv, errs
105+
}
106+
107+
if count > i {
72108
errs <- errors.New("jwt_generator: retrieving; not enough jwts stored in cache")
73109
close(errs)
74110
close(recv)
@@ -83,20 +119,25 @@ func (c *cache) get(count int64) (<-chan string, <-chan error) {
83119

84120
func (c *cache) retrieve(count int64, recv chan<- string, errs chan<- error) {
85121
var i int64 = 0
122+
defer func() {
123+
close(errs)
124+
close(recv)
125+
}()
86126

87127
for i = 0; i < count; i++ {
88128
if c.scanner.Scan() {
89129
recv <- string(c.scanner.Bytes())
90130
continue
91131
}
92-
// reached EOF or err
132+
93133
if err := c.scanner.Err(); err != nil {
94134
errs <- err
95-
close(errs)
135+
return
96136
}
97-
break
137+
138+
errs <- errors.New("unable to read anymore jwts from file")
139+
return
98140
}
99-
close(recv)
100141
}
101142

102143
func (c *cache) save(tokens []string) error {
@@ -110,19 +151,25 @@ func (c *cache) save(tokens []string) error {
110151
if stat.Size() > 0 {
111152
pos = stat.Size()
112153
}
154+
113155
if _, err := c.f.WriteAt([]byte(strings.Join(tokens, "\n")+"\n"), pos); err != nil {
114156
return err
115157
}
116158

117-
b := make([]byte, 8)
118-
newCount := uint64(int64(add) + c.count)
119-
binary.LittleEndian.PutUint64(b, newCount)
159+
newCount := int64(add) + c.count
160+
s := strconv.FormatInt(newCount, 10)
161+
162+
b := make([]byte, byteSizeCounter)
163+
for i, ss := range s {
164+
b[i] = byte(ss)
165+
}
166+
120167
_, err = c.f.WriteAt(b, 0)
121168
if err != nil {
122169
return err
123170
}
124171

125-
_, err = c.f.WriteAt([]byte{byte('\n')}, 9)
172+
_, err = c.f.WriteAt([]byte{byte('\n')}, byteSizeCounter)
126173
if err != nil {
127174
return err
128175
}

0 commit comments

Comments
 (0)