Store token in redis when use spring-authorization-server

2k views Asked by At

Is it possible to store tokens for example in Redis when use spring-authorization-server. In spring-security-oauth we can define TokenStore:

@Bean
public TokenStore redisTokenStore() {
    RedisTokenStore redisTokenStore = new RedisTokenStore(redisConnectionFactory);
    redisTokenStore.setPrefix(redisTokenPrefix);
    return redisTokenStore;
}

In theory, i can implement interface OAuth2AuthorizationService but maybe there is an easier and more beautiful solution

2

There are 2 answers

0
Steve Riesenberg On

Implementing OAuth2AuthorizationService is the correct way to do this. There is no built-in support for an integration with Redis. See #558 for more info.

0
shashank joshi On

I have done the following to store the token information into Redis cluster (I am using spring authorization server 1.2.3):

MyTokenAuthService: This class implements OAuth2AuthorizationService and calls the cache methods to store OAuth2Authorization object.

JedisRefreshAccessToken: The cache creates a byte array using java ObjectOutputStream to create bytes to store in redis and java ObjectInputStream to get it back from stored bytes in Redis cache.

MyJedis: This class uses JedisCluster class to write and read byte array from to Redis cache.

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.stereotype.Component;

import in.org.cris.superapp.authserver.cache.CRISJedisRefreshAccessToken;




@Component

   public class MyTokenAuthService implements OAuth2AuthorizationService {

    @Autowired JedisRefreshAccessToken oauth2TokenCache;
    private static final Logger logger = LoggerFactory.getLogger(MyTokenAuthService.class);
    @Override
    public void save(OAuth2Authorization authorization) {
        logger.info(" save");
        oauth2TokenCache.saveToken(authorization);
    }
    @Override
    public void remove(OAuth2Authorization authorization) {
        logger.info(" remove " + authorization.getId());
        oauth2TokenCache.removeToken(authorization.getId());
    }
    @Override
    public OAuth2Authorization findById(String id) {
        logger.info(" findById " + id);
        return oauth2TokenCache.findByIdFromCache(id);
    }
    @Override
    public OAuth2Authorization findByToken(String token, OAuth2TokenType tokenType) {
        logger.info(" findByToken " + token + " " + tokenType.getValue());
        return oauth2TokenCache.findByToken(token, tokenType);
    }
}

import java.io.ByteArrayInputStream;
    import java.io.ByteArrayOutputStream;
    import java.io.ObjectInputStream;
    import java.io.ObjectOutputStream;
    
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.security.oauth2.core.OAuth2AccessToken;
    import org.springframework.security.oauth2.core.OAuth2DeviceCode;
    import org.springframework.security.oauth2.core.OAuth2RefreshToken;
    import org.springframework.security.oauth2.core.OAuth2UserCode;
    import org.springframework.security.oauth2.core.oidc.OidcIdToken;
    import org.springframework.security.oauth2.jwt.Jwt;
    import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
    import org.springframework.security.oauth2.server.authorization.OAuth2Authorization.Token;
    import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode;
    import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
    import org.springframework.stereotype.Component;
    
@Component

public class JedisRefreshAccessToken {

    private final transient static Logger logger = LoggerFactory.getLogger(JedisRefreshAccessToken.class);

    @Autowired MyJedis cache; // It has JedisCluster like methods (get, set with expire seconds). 

    static final String keyPrefix = "Oauth2Token_";
    private static String cacheKey(String id) {
        return keyPrefix + id;
    }
    private static String cacheKey_init(String id) {
        return keyPrefix + "init_" + id;
    }
    private static String cacheTokenKey(String token) {
        return keyPrefix + token;
    }

    public void removeToken(final String id) {
        OAuth2Authorization auth = findByIdFromCache(id);
        if(auth == null) {
            logger.info("No token found to remove. Id " + id);
            return;
        }
        cache.unlink(cacheTokenKey(auth.getAccessToken().getToken().getTokenValue()));
        cache.unlink(cacheTokenKey(auth.getRefreshToken().getToken().getTokenValue()));
        cache.unlink(cacheKey(auth.getId()).getBytes());
        cache.unlink(cacheKey_init(auth.getId()).getBytes());
    }

    public OAuth2Authorization findByIdFromCache(final String id) {
        final String key = cacheKey(id);
        byte[] v = cache.get(key.getBytes());
        if(v == null || v.length == 0) {
            v = cache.get(cacheKey_init(id).getBytes());
        }
        if(v == null || v.length == 0) {
            return null;
        }
        return readFromBytes(v);
    }

    public void saveToken(final OAuth2Authorization auth) {
        final boolean isComplete = auth.getAccessToken() != null;
        final String key = isComplete ? cacheKey(auth.getId()) : cacheKey_init(auth.getId());
        byte[] v = getBytes(auth);
        logger.info("saveToken Id " + auth.getId());
        cache.set(key.getBytes(), v, 100*30*60);
        storeTokensByValue(auth);
    }
    private void storeTokensByValue(final OAuth2Authorization auth) {
        Token<Jwt> jwt = auth.getToken(Jwt.class);
        Jwt jwtAccessToken = jwt!=null ? jwt.getToken() : null;
        if(jwtAccessToken != null) {
            cache.set(cacheTokenKey(jwtAccessToken.getTokenValue()), auth.getId(), 30*60);
        }

        Token<OAuth2AuthorizationCode> authCodeWr = auth.getToken(OAuth2AuthorizationCode.class);
        OAuth2AuthorizationCode authCodeToken = authCodeWr!=null ? authCodeWr.getToken() : null;
        if(authCodeToken != null) {
            cache.set(cacheTokenKey(authCodeToken.getTokenValue()), auth.getId(), 300);
        }

        Token<OAuth2AccessToken> accessTokenWr = auth.getToken(OAuth2AccessToken.class);
        OAuth2AccessToken accessToken = accessTokenWr!=null ? accessTokenWr.getToken() : null;
        if(accessToken != null) {
            cache.set(cacheTokenKey(accessToken.getTokenValue()), auth.getId(), 30*60);
        }

        Token<OidcIdToken> oidcTkn = auth.getToken(OidcIdToken.class);
        OidcIdToken oidcIdToken = oidcTkn!=null ? oidcTkn.getToken() : null;
        if(oidcIdToken != null) {
            cache.set(cacheTokenKey(oidcIdToken.getTokenValue()), auth.getId(), 300);
        }

        Token<OAuth2UserCode> userCode = auth.getToken(OAuth2UserCode.class);
        OAuth2UserCode userCodeToken = userCode!=null ? userCode.getToken() : null;
        if(userCodeToken != null) {
            cache.set(cacheTokenKey(userCodeToken.getTokenValue()), auth.getId(), 300);
        }

        Token<OAuth2DeviceCode> deviceCode = auth.getToken(OAuth2DeviceCode.class);
        OAuth2DeviceCode deviceCodeToken = deviceCode!=null ? deviceCode.getToken() : null;
        if(deviceCodeToken != null) {
            cache.set(cacheTokenKey(deviceCodeToken.getTokenValue()), auth.getId(), 300);
        }

        if(auth.getRefreshToken() != null) {
            OAuth2RefreshToken refreshToken = auth.getRefreshToken().getToken();
            if(refreshToken != null) {
                cache.set(cacheTokenKey(refreshToken.getTokenValue()), auth.getId(), 100*30*60);
            }
        }
    }
    
    public OAuth2Authorization findByToken(String token, OAuth2TokenType tokenType) {
        final String key = cacheTokenKey(token);
        String id = cache.get(key);
        if(id == null || id.isEmpty()) {
            logger.info("findByToken return null " + token);
            return null;
        }
        byte[] bytes = cache.get(cacheKey(id).getBytes());
        if(bytes == null || bytes.length == 0) {
            bytes = cache.get(cacheKey_init(id).getBytes());
        }
        if(bytes == null || bytes.length == 0) {
            return null;
        }
        return readFromBytes(bytes);
    }

    private OAuth2Authorization readFromBytes(final byte[] bytes) {
        try(ByteArrayInputStream bin = new ByteArrayInputStream(bytes)) {
            try(ObjectInputStream objIn = new ObjectInputStream(bin)) {
                return (OAuth2Authorization)objIn.readObject();
            }
        }
        catch(Exception e) {
            logger.error("", e);
            throw new RuntimeException(e);
        }
    }
    private byte[] getBytes(final OAuth2Authorization auth) {
        try {
            try(ByteArrayOutputStream bout = new ByteArrayOutputStream()) {
                try(ObjectOutputStream objOut = new ObjectOutputStream(bout)) {
                    objOut.writeObject(auth);
                }
                return bout.toByteArray();
            }
        }
        catch(Exception e) {
            logger.error("", e);
            throw new RuntimeException(e);
        }
    }

}

import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;

import redis.clients.jedis.HostAndPort;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisCluster;
import redis.clients.jedis.Tuple;
import redis.clients.jedis.params.SetParams;



 @Component

public class MyJedis {
    private static final Logger logger = LoggerFactory.getLogger(MyJedis.class);
    private static void initCluster() { 
        List<String> hostPorts = Arrays.asList("ip1:port1,ip2:port2".split(","));
        for(final String hp: hostPorts) {
            try {
                String[] hpArr = hp.split(":");
                final String host = hpArr[0];
                final int port = Integer.parseInt(hpArr[1]);
                jedisCluster = new JedisCluster(new HostAndPort(host, port), 
                        120, 120, 3, getRedisPassword(), null, getGenericPoolConfig(), false);
                jedisCluster.exists("some-key");
                break;
            }
            catch(Exception e) {
                e.printStackTrace();
            }
        }
        
        logger.info("jedis cluster inited.");
    }
    private static final Thread initerThread = new Thread() {
        public void run() {
            if(jedisCluster == null) {
                initCluster();
            }
        }
    };
    static void initializeRedisCache() {
        if(jedisCluster == null) {
            initerThread.start();
        }
    }
    private static JedisCluster jedisCluster;
    
    
    private static GenericObjectPoolConfig<Jedis> getGenericPoolConfig(){
        GenericObjectPoolConfig<Jedis> genericObjectPoolConfig = 
                new GenericObjectPoolConfig<Jedis>();
        genericObjectPoolConfig.setTestOnBorrow(true);
        return genericObjectPoolConfig;
    }
    
    long unlink(String key) {
        return jedisCluster.unlink(key);
    }
    long unlink(byte[] key) {
        return jedisCluster.unlink(key);
    }
    String get(String key) {
        return jedisCluster.get(key);
    }
    byte[] get(byte[] key) {
        return jedisCluster.get(key);
    }
    boolean exists(String key) {
        return jedisCluster.exists(key);
    }
    long ttl(String key) {
        return jedisCluster.ttl(key);
    }
    String set(String key, String val) {
        return jedisCluster.set(key, val);
    }
    String set(String key, String val, long expireSeconds) {
        return jedisCluster.set(key, val, new SetParams().ex(expireSeconds));
    }
    String set(byte[] key, byte[] val, long expireSeconds) {
        return jedisCluster.set(key, val, new SetParams().ex(expireSeconds));
    }
    
}