Skip to content

Commit 24e8ed3

Browse files
committed
fix(auth): delegate JWT parsing to github.com/golang-jwt/jwt
Signed-off-by: Marc Nuri <[email protected]>
1 parent 4d994d3 commit 24e8ed3

File tree

4 files changed

+169
-248
lines changed

4 files changed

+169
-248
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ require (
66
github.com/BurntSushi/toml v1.5.0
77
github.com/coreos/go-oidc/v3 v3.14.1
88
github.com/fsnotify/fsnotify v1.9.0
9+
github.com/golang-jwt/jwt/v4 v4.5.2
910
github.com/mark3labs/mcp-go v0.34.0
1011
github.com/pkg/errors v0.9.1
1112
github.com/spf13/afero v1.14.0

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
118118
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
119119
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
120120
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
121+
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
122+
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
121123
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
122124
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
123125
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=

pkg/http/authorization.go

Lines changed: 34 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@ package http
22

33
import (
44
"context"
5-
"encoding/base64"
6-
"encoding/json"
75
"fmt"
86
"net/http"
97
"strings"
10-
"time"
118

129
"github.com/coreos/go-oidc/v3/oidc"
10+
"github.com/golang-jwt/jwt/v4"
1311
"k8s.io/klog/v2"
1412

1513
"github.com/manusa/kubernetes-mcp-server/pkg/mcp"
@@ -55,7 +53,10 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, mcpServer *mcp
5553
// Validate the token offline for simple sanity check
5654
// Because missing expected audience and expired tokens must be
5755
// rejected already.
58-
claims, err := validateJWTToken(token, audience)
56+
claims, err := ParseJWTClaims(token)
57+
if err == nil && claims != nil {
58+
err = claims.Validate(audience)
59+
}
5960
if err != nil {
6061
klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
6162

@@ -118,80 +119,49 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, mcpServer *mcp
118119
}
119120
}
120121

121-
type JWTClaims struct {
122-
Issuer string `json:"iss"`
123-
Audience any `json:"aud"`
124-
ExpiresAt int64 `json:"exp"`
125-
Scope string `json:"scope,omitempty"`
126-
}
122+
type JWTClaims jwt.MapClaims
127123

128124
func (c *JWTClaims) GetScopes() []string {
129-
if c.Scope == "" {
130-
return nil
131-
}
132-
return strings.Fields(c.Scope)
133-
}
134-
135-
func (c *JWTClaims) ContainsAudience(audience string) bool {
136-
switch aud := c.Audience.(type) {
125+
scope := jwt.MapClaims(*c)["scope"]
126+
switch scope.(type) {
137127
case string:
138-
return aud == audience
139-
case []interface{}:
140-
for _, a := range aud {
141-
if str, ok := a.(string); ok && str == audience {
142-
return true
143-
}
144-
}
145-
case []string:
146-
for _, a := range aud {
147-
if a == audience {
148-
return true
149-
}
150-
}
128+
return strings.Fields(scope.(string))
151129
}
152-
return false
130+
return nil
153131
}
154132

155-
// validateJWTToken validates basic JWT claims without signature verification and returns the claims
156-
func validateJWTToken(token, audience string) (*JWTClaims, error) {
157-
parts := strings.Split(token, ".")
158-
if len(parts) != 3 {
159-
return nil, fmt.Errorf("invalid JWT token format")
160-
}
161-
162-
claims, err := parseJWTClaims(parts[1])
163-
if err != nil {
164-
return nil, fmt.Errorf("failed to parse JWT claims: %v", err)
165-
}
166-
167-
if claims.ExpiresAt > 0 && time.Now().Unix() > claims.ExpiresAt {
168-
return nil, fmt.Errorf("token expired")
169-
}
133+
func (c *JWTClaims) VerifyAudience(audience string) bool {
134+
return jwt.MapClaims(*c).VerifyAudience(audience, true)
135+
}
170136

171-
if !claims.ContainsAudience(audience) {
172-
return nil, fmt.Errorf("token audience mismatch: %v", claims.Audience)
173-
}
137+
func (c *JWTClaims) VerifyExpiresAt(expriesAt int64) bool {
138+
return jwt.MapClaims(*c).VerifyExpiresAt(expriesAt, true)
139+
}
174140

175-
return claims, nil
141+
func (c *JWTClaims) VerifyIssuer(issuer string) bool {
142+
return jwt.MapClaims(*c).VerifyIssuer(issuer, true)
176143
}
177144

178-
func parseJWTClaims(payload string) (*JWTClaims, error) {
179-
// Add padding if needed
180-
if len(payload)%4 != 0 {
181-
payload += strings.Repeat("=", 4-len(payload)%4)
182-
}
145+
func (c *JWTClaims) Valid() error {
146+
return jwt.MapClaims(*c).Valid()
147+
}
183148

184-
decoded, err := base64.URLEncoding.DecodeString(payload)
185-
if err != nil {
186-
return nil, fmt.Errorf("failed to decode JWT payload: %v", err)
149+
// Validate Checks if the JWT claims are valid and if the audience matches the expected one.
150+
func (c *JWTClaims) Validate(audience string) error {
151+
if err := c.Valid(); err != nil {
152+
return err
187153
}
188-
189-
var claims JWTClaims
190-
if err := json.Unmarshal(decoded, &claims); err != nil {
191-
return nil, fmt.Errorf("failed to unmarshal JWT claims: %v", err)
154+
if !c.VerifyAudience(audience) {
155+
return fmt.Errorf("token audience mismatch: %v", jwt.MapClaims(*c)["aud"])
192156
}
157+
return nil
158+
}
193159

194-
return &claims, nil
160+
func ParseJWTClaims(token string) (*JWTClaims, error) {
161+
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
162+
mapClaims := &JWTClaims{}
163+
_, _, err := parser.ParseUnverified(token, mapClaims)
164+
return mapClaims, err
195165
}
196166

197167
func validateTokenWithOIDC(ctx context.Context, provider *oidc.Provider, token, audience string) error {

0 commit comments

Comments
 (0)