Skip to content

[8.x] permit at+jwt typ header value in jwt access tokens (#126687) #126832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/126687.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 126687
summary: Permit at+jwt typ header value in jwt access tokens
area: Authentication
type: enhancement
issues:
- 119370
Original file line number Diff line number Diff line change
Expand Up @@ -456,13 +456,13 @@ public void testFailureOnInvalidHMACSignature() throws Exception {

{
// This is the correct HMAC passphrase (from build.gradle)
final SignedJWT jwt = signHmacJwt(claimsSet, HMAC_PASSPHRASE);
final SignedJWT jwt = signHmacJwt(claimsSet, HMAC_PASSPHRASE, false);
final TestSecurityClient client = getSecurityClient(jwt, Optional.of(VALID_SHARED_SECRET));
assertThat(client.authenticate(), hasEntry(User.Fields.USERNAME.getPreferredName(), username));
}
{
// This is not the correct HMAC passphrase
final SignedJWT invalidJwt = signHmacJwt(claimsSet, "invalid-HMAC-passphrase-" + randomAlphaOfLength(12));
final SignedJWT invalidJwt = signHmacJwt(claimsSet, "invalid-HMAC-passphrase-" + randomAlphaOfLength(12), false);
final TestSecurityClient client = getSecurityClient(invalidJwt, Optional.of(VALID_SHARED_SECRET));
// This fails because the HMAC is wrong
final ResponseException exception = expectThrows(ResponseException.class, client::authenticate);
Expand All @@ -487,7 +487,7 @@ public void testFailureOnRequiredClaims() throws JOSEException, IOException {
data.put("token_use", randomValueOtherThan("access", () -> randomAlphaOfLengthBetween(3, 10)));
}
final JWTClaimsSet claimsSet = buildJwt(data, Instant.now(), false, false);
final SignedJWT jwt = signHmacJwt(claimsSet, "test-HMAC/secret passphrase-value");
final SignedJWT jwt = signHmacJwt(claimsSet, "test-HMAC/secret passphrase-value", false);
final TestSecurityClient client = getSecurityClient(jwt, Optional.of(VALID_SHARED_SECRET));
final ResponseException exception = expectThrows(ResponseException.class, client::authenticate);
assertThat(exception.getResponse(), hasStatusCode(RestStatus.UNAUTHORIZED));
Expand Down Expand Up @@ -747,18 +747,18 @@ private SignedJWT buildAndSignJwtForRealm3(String principal, Instant issueTime)

private SignedJWT signJwtForRealm1(JWTClaimsSet claimsSet) throws IOException, JOSEException, ParseException {
final RSASSASigner signer = loadRsaSigner();
return signJWT(signer, "RS256", claimsSet);
return signJWT(signer, "RS256", claimsSet, false);
}

private SignedJWT signJwtForRealm2(JWTClaimsSet claimsSet) throws JOSEException, ParseException {
private SignedJWT signJwtForRealm2(JWTClaimsSet claimsSet) throws JOSEException {
// Input string is configured in build.gradle
return signHmacJwt(claimsSet, "test-HMAC/secret passphrase-value");
return signHmacJwt(claimsSet, "test-HMAC/secret passphrase-value", true);
}

private SignedJWT signJwtForRealm3(JWTClaimsSet claimsSet) throws JOSEException, ParseException, IOException {
final int bitSize = randomFrom(384, 512);
final MACSigner signer = loadHmacSigner("test-hmac-" + bitSize);
return signJWT(signer, "HS" + bitSize, claimsSet);
return signJWT(signer, "HS" + bitSize, claimsSet, false);
}

private RSASSASigner loadRsaSigner() throws IOException, ParseException, JOSEException {
Expand All @@ -781,10 +781,10 @@ private MACSigner loadHmacSigner(String keyId) throws IOException, ParseExceptio
}
}

private SignedJWT signHmacJwt(JWTClaimsSet claimsSet, String hmacPassphrase) throws JOSEException {
private SignedJWT signHmacJwt(JWTClaimsSet claimsSet, String hmacPassphrase, boolean allowAtJwtType) throws JOSEException {
final OctetSequenceKey hmac = JwkValidateUtil.buildHmacKeyFromString(hmacPassphrase);
final JWSSigner signer = new MACSigner(hmac);
return signJWT(signer, "HS256", claimsSet);
return signJWT(signer, "HS256", claimsSet, allowAtJwtType);
}

// JWT construction
Expand Down Expand Up @@ -822,10 +822,14 @@ static JWTClaimsSet buildJwt(Map<String, Object> claims, Instant issueTime, bool
return builder.build();
}

static SignedJWT signJWT(JWSSigner signer, String algorithm, JWTClaimsSet claimsSet) throws JOSEException {
static SignedJWT signJWT(JWSSigner signer, String algorithm, JWTClaimsSet claimsSet, boolean allowAtJwtType) throws JOSEException {
final JWSHeader.Builder builder = new JWSHeader.Builder(JWSAlgorithm.parse(algorithm));
if (randomBoolean()) {
builder.type(JOSEObjectType.JWT);
if (allowAtJwtType && randomBoolean()) {
builder.type(new JOSEObjectType("at+jwt"));
} else {
builder.type(JOSEObjectType.JWT);
}
}
final JWSHeader jwtHeader = builder.build();
final SignedJWT jwt = new SignedJWT(jwtHeader, claimsSet);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ private SignedJWT buildAndSignJwt(String principal, String dn, Instant issueTime
issueTime
);
final RSASSASigner signer = loadRsaSigner();
return JwtRestIT.signJWT(signer, "RS256", claimsSet);
return JwtRestIT.signJWT(signer, "RS256", claimsSet, false);
}

private RSASSASigner loadRsaSigner() throws IOException, ParseException, JOSEException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ private static List<JwtFieldValidator> configureFieldValidatorsForIdToken(RealmC
}

return List.of(
JwtTypeValidator.INSTANCE,
JwtTypeValidator.ID_TOKEN_INSTANCE,
new JwtStringClaimValidator("iss", true, List.of(realmConfig.getSetting(JwtRealmSettings.ALLOWED_ISSUER)), List.of()),
subjectClaimValidator,
new JwtStringClaimValidator("aud", false, realmConfig.getSetting(JwtRealmSettings.ALLOWED_AUDIENCES), List.of()),
Expand All @@ -157,7 +157,7 @@ private static List<JwtFieldValidator> configureFieldValidatorsForAccessToken(
final Clock clock = Clock.systemUTC();

return List.of(
JwtTypeValidator.INSTANCE,
JwtTypeValidator.ACCESS_TOKEN_INSTANCE,
new JwtStringClaimValidator("iss", true, List.of(realmConfig.getSetting(JwtRealmSettings.ALLOWED_ISSUER)), List.of()),
getSubjectClaimValidator(realmConfig, fallbackClaimLookup),
new JwtStringClaimValidator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

public class JwtTypeValidator implements JwtFieldValidator {

private static final JOSEObjectTypeVerifier<SecurityContext> JWT_HEADER_TYPE_VERIFIER = new DefaultJOSEObjectTypeVerifier<>(
JOSEObjectType.JWT,
null
);
private final JOSEObjectTypeVerifier<SecurityContext> JWT_HEADER_TYPE_VERIFIER;
private static final JOSEObjectType AT_PLUS_JWT = new JOSEObjectType("at+jwt");

public static final JwtTypeValidator INSTANCE = new JwtTypeValidator();
public static final JwtTypeValidator ID_TOKEN_INSTANCE = new JwtTypeValidator(JOSEObjectType.JWT, null);

private JwtTypeValidator() {}
// strictly speaking, this should only permit `at+jwt`, but removing the other two options is a breaking change
public static final JwtTypeValidator ACCESS_TOKEN_INSTANCE = new JwtTypeValidator(JOSEObjectType.JWT, AT_PLUS_JWT, null);

private JwtTypeValidator(JOSEObjectType... allowedTypes) {
JWT_HEADER_TYPE_VERIFIER = new DefaultJOSEObjectTypeVerifier<>(allowedTypes);
}

public void validate(JWSHeader jwsHeader, JWTClaimsSet jwtClaimsSet) {
final JOSEObjectType jwtHeaderType = jwsHeader.getType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,21 @@

package org.elasticsearch.xpack.security.authc.jwt;

import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.util.Base64URL;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;

import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;

import java.text.ParseException;
import java.util.Map;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class JwtAuthenticatorIdTokenTypeTests extends JwtAuthenticatorTests {

Expand All @@ -28,4 +38,23 @@ public void testSubjectIsRequired() throws ParseException {
public void testInvalidIssuerIsCheckedBeforeAlgorithm() throws ParseException {
doTestInvalidIssuerIsCheckedBeforeAlgorithm(buildJwtAuthenticator());
}

public void testAccessTokenHeaderTypeIsRejected() throws ParseException {
final JWTClaimsSet claimsSet = JWTClaimsSet.parse(Map.of());
final SignedJWT signedJWT = new SignedJWT(
JWSHeader.parse(Map.of("alg", allowedAlgorithm, "typ", "at+jwt")).toBase64URL(),
claimsSet.toPayload().toBase64URL(),
Base64URL.encode("signature")
);

final JwtAuthenticationToken jwtAuthenticationToken = mock(JwtAuthenticationToken.class);
when(jwtAuthenticationToken.getSignedJWT()).thenReturn(signedJWT);
when(jwtAuthenticationToken.getJWTClaimsSet()).thenReturn(signedJWT.getJWTClaimsSet());

final PlainActionFuture<JWTClaimsSet> future = new PlainActionFuture<>();
final JwtAuthenticator jwtAuthenticator = buildJwtAuthenticator();
jwtAuthenticator.authenticate(jwtAuthenticationToken, future);
final Exception e = expectThrows(IllegalArgumentException.class, future::actionGet);
assertThat(e.getMessage(), equalTo("invalid jwt typ header"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.security.authc.jwt;

import com.nimbusds.jose.JOSEObjectType;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.openid.connect.sdk.Nonce;
Expand Down Expand Up @@ -134,7 +133,7 @@ protected SecureString randomJwt(JwtIssuerAndRealm jwtIssuerAndRealm, User user)

final Instant now = Instant.now().truncatedTo(ChronoUnit.SECONDS);
unsignedJwt = JwtTestCase.buildUnsignedJwt(
randomBoolean() ? null : JOSEObjectType.JWT.toString(), // kty
randomFrom("at+jwt", "JWT", null), // typ
randomBoolean() ? null : jwk.getKeyID(), // kid
algJwkPair.alg(), // alg
randomAlphaOfLengthBetween(10, 20), // jwtID
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,54 @@

public class JwtTypeValidatorTests extends ESTestCase {

public void testValidType() throws ParseException {
public void testValidIdTokenType() throws ParseException {
final String algorithm = randomAlphaOfLengthBetween(3, 8);

// typ is allowed to be missing
final JWSHeader jwsHeader = JWSHeader.parse(
randomFrom(Map.of("alg", randomAlphaOfLengthBetween(3, 8)), Map.of("typ", "JWT", "alg", randomAlphaOfLengthBetween(3, 8)))
randomFrom(
// typ is allowed to be missing
Map.of("alg", algorithm),
Map.of("typ", "JWT", "alg", algorithm)
)
);

try {
JwtTypeValidator.INSTANCE.validate(jwsHeader, JWTClaimsSet.parse(Map.of()));
JwtTypeValidator.ID_TOKEN_INSTANCE.validate(jwsHeader, JWTClaimsSet.parse(Map.of()));
} catch (Exception e) {
throw new AssertionError("validation should have passed without exception", e);
}
}

public void testValidAccessTokenType() throws ParseException {
final String algorithm = randomAlphaOfLengthBetween(3, 8);

final JWSHeader jwsHeader = JWSHeader.parse(
randomFrom(
// typ is allowed to be missing
Map.of("alg", algorithm),
Map.of("typ", "JWT", "alg", algorithm),
Map.of("typ", "at+jwt", "alg", algorithm),
Map.of("typ", "AT+JWT", "alg", algorithm)
)
);

try {
JwtTypeValidator.ACCESS_TOKEN_INSTANCE.validate(jwsHeader, JWTClaimsSet.parse(Map.of()));
} catch (Exception e) {
throw new AssertionError("validation should have passed without exception", e);
}
}

public void testInvalidType() throws ParseException {
final JwtTypeValidator validator = randomFrom(JwtTypeValidator.ID_TOKEN_INSTANCE, JwtTypeValidator.ACCESS_TOKEN_INSTANCE);

final JWSHeader jwsHeader = JWSHeader.parse(
Map.of("typ", randomAlphaOfLengthBetween(4, 8), "alg", randomAlphaOfLengthBetween(3, 8))
);

final IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> JwtTypeValidator.INSTANCE.validate(jwsHeader, JWTClaimsSet.parse(Map.of()))
() -> validator.validate(jwsHeader, JWTClaimsSet.parse(Map.of()))
);
assertThat(e.getMessage(), containsString("invalid jwt typ header"));
}
Expand Down