Mocking AWS services and Lambda best practices

3k views Asked by At

I'm working on a simple AWS lambda function which is triggered by DynamoDB Streams events and should forward all records except for REMOVE events to an SQS queue. The function works as expected, no surprises there.

I want to write a unit test to test the behavior of not submitting anything to SQS when it's a DELETE event. I first tried this using aws-sdk-mock. As you can see in the function code, I try to comply with lambda best practices by initializing the SQS client outside of the handler code. Apparently this prevents aws-sdk-mock from being able to mock the SQS service (there is an issue on GitHub regarding this: https://github.com/dwyl/aws-sdk-mock/issues/206).

I then tried to mock SQS using jest which required more code to get it right, but I ended up with the same problem, being required to place the initialization of SQS inside the handler function which violates lambda best practices.

How can I write a unit test for this function while letting the initialization of the SQS client (const sqs: SQS = new SQS()) outside the handler? Am I mocking the service the wrong way or has the structure of the handler to be changed in order to make it easier to test?

I'm aware that this lambda function is pretty straight forward and unit tests might be unnecessary, but I will have to write further lambdas with more complex logic and I think this one is quite suitable to demonstrate the problem.

index.ts

import {DynamoDBStreamEvent, DynamoDBStreamHandler} from "aws-lambda";
import SQS = require("aws-sdk/clients/sqs");
import DynamoDB = require("aws-sdk/clients/dynamodb");

const sqs: SQS = new SQS()

export const handleDynamoDbEvent: DynamoDBStreamHandler = async (event: DynamoDBStreamEvent, context, callback) => {
    const QUEUE_URL = process.env.TARGET_QUEUE_URL
    if (QUEUE_URL.length == 0) {
        throw new Error('TARGET_QUEUE_URL not set or empty')
    }
    await Promise.all(
        event.Records
            .filter(_ => _.eventName !== "REMOVE")
            .map((record) => {
                const unmarshalled = DynamoDB.Converter.unmarshall(record.dynamodb.NewImage);
                let request: SQS.SendMessageRequest = {
                    MessageAttributes: {
                        "EVENT_NAME": {
                            DataType: "String",
                            StringValue: record.eventName
                        }
                    },
                    MessageBody: JSON.stringify(unmarshalled),
                    QueueUrl: QUEUE_URL,
                }
                return sqs.sendMessage(request).promise()
            })
    );
}

index.spec.ts

import {DynamoDBRecord, DynamoDBStreamEvent, StreamRecord} from "aws-lambda";
import {AttributeValue} from "aws-lambda/trigger/dynamodb-stream";
import {handleDynamoDbEvent} from "./index";
import {AWSError} from "aws-sdk/lib/error";
import {PromiseResult, Request} from "aws-sdk/lib/request";
import * as SQS from "aws-sdk/clients/sqs";
import {mocked} from "ts-jest/utils";
import DynamoDB = require("aws-sdk/clients/dynamodb");


jest.mock('aws-sdk/clients/sqs', () => {
    return jest.fn().mockImplementation(() => {
        return {
            sendMessage: (params: SQS.Types.SendMessageRequest, callback?: (err: AWSError, data: SQS.Types.SendMessageResult) => void): Request<SQS.Types.SendMessageResult, AWSError> => {
                // @ts-ignore
                const Mock = jest.fn<Request<SQS.Types.SendMessageResult, AWSError>>(()=>{
                    return {
                        promise: (): Promise<PromiseResult<SQS.Types.SendMessageResult, AWSError>> => {
                            return new Promise<PromiseResult<SQS.SendMessageResult, AWSError>>(resolve => {
                                resolve(null)
                            })
                        }
                    }
                })
                return new Mock()
            }
        }
    })
});


describe.only('Handler test', () => {

    const mockedSqs = mocked(SQS, true)

    process.env.TARGET_QUEUE_URL = 'test'
    const OLD_ENV = process.env;

    beforeEach(() => {
        mockedSqs.mockClear()
        jest.resetModules();
        process.env = {...OLD_ENV};
    });

    it('should write INSERT events to SQS', async () => {
        console.log('Starting test')
        await handleDynamoDbEvent(createEvent(), null, null)
        expect(mockedSqs).toHaveBeenCalledTimes(1)
    });
})
2

There are 2 answers

4
Phuong Nguyen On BEST ANSWER

Just a rough idea of how would I approach this:

  • Instead of doing actual SQS sending/manipulation inside the main function, I would create an interface for message client. Something like this:
interface QueueClient {
    send(eventName: string, body: string): Promise<any>;
}
  • And create an actual class that implements that interface to do interaction with SQS:
class SQSQueueClient implements QueueClient {
    queueUrl: string
    sqs: SQS

    constructor() {
        this.queueUrl = process.env.TARGET_QUEUE_URL;
        if (this.queueUrl.length == 0) {
            throw new Error('TARGET_QUEUE_URL not set or empty')
        }
        this.sqs = new SQS();
    }

    send(eventName: string, body: string): Promise<any> {
        let request: SQS.SendMessageRequest = {
            MessageAttributes: {
                "EVENT_NAME": {
                    DataType: "String",
                    StringValue: eventName
                }
            },
            MessageBody: body,
            QueueUrl: this.queueUrl,
        }
        return this.sqs.sendMessage()
    }
}

This class knows about details of how to translate data into SQS format

  • I will then separate the main function into 2. The entry point just parses the queue url, create an actual instance of sqs queue client and call process(). Main logic is in process()
const queueClient = new SQSQueueClient();

export const handleDynamoDbEvent: DynamoDBStreamHandler = async (event: DynamoDBStreamEvent, context, callback) => {
    return process(queueClient, event);
}

export const process = async (queueClient: QueueClient, event: DynamoDBStreamEvent) => {
    return await Promise.all(
        event.Records
            .filter(_ => _.eventName !== "REMOVE")
            .map((record) => {
                const unmarshalled = DynamoDB.Converter.unmarshall(record.dynamodb.NewImage);
                return queueClient.send(record.eventName, JSON.stringify(unmarshalled));
            })
    );
}
  • Now it's much easier to test the main logic in process(). You can provide a mock instance that implements the interface QueueClient by handwriting one or use whatever mocking framework you like
  • For SQSQueueClient class, there's not much benefit of unit testing it, so I will rely more on integration test (e.g. use something like localstack)

I don't have actual IDE now so pardon me if there's syntax error here and there

0
mheck On

I added an initialization method which is called from inside the handler function. It returns immediately if it has already been called earlier and will otherwise initialize the SQS client. It can easily be extended to also initialize other clients.

This is in line with the lambda best practices and makes the test code work.

let sqs: SQS = null
let initialized = false

export const handleDynamoDbEvent: DynamoDBStreamHandler = async (event: DynamoDBStreamEvent, context, callback) => {
    init()
    const QUEUE_URL = process.env.TARGET_QUEUE_URL
    if (QUEUE_URL.length == 0) {
        throw new Error('TARGET_QUEUE_URL not set or empty')
    }
    await Promise.all(
        event.Records
            .filter(_ => _.eventName !== "REMOVE")
            .map((record) => {
                const unmarshalled = DynamoDB.Converter.unmarshall(record.dynamodb.NewImage);
                let request: SQS.SendMessageRequest = {
                    MessageAttributes: {
                        "EVENT_NAME": {
                            DataType: "String",
                            StringValue: record.eventName
                        }
                    },
                    MessageBody: JSON.stringify(unmarshalled),
                    QueueUrl: QUEUE_URL,
                }
                return sqs.sendMessage(request).promise()
            })
    );
}

function init() {
    if (initialized) {
        return
    }
    console.log('Initializing...')
    initialized = true
    sqs = new SQS()
}