Skip to content

Commit 8941653

Browse files
authored
Merge pull request #3 from nmaltais/master
Add --jwt-claims flag to add custom claims to the generated JWTs via CLI
2 parents c17c4aa + 30fa3ba commit 8941653

File tree

7 files changed

+202
-148
lines changed

7 files changed

+202
-148
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ Flags:
8989
--jwt-key string JWT signing private key path
9090
--jwt-kid string JWT KID
9191
--jwt-sub string JWT subject (sub) claim
92+
--jwt-claims string JWT custom claims as a JSON string, ex: {"iat": 1719410063, "browser": "chrome"}
9293
-m, --method string request method (default "GET")
9394
--mtls-cert string mTLS cert path
9495
--mtls-key string mTLS cert private key path

cmd/payloader/run.go

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,29 @@ import (
99
)
1010

1111
const (
12-
argMethod = "method"
13-
argConnections = "connections"
14-
argRequests = "requests"
15-
argKeepAlive = "disable-keep-alive"
16-
argVerifySigner = "skip-verify"
17-
argTime = "time"
18-
argMTLSKey = "mtls-key"
19-
argMTLSCert = "mtls-cert"
20-
argReadTimeout = "read-timeout"
21-
argWriteTimeout = "write-timeout"
22-
argVerbose = "verbose"
23-
argTicker = "ticker"
24-
argJWTKey = "jwt-key"
25-
argJWTSUb = "jwt-sub"
26-
argJWTIss = "jwt-iss"
27-
argJWTAud = "jwt-aud"
28-
argJWTHeader = "jwt-header"
29-
argJWTKid = "jwt-kid"
30-
argHeaders = "headers"
31-
argBody = "body"
32-
argBodyFile = "body-file"
33-
argClient = "client"
12+
argMethod = "method"
13+
argConnections = "connections"
14+
argRequests = "requests"
15+
argKeepAlive = "disable-keep-alive"
16+
argVerifySigner = "skip-verify"
17+
argTime = "time"
18+
argMTLSKey = "mtls-key"
19+
argMTLSCert = "mtls-cert"
20+
argReadTimeout = "read-timeout"
21+
argWriteTimeout = "write-timeout"
22+
argVerbose = "verbose"
23+
argTicker = "ticker"
24+
argJWTKey = "jwt-key"
25+
argJWTSUb = "jwt-sub"
26+
argJWTCustomClaims = "jwt-claims"
27+
argJWTIss = "jwt-iss"
28+
argJWTAud = "jwt-aud"
29+
argJWTHeader = "jwt-header"
30+
argJWTKid = "jwt-kid"
31+
argHeaders = "headers"
32+
argBody = "body"
33+
argBodyFile = "body-file"
34+
argClient = "client"
3435
)
3536

3637
var (
@@ -49,6 +50,7 @@ var (
4950
ticker time.Duration
5051
jwtKey string
5152
jwtSub string
53+
jwtCustomClaims string
5254
jwtIss string
5355
jwtAud string
5456
jwtHeader string
@@ -86,6 +88,7 @@ var runCmd = &cobra.Command{
8688
jwtKID,
8789
jwtKey,
8890
jwtSub,
91+
jwtCustomClaims,
8992
jwtIss,
9093
jwtAud,
9194
jwtHeader,
@@ -124,6 +127,7 @@ func init() {
124127
runCmd.Flags().StringVar(&jwtAud, argJWTAud, "", "JWT audience (aud) claim")
125128
runCmd.Flags().StringVar(&jwtIss, argJWTIss, "", "JWT issuer (iss) claim")
126129
runCmd.Flags().StringVar(&jwtSub, argJWTSUb, "", "JWT subject (sub) claim")
130+
runCmd.Flags().StringVar(&jwtCustomClaims, argJWTCustomClaims, "", "JWT custom claims")
127131
runCmd.Flags().StringVar(&jwtHeader, argJWTHeader, "", "JWT header field name")
128132

129133
runCmd.MarkFlagsRequiredTogether(argMTLSCert, argMTLSKey)

config/config.go

Lines changed: 77 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -9,62 +9,65 @@ import (
99
"regexp"
1010
"strings"
1111
"time"
12+
"encoding/json"
1213
)
1314

1415
type Config struct {
15-
Ctx context.Context
16-
ReqURI string
17-
DisableKeepAlive bool
18-
ReqTarget int64
19-
Conns uint
20-
Duration time.Duration
21-
MTLSKey string
22-
MTLSCert string
23-
SkipVerify bool
24-
ReadTimeout time.Duration
25-
WriteTimeout time.Duration
26-
Method string
27-
Verbose bool
28-
VerboseTicker time.Duration
29-
JwtKID string
30-
JwtKey string
31-
JwtSub string
32-
JwtIss string
33-
JwtAud string
34-
JwtHeader string
35-
SendJWT bool
36-
Headers []string
37-
Body string
38-
BodyFile string
39-
Client string
16+
Ctx context.Context
17+
ReqURI string
18+
DisableKeepAlive bool
19+
ReqTarget int64
20+
Conns uint
21+
Duration time.Duration
22+
MTLSKey string
23+
MTLSCert string
24+
SkipVerify bool
25+
ReadTimeout time.Duration
26+
WriteTimeout time.Duration
27+
Method string
28+
Verbose bool
29+
VerboseTicker time.Duration
30+
JwtKID string
31+
JwtKey string
32+
JwtSub string
33+
JwtCustomClaimsJSON string
34+
JwtIss string
35+
JwtAud string
36+
JwtHeader string
37+
SendJWT bool
38+
Headers []string
39+
Body string
40+
BodyFile string
41+
Client string
4042
}
4143

42-
func NewConfig(ctx context.Context, reqURI, mTLScert, mTLSKey string, disableKeepAlive bool, reqs int64, conns uint, totalTime time.Duration, skipVerify bool, readTimeout, writeTimeout time.Duration, method string, verbose bool, ticker time.Duration, jwtKID, jwtKey, jwtSub, jwtIss, jwtAud, jwtHeader string, headers []string, body, bodyFile string, client string) *Config {
44+
func NewConfig(ctx context.Context, reqURI, mTLScert, mTLSKey string, disableKeepAlive bool, reqs int64, conns uint, totalTime time.Duration, skipVerify bool, readTimeout, writeTimeout time.Duration, method string, verbose bool, ticker time.Duration, jwtKID, jwtKey, jwtSub, jwtCustomClaimsJSON, jwtIss, jwtAud, jwtHeader string, headers []string, body, bodyFile string, client string) *Config {
4345
return &Config{
44-
Ctx: ctx,
45-
ReqURI: reqURI,
46-
MTLSKey: mTLSKey,
47-
MTLSCert: mTLScert,
48-
DisableKeepAlive: disableKeepAlive,
49-
ReqTarget: reqs,
50-
Conns: conns,
51-
Duration: totalTime,
52-
SkipVerify: skipVerify,
53-
ReadTimeout: readTimeout,
54-
WriteTimeout: writeTimeout,
55-
Method: method,
56-
Verbose: verbose,
57-
VerboseTicker: ticker,
58-
JwtKID: jwtKID,
59-
JwtKey: jwtKey,
60-
JwtSub: jwtSub,
61-
JwtIss: jwtIss,
62-
JwtAud: jwtAud,
63-
JwtHeader: jwtHeader,
64-
Headers: headers,
65-
Body: body,
66-
BodyFile: bodyFile,
67-
Client: client,
46+
Ctx: ctx,
47+
ReqURI: reqURI,
48+
MTLSKey: mTLSKey,
49+
MTLSCert: mTLScert,
50+
DisableKeepAlive: disableKeepAlive,
51+
ReqTarget: reqs,
52+
Conns: conns,
53+
Duration: totalTime,
54+
SkipVerify: skipVerify,
55+
ReadTimeout: readTimeout,
56+
WriteTimeout: writeTimeout,
57+
Method: method,
58+
Verbose: verbose,
59+
VerboseTicker: ticker,
60+
JwtKID: jwtKID,
61+
JwtKey: jwtKey,
62+
JwtSub: jwtSub,
63+
JwtCustomClaimsJSON: jwtCustomClaimsJSON,
64+
JwtIss: jwtIss,
65+
JwtAud: jwtAud,
66+
JwtHeader: jwtHeader,
67+
Headers: headers,
68+
Body: body,
69+
BodyFile: bodyFile,
70+
Client: client,
6871
}
6972
}
7073

@@ -83,6 +86,22 @@ var allowedMethods = [4]string{
8386
"DELETE",
8487
}
8588

89+
// Converts jwtCustomClaimsJSON from string to map[string]interface{}
90+
func JwtCustomClaimsJSONStringToMap(jwtCustomClaimsJSON string) (map[string]interface{}, error) {
91+
if jwtCustomClaimsJSON == "" {
92+
return nil, nil
93+
}
94+
95+
jwtCustomClaimsMap := map[string]interface{}{}
96+
97+
err := json.Unmarshal([]byte(jwtCustomClaimsJSON), &jwtCustomClaimsMap)
98+
if err != nil {
99+
return nil, err
100+
}
101+
102+
return jwtCustomClaimsMap, nil
103+
}
104+
86105
func (c *Config) Validate() error {
87106
if _, err := url.ParseRequestURI(c.ReqURI); err != nil {
88107
return fmt.Errorf("config: invalid request uri, got error %v", err)
@@ -180,6 +199,14 @@ func (c *Config) Validate() error {
180199
if c.ReqTarget == 0 && c.Duration == 0 {
181200
return errors.New("config: ReqTarget 0 and Duration 0")
182201
}
202+
203+
if c.JwtCustomClaimsJSON != "" {
204+
_, err := JwtCustomClaimsJSONStringToMap(c.JwtCustomClaimsJSON)
205+
if err != nil {
206+
return fmt.Errorf("config: failed to parse custom json in --jwt-claims, got error; %v", err)
207+
}
208+
}
209+
183210
return nil
184211
}
185212

pkgs/jwt-generator/jwt.go

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ import (
66
"encoding/hex"
77
"errors"
88
"fmt"
9+
"strings"
910
jwt_signer "github.com/domsolutions/gopayloader/pkgs/jwt-signer"
1011
"github.com/domsolutions/gopayloader/pkgs/jwt-signer/definition"
12+
config "github.com/domsolutions/gopayloader/config"
1113
"github.com/golang-jwt/jwt"
1214
"github.com/google/uuid"
1315
"github.com/pterm/pterm"
@@ -22,15 +24,16 @@ const (
2224
)
2325

2426
type Config struct {
25-
Ctx context.Context
26-
Kid string
27-
JwtKeyPath string
28-
jwtKeyBlob []byte
29-
JwtSub string
30-
JwtIss string
31-
JwtAud string
32-
signer definition.Signer
33-
store *cache
27+
Ctx context.Context
28+
Kid string
29+
JwtKeyPath string
30+
jwtKeyBlob []byte
31+
JwtSub string
32+
JwtCustomClaimsJSON string
33+
JwtIss string
34+
JwtAud string
35+
signer definition.Signer
36+
store *cache
3437
}
3538

3639
type JWTGenerator struct {
@@ -60,7 +63,9 @@ func (j *JWTGenerator) getFileName(dir string) string {
6063
hash.Write([]byte(j.config.JwtAud))
6164
hash.Write([]byte(j.config.JwtIss))
6265
hash.Write([]byte(j.config.JwtSub))
63-
hash.Write(j.config.jwtKeyBlob)
66+
hash.Write([]byte(j.config.JwtCustomClaimsJSON))
67+
strippedKey := strings.ReplaceAll(strings.ReplaceAll(string(j.config.jwtKeyBlob), "\r", ""), "\n", "") // Replace \r and \n to have the same value in Windows and Linux
68+
hash.Write([]byte(strippedKey))
6469
hash.Write([]byte(j.config.Kid))
6570
return filepath.Join(dir, "gopayloader-jwtstore-"+hex.EncodeToString(hash.Sum(nil))+".txt")
6671
}
@@ -163,8 +168,8 @@ func (j *JWTGenerator) generate(limit int64, errs chan<- error, response chan<-
163168
var err error
164169
var i int64 = 0
165170

171+
claims := j.commonClaims() // Claims common to all JWTs, computed only once
166172
for i = 0; i < limit; i++ {
167-
claims := j.commonClaims()
168173
claims["jti"] = uuid.New().String()
169174
tokens[i], err = j.config.signer.Generate(claims)
170175
if err != nil {
@@ -187,5 +192,18 @@ func (j *JWTGenerator) commonClaims() jwt.MapClaims {
187192
claims["iss"] = j.config.JwtIss
188193
}
189194
claims["exp"] = time.Now().Add(24 * time.Hour * 365).Unix()
195+
196+
if j.config.JwtCustomClaimsJSON != "" {
197+
// At this point the JSON in JwtCustomClaimsJSON has already been validated, but checking for errors again in case the workflow changes in the future
198+
jwtCustomClaimsMap, err := config.JwtCustomClaimsJSONStringToMap(j.config.JwtCustomClaimsJSON)
199+
if err != nil {
200+
return claims // Return claims if there's an error
201+
}
202+
for key, value := range jwtCustomClaimsMap {
203+
if key != "" {
204+
claims[key] = value
205+
}
206+
}
207+
}
190208
return claims
191209
}

pkgs/payloader/payloader.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,13 @@ func (p *PayLoader) handleReqs() (*GoPayloaderResults, error) {
101101
pterm.Info.Printf("Sending jwts with requests, checking for jwts in cache\n")
102102

103103
jwt := jwt_generator.NewJWTGenerator(&jwt_generator.Config{
104-
Ctx: p.config.Ctx,
105-
Kid: p.config.JwtKID,
106-
JwtKeyPath: p.config.JwtKey,
107-
JwtSub: p.config.JwtSub,
108-
JwtIss: p.config.JwtIss,
109-
JwtAud: p.config.JwtAud,
104+
Ctx: p.config.Ctx,
105+
Kid: p.config.JwtKID,
106+
JwtKeyPath: p.config.JwtKey,
107+
JwtSub: p.config.JwtSub,
108+
JwtCustomClaimsJSON: p.config.JwtCustomClaimsJSON,
109+
JwtIss: p.config.JwtIss,
110+
JwtAud: p.config.JwtAud,
110111
})
111112

112113
if err := os.MkdirAll(JwtCacheDir, 0755); err != nil {

0 commit comments

Comments
 (0)