Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -404,11 +404,9 @@ public void init(B http) throws Exception {
oidcAuthorizationCodeAuthenticationProvider.setAuthoritiesMapper(userAuthoritiesMapper);
oidcAuthorizedClientRefreshedEventListener.setAuthoritiesMapper(userAuthoritiesMapper);
}
oidcAuthorizationCodeAuthenticationProvider = this.postProcess(oidcAuthorizationCodeAuthenticationProvider);
http.authenticationProvider(oidcAuthorizationCodeAuthenticationProvider);
http.authenticationProvider(this.postProcess(oidcAuthorizationCodeAuthenticationProvider));

oidcAuthorizedClientRefreshedEventListener = this.postProcess(oidcAuthorizedClientRefreshedEventListener);
registerDelegateApplicationListener(oidcAuthorizedClientRefreshedEventListener);
registerDelegateApplicationListener(this.postProcess(oidcAuthorizedClientRefreshedEventListener));
configureOidcUserRefreshedEventListener(http);
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mockito;

import org.springframework.beans.factory.BeanCreationException;
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
Expand All @@ -43,10 +44,13 @@
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.authentication.event.AuthenticationSuccessEvent;
import org.springframework.security.config.ObjectPostProcessor;
import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configurers.oauth2.client.OAuth2LoginConfigurerTests.OAuth2LoginConfigCustomWithPostProcessor.SpyObjectPostProcessor;
import org.springframework.security.config.oauth2.client.CommonOAuth2Provider;
import org.springframework.security.config.test.SpringTestContext;
import org.springframework.security.config.test.SpringTestContextExtension;
Expand Down Expand Up @@ -709,6 +713,22 @@ public void oidcLoginWhenOAuth2ClientBeansConfiguredThenNotShared() throws Excep
verifyNoInteractions(clientRegistrationRepository, authorizedClientRepository);
}

// gh-17175
@Test
public void oauth2LoginWhenAuthenticationProviderPostProcessorThenUses() throws Exception {
loadConfig(OAuth2LoginConfigCustomWithPostProcessor.class);
// setup authorization request
OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response);
// setup authentication parameters
this.request.setParameter("code", "code123");
this.request.setParameter("state", authorizationRequest.getState());
// perform test
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
// assertions
verify(this.context.getBean(SpyObjectPostProcessor.class).spy).authenticate(any());
}

private void loadConfig(Class<?>... configs) {
AnnotationConfigWebApplicationContext applicationContext = new AnnotationConfigWebApplicationContext();
applicationContext.register(configs);
Expand Down Expand Up @@ -1307,6 +1327,52 @@ OAuth2AuthorizedClientRepository authorizedClientRepository() {

}

@Configuration
@EnableWebSecurity
static class OAuth2LoginConfigCustomWithPostProcessor {

private final ClientRegistrationRepository clientRegistrationRepository = new InMemoryClientRegistrationRepository(
GOOGLE_CLIENT_REGISTRATION);

private final ObjectPostProcessor<AuthenticationProvider> postProcessor = new SpyObjectPostProcessor();

@Bean
SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.oauth2Login((oauth2Login) -> oauth2Login
.clientRegistrationRepository(this.clientRegistrationRepository)
.withObjectPostProcessor(this.postProcessor)
);
// @formatter:on
return http.build();
}

@Bean
ObjectPostProcessor<AuthenticationProvider> mockPostProcessor() {
return this.postProcessor;
}

@Bean
HttpSessionOAuth2AuthorizationRequestRepository oauth2AuthorizationRequestRepository() {
return new HttpSessionOAuth2AuthorizationRequestRepository();
}

static class SpyObjectPostProcessor implements ObjectPostProcessor<AuthenticationProvider> {

AuthenticationProvider spy;

@Override
public <O extends AuthenticationProvider> O postProcess(O object) {
O spy = Mockito.spy(object);
this.spy = spy;
return spy;
}

}

}

private abstract static class CommonSecurityFilterChainConfig {

SecurityFilterChain configureFilterChain(HttpSecurity http) throws Exception {
Expand Down