Introduction
As technology evolves and becomes more prevalent - including the evolution of large-scale service-oriented architectures, managing web security becomes more and more complex. There are many more edge cases now than there were before, and keeping personal user information secure and safe is becoming increasingly hard. Without proactive security measures, businesses risk leaking sensitive information, and in the day and age of information - this can become a huge issue for users online.
This is why security must come first, and not as an afterthought, while building applications.
Many users end up creating many different accounts through various browsers and devices, which means we also have to consider and keep track of various devices users use to log in, lest we end up locking them out of their own account by accident, thinking someone gained unauthorized access, while in reality - the user just went on a trip and used their phone on the hotel's Wi-Fi.
In this guide - we'll look into the common proactive security strategy of invalidating a JWT token when a user logs out of a system, from a specific device.
Note: This guide assumes you've already got Spring Security Authentication set up, and aims to provide guidance on invalidating JWT tokens, in an implementation-agnostic way. Whether you've defined your own roles and authorities or used Spring's GrantedAuthority
, your own User
or relied on Spring's UserDetails
won't matter much. That being said - some of the underlying filters, classes and configurations will not be available in the guide itself, as might differ for your application.
If you'd like to consult the specific implementation used in this guide, including all of the configuration that's not shown here, you can access the full source code on GitHub.
Spring Security
Spring Security is a simple yet powerful framework that enables a software engineer to impose security restrictions on Spring-based web applications through various JEE components. It is an easy-to-extend and customizable framework that centers around the provision of authentication and access-control facilities for Spring-based applications.
At its core, it takes care of three main hurdles:
- Authentication: Checks whether the user is the right person to access some restricted resources. It takes care of two basic processes: identification (who the user is) and verification (whether the user is who they claim to be).
- Authorization: Ensures that a user is allowed access to only those parts of the resource that one has been authorized to use via a combination of Roles and Permissions.
- Servlet Filters: Any Spring web application is just one servlet that redirects incoming HTTP requests to
@Controller
or@RestController
. Since there is no security implementation inside the mainDispatcherServlet
, you need filters likeSecurityFilter
in front of servlets so that Authentication and Authorization are being taken care of before redirecting to Controllers.
Note: It's worth noting that some use the terms "Role" and "Permission" interchangeably, which can be a bit confusing to learners. Roles have a set of permissions. An Admin (Role) may have permissions to perform X and Y, while an Engineer may have permissions to perform Y and Z.
JSON Web Tokens
A JWT (JSON Web Token) is a token that facilitates the stateless approach of handling user authentication. It helps perform authentication without storing its state in the form of a session or a database object. When the server tries to authenticate a user, it does not access the user's session or perform a database query of any kind. This token is generated with the help of a user entity payload and internal objects known as claims and is used by clients to identify the user on the server.
A JWT is composed of the following structure:
header.payload.signature
- Header: Contains all relevant info about how a token can be interpreted or is signed.
- Payload: Contains claims in the form of a user or entity data object. Usually, there are three types of claims: Registered, Public and Private claims.
- Signature: Composed of the header, payload, a secret and the encoding algorithm. All of the contents are signed and some of them encoded by default.
If you'd like to read more about JWTs read our guide on Understanding JSON Web Tokens (JWT).
JSON Web Token Lifecycle
Let's take a look at the classic JWT lifecycle - from the moment a user tries logging in:
In the diagram, the client passes their user credentials in the form of a request to the server. The server, after performing identification and verification, returns a JWT token as a response. The client will henceforth use this JWT token to request access to secured endpoints.
Typically, the user will try to access some secure endpoint or resource after logging in:
This time around though, the client passes the JWT token it acquired before with the request to access secured data. The server will introspect the token and perform stateless authentication and authorization and provide access to secured content which is sent back as a response.
Finally, once the user is done with the application, they'll typically log out:
If the user wants to log out of the system, the client would ask the server to log the user out of a specific device and invalidate all his active sessions. While doing that server would be able to close all the user sessions but it won't be able to invalidate the JWT token as it's stateless and an immutable object.
This can quickly become a problem - when a user logs out, the JWT token has to be invalidated for further use. Furthermore, if someone tries to access a restricted resource with an invalidated token, they shouldn't be allowed access, with a mechanism to recover from this exceptional state.
How can we invalidate tokens? We can make them expire quickly, blacklist expired/removed tokens and/or rotate them via a refresh token issued alongside the JWT.
Let's go ahead and set up Spring Security to perform in-memory invalidation of JWT tokens, when a user logs out.
Spring Boot and Spring Security Setup
Now that we've sorted out JWTs and the main issue - let's initialize a simple Spring Boot application and set it up. The easiest way to start with a skeleton project is via Spring Initializr:
We've added the Spring Security dependency because we'd like to include and leverage the module to handle security for us. We've also included the Spring Web and Spring Data JPA modules since we're ultimately creating a web application that has a persistence layer. The use of Lombok is optional, as it's a convenience library that helps us reduce boilerplate code such as getters, setters and constructors, just by annotating our entities with Lombok annotations.
We'll also need to import a few extra dependencies, which aren't available on Spring's initializer. Namely, we'll import the JWT Library, as well as the Expiring Map library. Expiring Map introduces us to a high performance, thread-safe ConcurrentMap implementation that expires entries, which we'll be utilizing to expire certain tokens:
<!--Jwt-->
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt</artifactId>
<version>0.9.1</version>
</dependency>
<!--Expiring Map-->
<dependency>
<groupId>net.jodah</groupId>
<artifactId>expiringmap</artifactId>
<version>0.5.9</version>
</dependency>
Implementing a Spring Boot Web Application
Mapping Devices to Users while Logging In
Users are increasingly logging into systems through different devices. A generic and common scenario is a user logging in through a desktop website, and a smartphone. By default, in both cases, the back-end will generate the same JWT token for a given email, since the email is the identifier. Once the user logs out of the application on their desktop, it'll also log them out from their phone.
A way to solve this, if it's not the functionality you envisioned, is to pass the device information when sending the login request, along with the username and password. To generate a unique ID from the device the first time a user tries logging in, we can leverage the Fingerprint.js library from the frontend client.
We'll want to map multiple devices to a user, since a user might use more than one device, so we'll need a mechanism to map a device to a user login session. We'll also want to generate a refresh token to maintain the same user session (refreshing the expiration) as long as they're logged in. Once they're logged out, we can let the JWT token expire, and invalidate it.
That being said, we'll need to map a device as well as the refresh token to a user's session.
Since we've got a mechanism to identify devices - let's implement the functionality to map a user device to a user login session. We will also need to generate the refresh token to maintain the same user session throughout. So we will also talk about how we can map a refresh token with the user device to the user session.
Domain Model - Defining Entities
Let's start off with the domain model and the entities we'll be using. Namely, let's start with the User
and UserDevice
:
// Lombok annotations for getters, setters and constructor
@Entity
public class User {
@Id
@GeneratedValue(strategy = GenerationType.SEQUENCE, generator = "user_seq")
private Long id;
private String email;
private String password;
private String name;
private Boolean active;
@ManyToMany(fetch = FetchType.LAZY)
@JoinTable(name = "user_roles", joinColumns = @JoinColumn(name = "user_id"), inverseJoinColumns = @JoinColumn(name = "role_id"))
private Set<Role> roles = new HashSet<>();
public void activate() {
this.active = true;
}
public void deactivate() {
this.active = false;
}
}
This User
will use some sort of device to send a login request. Let's define the UserDevice
model as well:
// Lombok annotations for getters, setters and constructor
@Entity
public class UserDevice {
@Id
@GeneratedValue(strategy = GenerationType.SEQUENCE, generator = "user_device_seq")
private Long id;
private User user;
private String deviceType;
private String deviceId;
@OneToOne(optional = false, mappedBy = "userDevice")
private RefreshToken refreshToken;
private Boolean isRefreshActive;
}
Finally, we'll also like to have a RefreshToken
for each device:
// Lombok annotations
@Entity
public class RefreshToken {
@Id
@GeneratedValue(strategy = GenerationType.SEQUENCE, generator = "refresh_token_seq")
private Long id;
private String token;
@OneToOne(optional = false, cascade = CascadeType.ALL)
@JoinColumn(name = "USER_DEVICE_ID", unique = true)
private UserDevice userDevice;
private Long refreshCount;
private Instant expiryDate;
public void incrementRefreshCount() {
refreshCount = refreshCount + 1;
}
}
Data Transfer Objects - Defining Request Payload
Now, let's define the Data Transfer Objects for the incoming API request payload. We'll need a DeviceInfo
DTO that'll simply contain the deviceId
and deviceType
for our UserDevice
model. We'll also have a LoginForm
DTO, that contains the user's credentials and the DeviceInfo
DTO.
Using both of these allows us to send the minimally required information to authenticate a user given their device, and map the device to their session:
// Lombok annotations
public class DeviceInfo {
// Payload Validators
private String deviceId;
private String deviceType;
}
// Lombok annotations
public class LoginForm {
// Payload Validators
private String email;
private String password;
private DeviceInfo deviceInfo;
}
Let's also create the JWTResponse
payload that contains all the tokens and expiry duration. This is the generated response from the server to the client which is used to verify a client and can be utilized further to make requests to secure endpoints:
// Lombok annotations
public class JwtResponse {
private String accessToken;
private String refreshToken;
private String tokenType = "Bearer";
private Long expiryDuration;
}
Since we have defined two new entities, UserDevice
and RefreshToken
, let's define their repositories so we can perform CRUD operations on these entities.
Persistence Layer - Defining Repositories
public interface UserDeviceRepository extends JpaRepository<UserDevice, Long> {
@Override
Optional<UserDevice> findById(Long id);
Optional<UserDevice> findByRefreshToken(RefreshToken refreshToken);
Optional<UserDevice> findByUserId(Long userId);
}
public interface RefreshTokenRepository extends JpaRepository<RefreshToken, Long> {
@Override
Optional<RefreshToken> findById(Long id);
Optional<RefreshToken> findByToken(String token);
}
Service Layer - Defining Services
Now, we'll want to have middleman services interfacing the controllers that allow us to use the repositories. Let's create the Service Layer to handle the CRUD operation requests for the UserDevice
and RefreshToken
entities:
@Service
public class UserDeviceService {
// Autowire Repositories
public Optional<UserDevice> findByUserId(Long userId) {
return userDeviceRepository.findByUserId(userId);
}
// Other Read Services
public UserDevice createUserDevice(DeviceInfo deviceInfo) {
UserDevice userDevice = new UserDevice();
userDevice.setDeviceId(deviceInfo.getDeviceId());
userDevice.setDeviceType(deviceInfo.getDeviceType());
userDevice.setIsRefreshActive(true);
return userDevice;
}
public void verifyRefreshAvailability(RefreshToken refreshToken) {
UserDevice userDevice = findByRefreshToken(refreshToken)
.orElseThrow(() -> new TokenRefreshException(refreshToken.getToken(), "No device found for the matching token. Please login again"));
if (!userDevice.getIsRefreshActive()) {
throw new TokenRefreshException(refreshToken.getToken(), "Refresh blocked for the device. Please login through a different device");
}
}
}
@Service
public class RefreshTokenService {
// Autowire Repositories
public Optional<RefreshToken> findByToken(String token) {
return refreshTokenRepository.findByToken(token);
}
// other CRUD methods
public RefreshToken createRefreshToken() {
RefreshToken refreshToken = new RefreshToken();
refreshToken.setExpiryDate(Instant.now().plusMillis(3600000));
refreshToken.setToken(UUID.randomUUID().toString());
refreshToken.setRefreshCount(0L);
return refreshToken;
}
public void verifyExpiration(RefreshToken token) {
if (token.getExpiryDate().compareTo(Instant.now()) < 0) {
throw new TokenRefreshException(token.getToken(), "Expired token. Please issue a new request");
}
}
public void increaseCount(RefreshToken refreshToken) {
refreshToken.incrementRefreshCount();
save(refreshToken);
}
}
With these two, we can go ahead and focus on the controllers.
Controllers
With our entities defined, their repositories and services ready, and DTOs for these entities ready to transfer data, we can finally create a controller for signing in. During the sign-in process, we'll generate a UserDevice
and RefreshToken
for the user, as well as map them to the user's session.
Check out our hands-on, practical guide to learning Git, with best-practices, industry-accepted standards, and included cheat sheet. Stop Googling Git commands and actually learn it!
Once we save these to the database, we can return a JwtResponse
containing these tokens and expiry information to the user:
@PostMapping("/signin")
public ResponseEntity<?> authenticateUser(@Valid @RequestBody LoginForm loginRequest) {
User user = userRepository.findByEmail(loginRequest.getEmail())
.orElseThrow(() -> new RuntimeException("Fail! -> Cause: User not found."));
if (user.getActive()) {
Authentication authentication = authenticationManager.authenticate(
new UsernamePasswordAuthenticationToken(
loginRequest.getEmail(),
loginRequest.getPassword()
)
);
SecurityContextHolder.getContext().setAuthentication(authentication);
String jwtToken = jwtProvider.generateJwtToken(authentication);
userDeviceService.findByUserId(user.getId())
.map(UserDevice::getRefreshToken)
.map(RefreshToken::getId)
.ifPresent(refreshTokenService::deleteById);
UserDevice userDevice = userDeviceService.createUserDevice(loginRequest.getDeviceInfo());
RefreshToken refreshToken = refreshTokenService.createRefreshToken();
userDevice.setUser(user);
userDevice.setRefreshToken(refreshToken);
refreshToken.setUserDevice(userDevice);
refreshToken = refreshTokenService.save(refreshToken);
return ResponseEntity.ok(new JwtResponse(jwtToken, refreshToken.getToken(), jwtProvider.getExpiryDuration()));
}
return ResponseEntity.badRequest().body(new ApiResponse(false, "User has been deactivated/locked !!"));
}
Here, we've verified that the user with the given email exists, throwing an exception if not. If the user is indeed active, we authenticate the user given their credentials. Then, using the JwtProvider
(see GitHub, assuming you don't have your own JWT Provider already implemented), we generate the JWT token for the user, based on the Spring Security Authentication
.
If there's already a RefreshToken
associated with the user's session, it's deleted as we're currently forming a new session.
Finally, we create a user device via the UserDeviceService
and generate a new refresh token for the user, saving both to the database, and return a JwtResponse
containing the jwtToken
, refreshToken
and the expiry duration used to expire a user's session. Otherwise, we return a badRequest()
, since the user is no longer active.
To refresh the JWT Token for as long as the user is actually using the application, we'll periodically be sending a refresh request:
public class TokenRefreshRequest {
@NotBlank(message = "Refresh token cannot be blank")
private String refreshToken;
// Getters, Setters, Constructor
}
Once sent, we'll verify that a token does exist in the database, and if it does - verify the expiration and the refresh availability. If the session can be refreshed, we refresh it and otherwise prompt the user to log in again:
@PostMapping("/refresh")
public ResponseEntity<?> refreshJwtToken(@Valid @RequestBody TokenRefreshRequest tokenRefreshRequest) {
String requestRefreshToken = tokenRefreshRequest.getRefreshToken();
Optional<String> token = Optional.of(refreshTokenService.findByToken(requestRefreshToken)
.map(refreshToken -> {
refreshTokenService.verifyExpiration(refreshToken);
userDeviceService.verifyRefreshAvailability(refreshToken);
refreshTokenService.increaseCount(refreshToken);
return refreshToken;
})
.map(RefreshToken::getUserDevice)
.map(UserDevice::getUser)
.map(u -> jwtProvider.generateTokenFromUser(u))
.orElseThrow(() -> new TokenRefreshException(requestRefreshToken, "Missing refresh token in database. Please login again")));
return ResponseEntity.ok().body(new JwtResponse(token.get(), tokenRefreshRequest.getRefreshToken(), jwtProvider.getExpiryDuration()));
}
What Happens When We Log Out?
Now we can try logging out of the system. One of the easiest options the client can try out is delete the token from the browser local or session storage so that the token is not forwarded to backend APIs to request access. But will that be enough? Although the user won't be able to log in from the client, that token is still active and can be used to access the APIs. So we need to invalidate the user session from the backend.
Remember how we mapped the user device and refresh token objects to manage the session? We can easily delete that record from the DB so that the backend won't find any active session of the user.
Now we should again ask the question of Is that really enough? Someone can still have the JWT and can use it to get authenticated since we just invalidated the session. We need to invalidate the JWT token as well so that it can't be misused. But wait, aren't JWTs stateless and immutable objects?
Well, it proves that you cannot manually expire a JWT token that has already been created. So one of the implementations to invalidate a JWT token would be to create an in-memory store called a "blacklist", which can store all the tokens that are no longer valid but have not expired yet.
We can use a datastore that has TTL (Time to live) options which can be set to the amount of time left until the token is expired. Once the token expires, it's removed from the memory, finally invalidating the token for good.
Note: Redis or MemcachedDB can serve our purpose, but we are looking for a solution that can store data in-memory, and do not want to introduce yet another persistent storage.
This is exactly why we've added the Expiring Map dependency earlier. It expires entries and the server can cache the tokens with a TTL into the expiring map:
Each time we try to access a secured endpoint, the JWTAuthenticationFilter
can additionally check if the token is present in the blacklisted/cached map or not. This way, we can also invalidate an immutable JWT token which is going to expire sometime soon, but hasn't already:
Blacklisting JWT Tokens Before They Expire
Let's implement the logic to cache each non-expired token on a logout request into an ExpiringMap
where the TTL for each token will be the number of seconds that remain until expiration. To prevent the cache from building up indefinitely, we'll also set a max size:
@Component
public class LoggedOutJwtTokenCache {
private ExpiringMap<String, OnUserLogoutSuccessEvent> tokenEventMap;
private JwtProvider tokenProvider;
@Autowired
public LoggedOutJwtTokenCache(JwtProvider tokenProvider) {
this.tokenProvider = tokenProvider;
this.tokenEventMap = ExpiringMap.builder()
.variableExpiration()
.maxSize(1000)
.build();
}
public void markLogoutEventForToken(OnUserLogoutSuccessEvent event) {
String token = event.getToken();
if (tokenEventMap.containsKey(token)) {
logger.info(String.format("Log out token for user [%s] is already present in the cache", event.getUserEmail()));
} else {
Date tokenExpiryDate = tokenProvider.getTokenExpiryFromJWT(token);
long ttlForToken = getTTLForToken(tokenExpiryDate);
logger.info(String.format("Logout token cache set for [%s] with a TTL of [%s] seconds. Token is due expiry at [%s]", event.getUserEmail(), ttlForToken, tokenExpiryDate));
tokenEventMap.put(token, event, ttlForToken, TimeUnit.SECONDS);
}
}
public OnUserLogoutSuccessEvent getLogoutEventForToken(String token) {
return tokenEventMap.get(token);
}
private long getTTLForToken(Date date) {
long secondAtExpiry = date.toInstant().getEpochSecond();
long secondAtLogout = Instant.now().getEpochSecond();
return Math.max(0, secondAtExpiry - secondAtLogout);
}
}
We also need to define a Data Transfer Object for the client to send when they'd like to log out:
// Lombok annotations
public class LogOutRequest {
private DeviceInfo deviceInfo;
private String token;
}
We will also need to define an Event Listener to listen for a logout event so that it can immediately mark the token to be cached into the blacklist. So let's define the event OnUserLogoutSuccessEvent
and event listener OnUserLogoutSuccessEventListener
:
// Lombok annotations
public class OnUserLogoutSuccessEvent extends ApplicationEvent {
private static final long serialVersionUID = 1L;
private final String userEmail;
private final String token;
private final transient LogOutRequest logOutRequest;
private final Date eventTime;
// All Arguments Constructor with modifications
}
@Component
public class OnUserLogoutSuccessEventListener implements ApplicationListener<OnUserLogoutSuccessEvent> {
private final LoggedOutJwtTokenCache tokenCache;
@Autowired
public OnUserLogoutSuccessEventListener(LoggedOutJwtTokenCache tokenCache) {
this.tokenCache = tokenCache;
}
public void onApplicationEvent(OnUserLogoutSuccessEvent event) {
if (null != event) {
DeviceInfo deviceInfo = event.getLogOutRequest().getDeviceInfo();
logger.info(String.format("Log out success event received for user [%s] for device [%s]", event.getUserEmail(), deviceInfo));
tokenCache.markLogoutEventForToken(event);
}
}
}
Finally, in the JWTProvider
, we'll add a check to validate a JWT Token to perform an extra check to see if the incoming token is present in the blacklist or not:
public boolean validateJwtToken(String authToken) {
try {
Jwts.parser().setSigningKey("HelloWorld").parseClaimsJws(authToken);
validateTokenIsNotForALoggedOutDevice(authToken);
return true;
} catch (MalformedJwtException e) {
logger.error("Invalid JWT token -> Message: {}", e);
} catch (ExpiredJwtException e) {
logger.error("Expired JWT token -> Message: {}", e);
} catch (UnsupportedJwtException e) {
logger.error("Unsupported JWT token -> Message: {}", e);
} catch (IllegalArgumentException e) {
logger.error("JWT claims string is empty -> Message: {}", e);
}
return false;
}
private void validateTokenIsNotForALoggedOutDevice(String authToken) {
OnUserLogoutSuccessEvent previouslyLoggedOutEvent = loggedOutJwtTokenCache.getLogoutEventForToken(authToken);
if (previouslyLoggedOutEvent != null) {
String userEmail = previouslyLoggedOutEvent.getUserEmail();
Date logoutEventDate = previouslyLoggedOutEvent.getEventTime();
String errorMessage = String.format("Token corresponds to an already logged out user [%s] at [%s]. Please login again", userEmail, logoutEventDate);
throw new InvalidTokenRequestException("JWT", authToken, errorMessage);
}
}
Running In-Memory Invalidation of JWT Tokens
Finally, with the implementation done - we can take a look at the user's session cycle and see what happens when we log in and then log out - we will sign up, log in, refresh our tokens and then log out of the system. Finally we will try to access a secured endpoint using a previously generated JWT token and see what happens.
Henceforth, we'll be using Postman to test the functionality of our API. If you're unfamiliar with Postman - read our guide on Getting Started with Postman.
Let's first sign up a new user, Adam Smith, as the administrator on our application:
It's critical that the JWT gets invalidated after the administrator logs out, as a malicious user could gain destructive authority over the application if they snatch the JWT before expiry.
Naturally, Adam will want to log into the application:
The server responds with an accessToken
(JWT), a refreshToken
and the expiryDuration
. Since Adam has a lot of work to do on the app, he might want to refresh the JWT token assigned to him at some point to extend his access while he's still online.
This is done by passing the Access Token from above as a Bearer Token in Authorization:
Finally, Adam logs out of the application, passing the device info and access token to do so:
Once unauthorized, let's try to hit the /users/me
endpoint with previously used JWT token even though it hasn't expired yet, to see if we can access or not:
The API throws 401 Unauthorized
error, since the JWT token is now in the cached blacklist.
Conclusion
As you can see, the logout flow using JSON Web Tokens isn't so straightforward. We must follow few best practices to accommodate a few scenarios:
- Define an affordable expiration time on tokens. It's often advised to keep the expiry time as low as possible, as to not overfill the blacklist with a lot of tokens.
- Delete the token that's stored within the browser local or session storage.
- Use an in-memory or high performance TTL-based store to cache the token which is yet to expire.
- Query against the blacklisted token on every authorized request call.
As mentioned in the beginning of the guide, you can find the full source code in <a rel=”nofollow noopener” target=”_blank” href=”https://github.com/arpendu11/spring-security-jwt-jpa”>GitHub.