Spring Boot + Amazon Cognito OAuth 2.0 / JWT

I recently had to implement Amazon Cognito JWT authentication with Spring Boot, and since I wasn´t able to find any clean & simple guides with quick googling, I decided to write my own. In this blog post, I´m going to walk you through the steps of the process I used.

We are going to implement a Spring boot application that is able to authenticate the user against Amazon Cognito using OAuth 2.0 authorization code grant and JSON Web Tokens. All code examples are written in Kotlin.

This post is not going to cover Cognito itself. I expect you to know what Amazon Cognito is and how to configure it.

NOTE: This is a practical guide with lots of code examples. For the sake of simplicity, the code contains only the necessary components for the authentication to work, and you are expected to add more features for it to be safe & efficient in a production environment.

Ok, so let’s get started.

Settings & dependencies

First, we need to set up some dependencies. Spring boot starter web & security are pretty obvious, and Nimbus JOSE + JWT is a library which we are going to use to handle the JSON Web Tokens.

build.gradle
1
2
3
4
5
compile('org.springframework.boot:spring-boot-starter-web' )
compile('org.springframework.boot:spring-boot-starter-security')
compile('com.nimbusds:nimbus-jose-jwt:5.12')

 

Next, let´s define some properties:

application-properties.yml
1
2
3
4
5
6
7
8
9
10
11
12
urls:
  cognito:    # cognito root auth url
endpoints:
  authorize: ${urls.cognito}/oauth2/authorize?response_type=code&client_id=${cognito.client}&redirect_uri=${cognito.callback}
  token: ${urls.cognito}/oauth2/token
cognito:
  client:     # cognito client id
  secret:     # cognito client secret
  callback:   # valid callback url set in cognito
  keys:       # url for cognito jwt keys

Here we specify:

  • Base URL for Cognito authentication
  • Endpoint URLs for authorization and token requests
  • Cognito client_id
  • Cognito client_secret
  • Cognito callback_uri
  • URL of Cognito public keys

You´ll get all these values from your Cognito configuration.

Authentication

The next step is to define a processor bean for tokens and configure it to use the specified keys URL as a key source. This bean is responsible for processing and verifying the token, and extracting the authentication details. Most of the work is done under the hood, so not much manual configuration is needed at this point. Basically, we just need to set the key source and algorithm (which is RS256 in this example).

example).

JwtProcessor.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
@Configuration
class JwtProcessor {
    @Value("\${cognito.keys}")
    private val keySource: String = ""
    @Bean
    fun configurableJWTProcessor(): ConfigurableJWTProcessor<*> {
        val resourceRetriever = DefaultResourceRetriever(50005000)
        val jwkSetURL = URL(keySource)
        val keySource: JWKSource<SecurityContext> = RemoteJWKSet(jwkSetURL, resourceRetriever)
        val jwtProcessor: ConfigurableJWTProcessor<SecurityContext> = DefaultJWTProcessor()
        val keySelector = JWSVerificationKeySelector(JWSAlgorithm.RS256, keySource)
        jwtProcessor.setJWSKeySelector(keySelector)
        return jwtProcessor
    }
}

We also need to set up a filter which filters all our authenticated requests, extracts the token from headers, and sends it for processing. The filter is also responsible for denying any requests that don´t contain a valid token. We first try to extract the token from the Authorization header and then extract the actual authentication and claims. If the token is valid we then manually set the Spring Security Context and let the request go forward. We´ll also catch any exceptions thrown by the processor in case the token is not valid, and respond with 401.

For the sake of simplicity in this tutorial, we implement a quick and dirty CognitoAuthenticationToken class for spring security:

CognitoAuthenticationToken.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class CognitoAuthenticationToken(
    private val token: String,
    details: JWTClaimsSet,
    authorities: List<GrantedAuthority> = listOf()
) : AbstractAuthenticationToken(authorities) {
    init {
        setDetails(details)
        isAuthenticated = true
    }
    override fun getCredentials(): Any {
        return token
    }
    override fun getPrincipal(): Any {
        return details
    }
}

And then the actual filter:

AuthFilter.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class AuthFilter(
    val processor: ConfigurableJWTProcessor<SecurityContext>,
    authenticationManager: AuthenticationManager
) : BasicAuthenticationFilter(authenticationManager) {
    override fun doFilterInternal(
        req: HttpServletRequest,
        res: HttpServletResponse,
        chain: FilterChain
    ) {
        try {
            val token = extractToken(req.getHeader("Authorization"))
            val authentication = extractAuthentication(token)
            SecurityContextHolder.getContext().authentication = authentication
            chain.doFilter(req, res)
        catch (e: AccessDeniedException) {
            LoggerFactory.getLogger(this.javaClass.simpleName).error("Access denied: ${e.message ?: "No message"}")
            res.status = 401
            res.writer.write("Access denied")
        }
    }
    /**
     * Extract token from header
     */
    private fun extractToken(header: String?): String? {
        val headers = header?.split("Bearer ")
        return if (headers == null || headers.size < 2) {
            null
        else {
            headers[1]
        }
    }
    /**
     * Extract authentication details from token
     */
    @Throws(AccessDeniedException::class)
    private fun extractAuthentication(token: String?): CognitoAuthenticationToken? {
        if (token == null)
            return null
        return try {
            val claims = processor.process(token, null)
            CognitoAuthenticationToken(token, claims)
        catch (e: Exception) {
            throw AccessDeniedException("${e.javaClass.simpleName} (${e.message ?: "No message"})")
        }
    }
}

Now that we have working processor and filter, we can implement the configuration for spring security as follows (note that we want our /auth -endpoints to be unprotected since they are used for the actual authentication requests):

AuthConfig.kt
1
2
3
4
5
6
7
8
9
10
11
12
@EnableWebSecurity
class AuthConfig(val processor: ConfigurableJWTProcessor<SecurityContext>) : WebSecurityConfigurerAdapter() {
    override fun configure(http: HttpSecurity) {
        http
            .authorizeRequests()
            .antMatchers("/auth/**").permitAll()
            .anyRequest().authenticated()
            .and()
            .addFilter(AuthFilter(processor, authenticationManager()))
            .sessionManagement().sessionCreationPolicy(SessionCreationPolicy.STATELESS)
    }
}

Next, we need to define the user endpoints for authentication requests. We need two endpoints: one for redirecting the user to the Cognito login form (which after successful login redirects the user to callback uri with authorization code), and other for retrieving the actual token with the authorization code. This way the client does not need to know almost anything about Cognito beforehand. Redirect to login form is handled with a basic redirect response, and the token is retrieved by sending a POST request to Cognito´s /oauth2/token -endpoint with the authorization code and client id. We´ll set up a service for that. We also need a model for Cognito JWT:

CognitoJWT.kt
1
2
3
4
5
6
7
data class CognitoJWT(
    val id_token: String = "",
    val access_token: String = "",
    val refresh_token: String = "",
    val expires_in: Int = 0,
    val token_type: String = ""
)
AuthService.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
@Component
class AuthService {
    @Value("\${endpoints.token}")
    private val tokenUrl: String = ""
    @Value("\${cognito.client}")
    private val clientId: String = ""
    @Value("\${cognito.secret}")
    private val clientSecret: String = ""
    @Value("\${cognito.callback}")
    private val callbackUrl: String = ""
    /**
     * Get token with authorization code
     */
    fun getToken(code: String): CognitoJWT? {
        val client = RestTemplate()
        val headers = LinkedMultiValueMap<String, String>()
        val auth = "$clientId:$clientSecret".toBase64()
        headers.add("HeaderName""value")
        headers.add("Authorization""Basic $auth")
        headers.add("Content-Type""application/x-www-form-urlencoded")
        val req = HttpEntity<Nothing?>(null, headers)
        val url = "$tokenUrl?grant_type=authorization_code&client_id=$clientId&code=$code&redirect_uri=$callbackUrl"
        return  client.postForObject(url, req, CognitoJWT::class.java)
    }
}
AuthController.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
@RestController
@RequestMapping("/auth")
class AuthController(val authService: AuthService) {
    @Value("\${endpoints.authorize}")
    private val authorizeUrl: String = ""
    /**
     * Redirect user to correct url for authorization code
     */
    @GetMapping("/login")
    fun login(): ResponseEntity<Any> =
        ResponseEntity
            .status(HttpStatus.SEE_OTHER)
            .header(HttpHeaders.LOCATION, authorizeUrl)
            .build()
    /**
     * Get aws tokens with authorization code
     */
    @GetMapping("/token")
    fun token(@RequestParam("code") code: String): CognitoJWT? =
        authService.getToken(code)
}

Accessing the claims

Now all that is left is to access the token claims inside the application. We´ll specify a model class for the claims and update our AuthService with a method for extracting these claims from security context (note that the claims must, of course, match the ones you´ve set up when configuring Cognito).

TokenClaims.kt
1
2
3
4
5
6
7
8
9
data class TokenClaims(
    val uuid: String,
    val auth_time: Long,
    val issued: Date,
    val expire: Date,
    val name: String,
    val cognitoUserName: String,
    val email: String
)
AuthService.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
fun getClaims(): TokenClaims {
    val authentication = SecurityContextHolder.getContext().authentication
    val details = authentication.details as JWTClaimsSet
    return TokenClaims(
            uuid = details.getStringClaim("sub"),
            auth_time = details.getClaim("auth_time") as Long,
            issued = details.getClaim("iat") as Date,
            expire = details.getClaim("exp") as Date,
            name = details.getStringClaim("name"),
            cognitoUserName = details.getStringClaim("cognito:username"),
            email = details.getStringClaim("email")
    )
}

To access the claims anywhere inside spring context, we simply inject our AuthService class and use the getClaims method. Let´s test it by writing a basic controller which returns the extracted claims back to the client:

UserController.kt
1
2
3
4
5
6
7
8
@RestController
@RequestMapping("/user")
class UserController(val authService: AuthService) {
    @GetMapping("/me")
    fun getCurrentUser(): TokenClaims {
        return authService.getClaims()
    }
}

Logging in and accessing protected endpoints

You should now be able to log in and access the protected /user/me -endpoint.

 

Steps for logging in and using protected endpoints:
1. Send GET request to /auth/login.

2. Follow the redirect and login to Cognito to get the authorization code.

3. Send a GET request to /auth/token?code={code}, and copy the id_token parameter from the response.

4. Send a GET request to /user/me containing Authorization header with value ‘Bearer {id_token}’.

 

Code for a fully working demo is available at:

https://github.com/akselip/spring-cognito-demo

Leave a Reply

Your email address will not be published. Required fields are marked *

Liity joukkoon

DevOps Specialist

Madrid, Helsinki, Tampere

Backend Developer

Helsinki, Jyväskylä, Tampere

Frontend Developer

Helsinki, Tampere, Jyväskylä