NestJS Fastify JWKS Validation

2.7k views Asked by At

I am using the Fastify Adapter in my NestJS application and would like to add some logic to do JWKS validation, similar to the passport example on the Auth0 website.

// src/authz/jwt.strategy.ts

import { Injectable } from '@nestjs/common';
import { PassportStrategy } from '@nestjs/passport';
import { ExtractJwt, Strategy } from 'passport-jwt';
import { passportJwtSecret } from 'jwks-rsa';
import * as dotenv from 'dotenv';

dotenv.config();

@Injectable()
export class JwtStrategy extends PassportStrategy(Strategy) {
  constructor() {
    super({
      secretOrKeyProvider: passportJwtSecret({
        cache: true,
        rateLimit: true,
        jwksRequestsPerMinute: 5,
        jwksUri: `${process.env.AUTH0_ISSUER_URL}.well-known/jwks.json`,
      }),

      jwtFromRequest: ExtractJwt.fromAuthHeaderAsBearerToken(),
      audience: process.env.AUTH0_AUDIENCE,
      issuer: `${process.env.AUTH0_ISSUER_URL}`,
      algorithms: ['RS256'],
    });
  }

  validate(payload: unknown): unknown {
    return payload;
  }
}

It is my understanding that Passport only works with Express and will not work with Fastify. Does anyone know how to do something like this with Fastify and NestJS ?

1

There are 1 answers

0
mh377 On BEST ANSWER

I didn't manage to find a library like passport to do the JWKS validation with fastify. I decided to write my own validation using the jsonwebtoken and the @types/jsonwebtoken libraries.

Below is a sample of my solution for anybody else that is interested :)

File structure is as follows:

src 
 |__ auth
       |__jwks
            |_ jwks.client.ts
            |_ jwks.service.ts
            |_ jwt-auth.guard.ts
            |_ jwt-auth.module.ts
 |__ caching
           |_ redis-cache.module.ts
 |__ models
       |__ json-web-key.model.ts
       |__ jwks-response.model.ts
 |__ my.controller.ts
 |__ app.module.ts

models for the jwks response

// src/models/jwks-response.model.ts

import { JsonWebKey } from "src/models/json-web-key.model";

export class JwksResponse {
    keys: Array<JsonWebKey>
}

// src/models/json-web-key.model.ts

export class JsonWebKey {
        kty: string;
        kid: string;
        use: string;
        x5t: string;
        x5c: Array<string>;
        n?: string;
        e?: string;
        x?: string;
        y?: string;
        crv?: string;
    }

client to call the jwks endpoint and process the response

//src/auth/jwks/jwks.client.ts

import { HttpException, Injectable, Logger } from "@nestjs/common";
import { ConfigService} from "@nestjs/config";
import { HttpService } from "@nestjs/axios";
import { map, lastValueFrom } from "rxjs";
import { JwksResponse } from "src/models/jwks-response.model";
import { JsonWebKey } from "src/models/json-web-key.model";

@Injectable()
export class JwksClient {

    private readonly logger: Logger = new Logger(JwksClient.name);
    private readonly JWKS_URL: string = this.configService.get<string>('services.jwks.url');
    private readonly TIMEOUT: number = parseInt(this.configService.get<string>('services.timeout'));

    constructor(private configService: ConfigService, private httpService: HttpService){}

    async getJsonWebKeySet(): Promise<Array<JsonWebKey>> {
        this.logger.log(`Attempting to retrieve json web keys from Jwks endpoint`);

        const config = {
            timeout: this.TIMEOUT,
        };

        let response: JwksResponse = null;
        try {
            response = await lastValueFrom(this.httpService.get(this.JWKS_URL, config)
                .pipe(
                        map((response) => {
                            return response.data;
                        },
                    ),
                ),
            );
        } catch(e) {
            this.logger.error(`An error occurred invoking Jwks endpoint to retrieve public keys`);
            this.logger.error(e.stack);
            throw new HttpException(e.message, e.response.status);
        }

        if (!response || !response.keys || response.keys.length == 0) {
            this.logger.error('No json web keys were returned from Jwks endpoint')
            return [];
        }

        return response.keys;
    }
}

service containing logic to call jwks endpoint and verify the jwt token with the public key.

The JWT token will consist of a header, payload and a signature.

The header should also have a kid field that will match the kid of one of the json web keys, so that you know which one to verify your token with.

The x5c array contains a certificate chain and the first element of this array will always contain the certificate that you use to get the public key from to verify the token.

Note: I had to wrap the certificate in with \n-----BEGIN CERTIFICATE-----\n${key.x5c[0]}\n-----END CERTIFICATE----- to be able to create the public key but you may not have to do this in your implementation.

You will also need to add logic to verify the claims for your JWT.

I have also cached a valid JWT for a period of time to ensure that the verification is not required each time as this would have performance implications, the key for this cache uses the auth token to ensure that it is unique.

import { HttpException, HttpStatus, Injectable, CACHE_MANAGER, Logger, Inject } from "@nestjs/common";
import { ConfigService} from "@nestjs/config";
import { IncomingHttpHeaders } from "http";
import { JwksClient } from "src/auth/jwks/jwks.client";
import { JsonWebKey } from "src/models/json-web-key.model";
import { JwtPayload } from 'jsonwebtoken';
import * as jwt from 'jsonwebtoken';
import * as crypto from "crypto";
import { Cache } from 'cache-manager';

@Injectable()
export class JwksService {

    private readonly logger: Logger = new Logger(JwksService.name);
    private readonly CACHE_KEY: string = this.configService.get<string>('caches.jwks.key');
    private readonly CACHE_TTL: number = parseInt(this.configService.get<string>('caches.jwks.ttl'));

    constructor(private configService: ConfigService, private readonly jwksClient: JwksClient, @Inject(CACHE_MANAGER) private cacheManager: Cache){}

    async verify(request: any): Promise<boolean> {

        let token: string = this.getAuthorizationTokenFromHeader(request.headers);

        const jwksKey = `${this.CACHE_KEY}:${token}`

        const cachedVerificationResult: boolean = await this.getCachedVerificationResult(jwksKey);

        if (cachedVerificationResult) {
            this.logger.debug("Found cached verification result");
            return cachedVerificationResult;
        }

        if (!this.hasTokenWithValidClaims(token)) {
            this.logger.error("Token with invalid claims was provided")
            return false;
        }

        // Get all web keys from JWKS endpoint
        let jsonWebKeys: Array<JsonWebKey> = await this.jwksClient.getJsonWebKeySet();

        // Find the public key with matching kid
        let publicKey: string | Buffer = this.findPublicKey(token, jsonWebKeys);

        if (!publicKey) {
            this.logger.error("No public key was found for the bearer token provided")
            return false;
        }

        try {
            jwt.verify(token, publicKey, { algorithms: ['Put algorithm here e.g. HS256, ES256 etc'] });
        } catch(e) {
            this.logger.error("An error occurred verifying the bearer token with the associated public key");
            this.logger.error(e.stack);
            throw new HttpException(e.message, HttpStatus.FORBIDDEN);
        }


        // Cache Jwks validation result
        this.cacheManager.set(jwksKey, true, { ttl: this.CACHE_TTL });

        this.logger.debug("Successfully verified bearer token with the associated public key")

        return true;
    }

    private hasTokenWithValidClaims(token: string) {

        var { header, payload, signature } = jwt.decode(token, { complete: true });

        
        // TODO: Add validation for claims

        return true;
    }

    private findPublicKey(token: string, jsonWebKeys: Array<JsonWebKey>): string | Buffer {

        var { header } = jwt.decode(token, { complete: true });

        let key = null;
        for (var jsonWebKey of jsonWebKeys) {
            if (jsonWebKey.kid === header.kid) {
                this.logger.debug(`Found json web key for kid ${header.kid}`);
                key = jsonWebKey;
                break;
            }
        }

        if (!key) {
            return null;
        }

        // Exctact x509 certificate from the certificate chain
        const x509Certificate = `\n-----BEGIN CERTIFICATE-----\n${key.x5c[0]}\n-----END CERTIFICATE-----`;

        // Create the public key from the x509 certificate
        return crypto.createPublicKey(x509Certificate).export({type:'spki', format:'pem'})
    }

    private getAuthorizationTokenFromHeader(headers: IncomingHttpHeaders): string {

        if(!headers || !headers.authorization) {
            throw new HttpException("Authorization header is missing", HttpStatus.BAD_REQUEST);
        }

        let token: string = headers.authorization;

        if (token.startsWith("Bearer ")) {
            token = headers.authorization.split(" ")[1].trim();
        }

        return token;
    }

    private async getCachedVerificationResult(jwksKey: string): Promise<boolean> {
        const response: boolean = await this.cacheManager.get(jwksKey);

        if(response && response === true) {
            return response;
        }

        return null;
    }
}

guard to verify the JWT

// src/auth/jwks/jwt-auth.guard.ts

import { Injectable, CanActivate, ExecutionContext, Logger } from '@nestjs/common';
import { JwksService } from 'src/auth/jwks/jwks.service';

@Injectable()
export class JwtAuthGuard implements CanActivate {

    private readonly logger: Logger = new Logger(JwtAuthGuard.name);

    constructor(private jwksService: JwksService){}

    async canActivate(
        context: ExecutionContext,
    ): Promise<boolean> {
        const request = context.switchToHttp().getRequest();
        return await this.jwksService.verify(request);
    }
}

module containing config for jwks

// src/auth/jwks/jwt-auth.model.ts

import { Module } from '@nestjs/common';
import { ConfigModule } from '@nestjs/config';
import { HttpModule } from '@nestjs/axios';
import configuration from '../../../config/configuration';
import { JwksClient } from 'src/auth/jwks/jwks.client';
import { JwksService } from 'src/auth/jwks/jwks.service';

@Module({
  imports: [
    ConfigModule.forRoot({ load: [configuration] }),
    HttpModule
  ],
  providers: [
    JwksClient,
    JwksService,
  ],
  exports: [JwksService, JwksClient],
})
export class JwtAuthModule {}

redis caching module containing config for redis cache

// src/caching/redis-cache.module.ts
import {  CacheModule, Module } from '@nestjs/common';
import { ConfigModule, ConfigService } from '@nestjs/config';
import configuration from '../../config/configuration';
import { RedisClientOptions } from 'redis';
import * as redisStore from 'cache-manager-redis-store';

@Module({
  imports: [
    ConfigModule.forRoot({ load: [configuration] }),
    CacheModule.registerAsync<RedisClientOptions>({
        isGlobal: true,
        imports: [ConfigModule],
        useFactory: async (configService: ConfigService) => ({
            store: redisStore,
            host: process.env.REDIS_URL,
            port: configService.get<number>('redis.port'),
            password: configService.get<string>('redis.password'),
            tls: configService.get<boolean>('redis.tls')
        }),
        inject: [ConfigService],
    })
  ],
  controllers: [],
  providers: []
})
export class RedisCacheModule {}

controller that uses the JwtAuthGuard

// src/my.controller.ts
import { Controller, Get, Param, Logger } from '@nestjs/common';

@Controller()
@UseGuards(JwtAuthGuard)
export class MyController {
    private readonly logger: Logger = new Logger(MyController.name);

    @Get('/:id')
    async getCustomerDetails(@Headers() headers, @Param('id') id: string): Promise<Customer> {
        this.logger.log(`Accepted incoming request with id: ${id}`);

        // Do some processing ....

        return customer;
    }
}

module containing configuration for whole app

// src/app.module.ts

import { Module } from '@nestjs/common';
import { ConfigModule } from '@nestjs/config';
import { HttpModule } from '@nestjs/axios';
import configuration from '../config/configuration';
import { JwtAuthModule } from 'src/auth/jwks/jwt-auth.module';
import { RedisCacheModule } from 'src/caching/redis-cache.module';

@Module({
  imports: [
    ConfigModule.forRoot({ load: [configuration] }),
    HttpModule,
    JwtAuthModule,
    RedisCacheModule
  ],
  controllers: [MyController],
  providers: []
})
export class AppModule {}