Training RNNs for learning sequential probability distributions

49 views Asked by At

I am interested in learning probability distributions with an RNN. I have attached some handwritten sketches to help explain.

  1. Imagine a circuit, and I can measure certain objects from this circuit. The circuit is specified by elements in a vector b.

  2. The outputs of the circuit can be measured sequentially (ie: first object 1, object2,...). There are 6 possible measurements that these objects can take which I will call a. In the code below these are encoded first one hot encoding "0": [1,0,0,0,0,0] etc...

  3. I want to develop an RNN that can predict the probability distribution of P(a1 | b). Then the next would be P(a2 | a1, b) etc. Please refer to my RNN sketch. the output vectors of the RNN for the decoding part would represent probability distributions.

I have implemented code for this strategy below, and am trying to get it to memorize a very simple dataset where object 1 is 50% 0 and 50% 1 and object 2 is always 1. I anticipate this should be able to converge to 0, but if you run it, you find it keeps getting stuck with a loss that is not zero. I am unable to find the structural problem in my code.

Are there any conceptual issues or major implementation errors within this code?

import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import TensorDataset, DataLoader


# Data (example with 4 measurements)
# these lets say represent the "parameters" of the circuit we are testing
circuit_params = np.array([[0.2, 0.4, 0.7, 0.6] for i in range(4)])

measurement_samples = np.array([[[1,0,0,0,0,0], [0,1,0,0,0,0]],
                                [[0,1,0,0,0,0], [0,1,0,0,0,0]],
                                [[1,0,0,0,0,0], [0,1,0,0,0,0]],
                                [[0,1,0,0,0,0], [0,1,0,0,0,0]]])

# Convert numpy arrays to PyTorch tensors
circuit_params_tensor = torch.tensor(circuit_params, dtype=torch.float32)
measurement_samples_tensor = torch.tensor(measurement_samples, dtype=torch.float32)

# RNN Model
class EncoderDecoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(EncoderDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.encoder_rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, batch_first=True)
        self.decoder_rnn = nn.RNN(input_size=output_size, hidden_size=hidden_size, batch_first=True)
        self.output_layer = nn.Linear(hidden_size, output_size)

    def forward(self, encoder_input, decoder_input):
        _, encoder_hidden = self.encoder_rnn(encoder_input)
        encoder_hidden = encoder_hidden.repeat(1, encoder_input.size(0), 1)
        decoder_output, _ = self.decoder_rnn(decoder_input, encoder_hidden)
        output = self.output_layer(decoder_output)
        return output

# Hyperparameters
input_size = 4 # how many used to describe average
hidden_size = 32 # hidden states internel to RNN
output_size = 6 # Number of possible measurement outputs (1 - 6) of qubits
learning_rate = 0.01 # learning rate of the RNN
num_epochs = 50 # repeats over the dataset
batch_size = 2  # Number of data points (one is: [circuit description] + [data element])

# Instantiate the model
model = EncoderDecoderRNN(input_size, hidden_size, output_size)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Preparing inputs and outputs for the model
encoder_input = circuit_params_tensor

# Number of measurement samples can vary, we dynamically adjust the decoder input
num_measurements = measurement_samples_tensor.size(1)
#print("num_measurements", num_measurements)

decoder_input = torch.zeros(encoder_input.size(0), num_measurements, output_size)  # Adjusted shape

# Create a TensorDataset to hold your inputs and targets
# this groups the dataset like: (tensor([0.2000, 0.4000, 0.7000, 0.6000]), tensor([[0., 1., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 0.]]))
dataset = TensorDataset(circuit_params_tensor, measurement_samples_tensor)

# Create a DataLoader to handle batching
# this batches the data into the RNN
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training loop
for epoch in range(num_epochs):
    for batch in dataloader:
        # Unpack the batch data
        encoder_batch_input, decoder_batch_target = batch
        
        # Prepare decoder input (adjusting the shape for batch and measurements)
        decoder_batch_input = torch.zeros_like(decoder_batch_target, dtype=torch.float32)
        
        # Forward pass
        output = model(encoder_batch_input, decoder_batch_input)
        
        # Calculate the loss
        loss = criterion(output, decoder_batch_target)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    if (epoch+1) % 1 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

For visualizing the output (repeatedly measuring to get a distribution), you can visualize it with this code to see that it is not working.

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

column_averages = np.mean(measurement_samples, axis=0)
num_qubits, num_states = column_averages.shape

single_circuit_param = torch.tensor([[0.2000, 0.4000, 0.7000, 0.6000]], dtype=torch.float32)

def take_multiple_measurements(model, circuit_param, num_qubits, num_measurements):
    measurements = torch.zeros((num_measurements, num_qubits))
    for i in range(num_measurements):
        measurements[i] = take_measurement(model, circuit_param, num_qubits)
    return measurements

def take_measurement(model, circuit_param, num_qubits):
    with torch.no_grad():
        # Initialize the input for the first qubit
        qubit_input = torch.zeros(1, num_qubits, output_size)
        
        # This will store the measurements for all qubits
        measurement = torch.zeros(num_qubits, dtype=torch.long)
        
        for qubit_idx in range(num_qubits):
            # Get the probability distribution of the current qubit given the circuit_param and previous qubits
            qubit_probs = model(circuit_param, qubit_input).squeeze(0)[qubit_idx]
            qubit_probs_softmax = nn.functional.softmax(qubit_probs, dim=-1)
            
            # Sample a value for the current qubit based on the probability distribution
            qubit_state = torch.multinomial(qubit_probs_softmax, 1).item()
            measurement[qubit_idx] = qubit_state
            
            # Update the input for the next qubit
            qubit_input[0, qubit_idx, :] = 0  # Reset the previous state
            qubit_input[0, qubit_idx, qubit_state] = 1.0  # Set the new state
            
        return measurement

# Function to get the distributions for each qubit after multiple measurements
def get_qubit_state_distributions(measurements, num_states):
    num_measurements, num_qubits = measurements.shape
    distributions = torch.zeros((num_qubits, num_states))
    
    for qubit_idx in range(num_qubits):
        for state in range(num_states):
            distributions[qubit_idx, state] = torch.sum(measurements[:, qubit_idx] == state) / num_measurements
    
    return distributions.numpy()

# Take multiple measurements
num_measurements = 100
num_states = 6  # Assuming there are 6 possible states (0 through 5)
measurements = take_multiple_measurements(model, single_circuit_param, num_qubits, num_measurements)

# Calculate the distributions
distributions = get_qubit_state_distributions(measurements, num_states)

x = np.arange(num_states)  # x-coordinates for the bar chart

# Plotting
fig, axes = plt.subplots(num_qubits, 1, figsize=(10, 8))

for i in range(num_qubits):
    axes[i].bar(x - 0.2, distributions[i], width=0.4, label='Measured Distribution')
    axes[i].bar(x + 0.2, column_averages[i], width=0.4, label='Average Prediction')
    axes[i].set_title(f'Qubit {i+1}')
    axes[i].set_xticks(x)
    axes[i].set_xticklabels([f'State {j}' for j in range(num_states)])
    axes[i].legend()

plt.tight_layout()
plt.show()

enter image description here

enter image description here

0

There are 0 answers