Symmetric encryption (AES) in Apache Thrift

1.2k views Asked by At

I have two applications that interact using Thrift. They share the same secret key and I need to encrypt their messages. It makes sense to use symmetric algorithm (AES, for example), but I haven't found any library to do this. So I made a research and see following options:

Use built-in SSL support

I can use built-in SSL support, establish secure connection and use my secret key just as authentication token. It requires to install certificates in addition to the secret key they already have, but I don't need to implement anything except checking that secret key received from client is the same as secret key stored locally.

Implement symmetric encryption

So far, there are following options:

  1. Extend TSocket and override write() and read() methods and en- / decrypt data in them. Will have increasing of traffic on small writes. For example, if TBinaryProtocol writes 4-bytes integer, it will take one block (16 bytes) in encrypted state.
  2. Extend TSocket and wrap InputStream and OutputStream with CipherInputStream and CipherOutputStream. CipherOutputStream will not encrypt small byte arrays immediately, updating Cipher with them. After we have enough data, they will be encrypted and written to the underlying OutputStream. So it will wait until you add 4 4-byte ints and encrypt them then. It allows us not wasting traffic, but is also a cause of problem - if last value will not fill the block, it will be never encrypted and written to the underlying stream. It expects me to write number of bytes divisible by its block size (16 byte), but I can't do this using TBinaryProtocol.
  3. Re-implement TBinaryProtocol, caching all writes instead of writing them to stream and encrypting in writeMessageEnd() method. Implement decryption in readMessageBegin(). I think encryption should be performed on the transport layer, not protocol one.

Please share your thoughts with me.

UPDATE

Java Implementation on Top of TFramedTransport

TEncryptedFramedTransport.java

package tutorial;

import org.apache.thrift.TByteArrayOutputStream;
import org.apache.thrift.transport.TMemoryInputTransport;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.apache.thrift.transport.TTransportFactory;

import javax.crypto.Cipher;
import java.security.Key;
/**
 * TEncryptedFramedTransport is a buffered TTransport. It encrypts fully read message
 * with the "AES/ECB/PKCS5Padding" symmetric algorithm and send it, preceeding with a 4-byte frame size.
 */
public class TEncryptedFramedTransport extends TTransport {
    public static final String ALGORITHM = "AES/ECB/PKCS5Padding";

    private Cipher encryptingCipher;
    private Cipher decryptingCipher;

    protected static final int DEFAULT_MAX_LENGTH = 0x7FFFFFFF;

    private int maxLength_;

    private TTransport transport_ = null;

    private final TByteArrayOutputStream writeBuffer_ = new TByteArrayOutputStream(1024);
    private TMemoryInputTransport readBuffer_ = new TMemoryInputTransport(new byte[0]);

    public static class Factory extends TTransportFactory {
        private int maxLength_;
        private Key secretKey_;

        public Factory(Key secretKey) {
            this(secretKey, DEFAULT_MAX_LENGTH);
        }

        public Factory(Key secretKey, int maxLength) {
            maxLength_ = maxLength;
            secretKey_ = secretKey;
        }

        @Override
        public TTransport getTransport(TTransport base) {
            return new TEncryptedFramedTransport(base, secretKey_, maxLength_);
        }
    }

    /**
     * Constructor wraps around another tranpsort
     */
    public TEncryptedFramedTransport(TTransport transport, Key secretKey, int maxLength) {
        transport_ = transport;
        maxLength_ = maxLength;

        try {
            encryptingCipher = Cipher.getInstance(ALGORITHM);
            encryptingCipher.init(Cipher.ENCRYPT_MODE, secretKey);

            decryptingCipher = Cipher.getInstance(ALGORITHM);
            decryptingCipher.init(Cipher.DECRYPT_MODE, secretKey);
        } catch (Exception e) {
            throw new RuntimeException("Unable to initialize ciphers.");
        }
    }

    public TEncryptedFramedTransport(TTransport transport, Key secretKey) {
        this(transport, secretKey, DEFAULT_MAX_LENGTH);
    }

    public void open() throws TTransportException {
        transport_.open();
    }

    public boolean isOpen() {
        return transport_.isOpen();
    }

    public void close() {
        transport_.close();
    }

    public int read(byte[] buf, int off, int len) throws TTransportException {
        if (readBuffer_ != null) {
            int got = readBuffer_.read(buf, off, len);
            if (got > 0) {
                return got;
            }
        }

        // Read another frame of data
        readFrame();

        return readBuffer_.read(buf, off, len);
    }

    @Override
    public byte[] getBuffer() {
        return readBuffer_.getBuffer();
    }

    @Override
    public int getBufferPosition() {
        return readBuffer_.getBufferPosition();
    }

    @Override
    public int getBytesRemainingInBuffer() {
        return readBuffer_.getBytesRemainingInBuffer();
    }

    @Override
    public void consumeBuffer(int len) {
        readBuffer_.consumeBuffer(len);
    }

    private final byte[] i32buf = new byte[4];

    private void readFrame() throws TTransportException {
        transport_.readAll(i32buf, 0, 4);
        int size = decodeFrameSize(i32buf);

        if (size < 0) {
            throw new TTransportException("Read a negative frame size (" + size + ")!");
        }

        if (size > maxLength_) {
            throw new TTransportException("Frame size (" + size + ") larger than max length (" + maxLength_ + ")!");
        }

        byte[] buff = new byte[size];
        transport_.readAll(buff, 0, size);

        try {
            buff = decryptingCipher.doFinal(buff);
        } catch (Exception e) {
            throw new TTransportException(0, e);
        }

        readBuffer_.reset(buff);
    }

    public void write(byte[] buf, int off, int len) throws TTransportException {
        writeBuffer_.write(buf, off, len);
    }

    @Override
    public void flush() throws TTransportException {
        byte[] buf = writeBuffer_.get();
        int len = writeBuffer_.len();
        writeBuffer_.reset();

        try {
            buf = encryptingCipher.doFinal(buf, 0, len);
        } catch (Exception e) {
            throw new TTransportException(0, e);
        }

        encodeFrameSize(buf.length, i32buf);
        transport_.write(i32buf, 0, 4);
        transport_.write(buf);
        transport_.flush();
    }

    public static void encodeFrameSize(final int frameSize, final byte[] buf) {
        buf[0] = (byte) (0xff & (frameSize >> 24));
        buf[1] = (byte) (0xff & (frameSize >> 16));
        buf[2] = (byte) (0xff & (frameSize >> 8));
        buf[3] = (byte) (0xff & (frameSize));
    }

    public static int decodeFrameSize(final byte[] buf) {
        return
                ((buf[0] & 0xff) << 24) |
                        ((buf[1] & 0xff) << 16) |
                        ((buf[2] & 0xff) << 8) |
                        ((buf[3] & 0xff));
    }
}

MultiplicationServer.java

package tutorial;

import co.runit.prototype.CryptoTool;
import org.apache.thrift.server.TNonblockingServer;
import org.apache.thrift.server.TServer;
import org.apache.thrift.transport.TNonblockingServerSocket;
import org.apache.thrift.transport.TNonblockingServerTransport;

import java.security.Key;

public class MultiplicationServer {
    public static MultiplicationHandler handler;

    public static MultiplicationService.Processor processor;

    public static void main(String[] args) {
        try {
            handler = new MultiplicationHandler();
            processor = new MultiplicationService.Processor(handler);

            Runnable simple = () -> startServer(processor);

            new Thread(simple).start();
        } catch (Exception x) {
            x.printStackTrace();
        }
    }

    public static void startServer(MultiplicationService.Processor processor) {
        try {
            Key key = CryptoTool.decodeKeyBase64("1OUXS3MczVFp3SdfX41U0A==");

            TNonblockingServerTransport serverTransport = new TNonblockingServerSocket(9090);
            TServer server = new TNonblockingServer(new TNonblockingServer.Args(serverTransport)
                    .transportFactory(new TEncryptedFramedTransport.Factory(key))
                    .processor(processor));

            System.out.println("Starting the simple server...");
            server.serve();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

MultiplicationClient.java

package tutorial;

import co.runit.prototype.CryptoTool;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;

import java.security.Key;

public class MultiplicationClient {
    public static void main(String[] args) {
        Key key = CryptoTool.decodeKeyBase64("1OUXS3MczVFp3SdfX41U0A==");

        try {
            TSocket baseTransport = new TSocket("localhost", 9090);
            TTransport transport = new TEncryptedFramedTransport(baseTransport, key);
            transport.open();

            TProtocol protocol = new TBinaryProtocol(transport);
            MultiplicationService.Client client = new MultiplicationService.Client(protocol);

            perform(client);

            transport.close();
        } catch (TException x) {
            x.printStackTrace();
        }
    }

    private static void perform(MultiplicationService.Client client) throws TException {
        int product = client.multiply(3, 5);
        System.out.println("3*5=" + product);
    }
}

Of course, keys must be the same on the client and server. To generate and store it in Base64:

public static String generateKey() throws NoSuchAlgorithmException, InvalidAlgorithmParameterException {
    KeyGenerator generator = KeyGenerator.getInstance("AES");
    generator.init(128);
    Key key = generator.generateKey();
    return encodeKeyBase64(key);
}

public static String encodeKeyBase64(Key key) {
    return Base64.getEncoder().encodeToString(key.getEncoded());
}

public static Key decodeKeyBase64(String encodedKey) {
    byte[] keyBytes = Base64.getDecoder().decode(encodedKey);
    return new SecretKeySpec(keyBytes, ALGORITHM);
}

UPDATE 2

Python Implementation on Top of TFramedTransport

TEncryptedTransport.py

from cStringIO import StringIO
from struct import pack, unpack
from Crypto.Cipher import AES

from thrift.transport.TTransport import TTransportBase, CReadableTransport

__author__ = 'Marboni'

BLOCK_SIZE = 16

pad = lambda s: s + (BLOCK_SIZE - len(s) % BLOCK_SIZE) * chr(BLOCK_SIZE - len(s) % BLOCK_SIZE)
unpad = lambda s: '' if not s else s[0:-ord(s[-1])]

class TEncryptedFramedTransportFactory:
    def __init__(self, key):
        self.__key = key

    def getTransport(self, trans):
        return TEncryptedFramedTransport(trans, self.__key)


class TEncryptedFramedTransport(TTransportBase, CReadableTransport):
    def __init__(self, trans, key):
        self.__trans = trans
        self.__rbuf = StringIO()
        self.__wbuf = StringIO()

        self.__cipher = AES.new(key)

    def isOpen(self):
        return self.__trans.isOpen()

    def open(self):
        return self.__trans.open()

    def close(self):
        return self.__trans.close()

    def read(self, sz):
        ret = self.__rbuf.read(sz)
        if len(ret) != 0:
            return ret

        self.readFrame()
        return self.__rbuf.read(sz)

    def readFrame(self):
        buff = self.__trans.readAll(4)
        sz, = unpack('!i', buff)
        encrypted = StringIO(self.__trans.readAll(sz)).getvalue()

        decrypted = unpad(self.__cipher.decrypt(encrypted))

        self.__rbuf = StringIO(decrypted)

    def write(self, buf):
        self.__wbuf.write(buf)

    def flush(self):
        wout = self.__wbuf.getvalue()
        self.__wbuf = StringIO()

        encrypted = self.__cipher.encrypt(pad(wout))
        encrypted_len = len(encrypted)
        buf = pack("!i", encrypted_len) + encrypted
        self.__trans.write(buf)
        self.__trans.flush()

    # Implement the CReadableTransport interface.
    @property
    def cstringio_buf(self):
        return self.__rbuf

    def cstringio_refill(self, prefix, reqlen):
        while len(prefix) < reqlen:
            self.readFrame()
            prefix += self.__rbuf.getvalue()
        self.__rbuf = StringIO(prefix)
        return self.__rbuf

MultiplicationClient.py

import base64
from thrift import Thrift
from thrift.transport import TSocket
from thrift.protocol import TBinaryProtocol

from tutorial import MultiplicationService, TEncryptedTransport

key = base64.b64decode("1OUXS3MczVFp3SdfX41U0A==")

try:
    transport = TSocket.TSocket('localhost', 9090)
    transport = TEncryptedTransport.TEncryptedFramedTransport(transport, key)

    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    client = MultiplicationService.Client(protocol)

    transport.open()

    product = client.multiply(4, 5, 'Echo!')
    print '4*5=%d' % product

    transport.close()
except Thrift.TException, tx:
    print tx.message
1

There are 1 answers

1
codeSF On BEST ANSWER

As stated by JensG, sending an externally encrypted binary or supplying a layered cipher transport are the two best options. It you need a template, take a look at the TFramedTransport. It is a simple layered transport and could easily be used as a starting block for creating a TCipherTransport.