LSTM with changing batch size while training

654 views Asked by At

I'm trying to build an LSTM on app-log data from different users. I have one big dataframe consisting of stacked app records of the users, so for example the first 1500 rows are for user 1, the following 500 for user 2 etc. I'm now wondering if it is possible to train the LSTM in such a way that the weights are updated after each user which would mean changing the batch size after each update. For a better understanding: I want the LSTM to first take all records of user 1 which are 1500 rows and make an update of weights after processing them, after that it should take the 500 rows of user 2 and should make an update of weights after processing them etc.

I'm building the LSTM with Keras.

Is there a possibility to do so?

Thanks!

1

There are 1 answers

3
ki-ljl On

I don't know your specific application scenario, but I'm assuming it's time series forecasting.

Build the LSTM model:

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.output_size = output_size
        self.num_directions = 1
        self.batch_size = batch_size
        self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
        self.linear = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input_seq):
        h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
        c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
        # print(input_seq.size())
        seq_len = input_seq.shape[1]
        # input(batch_size, seq_len, input_size)
        input_seq = input_seq.view(self.batch_size, seq_len, self.input_size)
        # output(batch_size, seq_len, num_directions * hidden_size)
        output, _ = self.lstm(input_seq, (h_0, c_0))
        # print('output.size=', output.size())
        # print(self.batch_size * seq_len, self.hidden_size)
        output = output.contiguous().view(self.batch_size * seq_len, self.hidden_size)  # (5 * 30, 64)
        pred = self.linear(output)  # pred()
        # print('pred=', pred.shape)
        pred = pred.view(self.batch_size, seq_len, -1)
        pred = pred[:, -1, :]
        return pred

You can use DataLoader to process data from different users with batch sizes of different sizes to get data sets of multiple users.

Like this:

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, item):
        return self.data[item]

    def __len__(self):
        return len(self.data)

Dtr = DataLoader(dataset=train, batch_size=B, shuffle=False, num_workers=0)
Dte = DataLoader(dataset=test, batch_size=B, shuffle=False, num_workers=0)

Then, we start training:

for t in range(len(users)):
    # change batch size
    b = batchsizes[t]   # Store batch_size of each user in batchsizes
    model = LSTM(input_size, hidden_size, num_layers, output_size, batch_size=b).to(device)
    if t != 0:
        model.load_state_dict(torch.load(LSTM_PATH)['model'])
    model.train()
    Dtr = Dtrs[t]  # Store train data of each user in Dtrs
    for i in range(epochs):
        cnt = 0
        for (seq, label) in Dtr:
            cnt += 1
            seq = seq.to(device)
            label = label.to(device)
            y_pred = model(seq)
            loss = loss_function(y_pred, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if cnt % 100 == 0:
                print('epoch', i, ':', cnt - 100, '~', cnt, loss.item())
    # Save the current user's model after training
    state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict()}
    torch.save(state, LSTM_PATH)

I'm sorry that the above code is not working directly, because I don't know your data situation, so I just provide you with a general framework.