create an efficient dataset to feed a keras siamese model

35 views Asked by At

I need help to build a generator for a Siamese model in Keras and Tensorflow. I succeeded in creating a python generator but it resulted too slow for my purposes. My idea was to use Tensorflow datasets but I could not make it work.

My data are couples of proteins that do (label=1) or do not interact (label =0). Data are redundant, meaning that a protein can interact with many others like this:

Protein 1 protein 2 label

1 2 1

1 3 1

1 4 0

2 1 0

2 4 1

And so on. For this reason I thought to create a tf.dataset of indexes, to manipulate them, and to recover the protein sequences at the end. Ideally my dataset should produce a shuffled, batched list of coupled indexes with labels like:

id1 id2 label

12 15 0

10 9 1

25 12 1

9 7 0

And then a final step should recover sequences from another indexed database, returning protein sequences to be fed to the Siamese model. Kind of:

if the db with sequences is:

1 seq(1)

2 seq(2)

. . .

20000 seq(20000)

The function should return a batch:

[seq(p12),seq(p15),0]

[seq(p10),seq(p9),1]

[seq(p25),seq(p12),1]

[seq(p9),seq(p7),0]

I created a dataset of indexes, shuffled and batched it, but there is no way to recover the real sequences using dataset.map. I tried everything I could, but I miss something basic I think. The example code is:

import tensorflow as tf
import numpy as np
   
print(tf.__version__)
   
# Create sequence db tf version
 
seq_db_tf=tf.random.uniform(
    shape=[10,10],
    minval=0,
    maxval=1,
    dtype=tf.dtypes.float32,
    seed=None,
    name=None
)
 
# Create sequence db numpy version

seq_db_np=np.random.rand(10,10)

 
# check indexing

print (seq_db[1:4])
 
# Create array of indexes and labels (not 0 1 but random in this example)

index_db=tf.random.uniform(
    shape=[4,3],
    minval=0,
    maxval=10,
    dtype=tf.dtypes.int32,
    seed=None,
    name=None
)
 
# check indexing

print (index_db[1:4])


def to_numpy(t):
    tn=t.numpy()
    return tn


dataset = tf.data.Dataset.from_tensor_slices(index_db)
dataset=dataset.batch(2,drop_remainder=True)
dataset1=dataset.map(lambda x: create_couples(x,seq_db_np))
 

##################### function to map something like

def create_couples(dataset,seq_db_np):
        for i in dataset:
            p1=i[:,0].numpy()
            p2=i[:,1].numpy()
            label=i[:,2].numpy()
            seq1=seq_db_np[p1,:]
            seq2=seq_db_np[p2,:]
            return [[seq1,seq2],label]
1

There are 1 answers

0
Davide On

in ended up with this solution:

def create_couples(it,seq_db):
    p1=it[:,0:1]
    p2=it[:,1:2]
    ii1=tf.squeeze(p1)
    ii2=tf.squeeze(p2)
    X1=tf.gather(seq_db, ii1)
    X2=tf.gather(seq_db, ii2)
    return [X1,X2]

def extract_labels(i,seq_db,to_cat):
    labels=i[0]
    if to_cat==1:
        labels=tf.keras.utils.to_categorical(labels,num_classes=2,dtype='int32')
    return labels
#################################### create dataset

trainDS = tf.data.Dataset.from_tensor_slices(train_s_t)
trainDS = (trainDS
    .batch(batch_size,drop_remainder=True)    
    .map(lambda t, seq_db_np=esmnp: tf.py_function(create_couples, 
                               inp=[t,seq_db_np], 
                               Tout=[ tf.float32,tf.float32]), 
                               num_parallel_calls=tf.data.AUTOTUNE)
    .prefetch(tf.data.AUTOTUNE))

trainY = tf.data.Dataset.from_tensor_slices(trainY_t)
trainY = (trainY
    .batch(batch_size,drop_remainder=True)
    .prefetch(tf.data.AUTOTUNE)
    .map(lambda t,seq_db_np=esmnp,to_cat=to_cat:
tf.py_function(extract_labels, inp=[t,seq_db_np,to_cat], 
                               Tout=[tf.int32]),
                               num_parallel_calls=tf.data.AUTOTUNE))

trainDSY=tf.data.Dataset.zip((trainDS,trainY))

now it works, but I had to batch before mapping , because the function "create couples" accepted only batches of data in "tf.gather" function. I had also to zip sequences and labels datasets because the siamese model did not take a single dataset in the formats [[X1,X2],labels] or ((X1,X2),labels). I wanted a more elegant solution, but it seems to work.