This is a much simplified network from a real problem that, to me, has a surprising INability to learn a simple task via backprop, ie, it can't overfit or learn at all. This simple version has come at the cost of many gray hairs, and much simplification, and I truly ignored the option that something this simple would fail to learn, until finally I tested it, and, sure enough it fails.
(Runnable version at bottom)
It is a FFNN tasked with turning input vectors into output vectors, all drawn from randn, and all with the same dimension.
inputvectors are compared via cosine similarity with a model parameter calledpredicate.if
predicateapproaches1.0the network should output it's model parameter vectortrue. Otherwise, output parameterfalse.lossis defined as1 - cosine_similarity(output, expected)
Note, if you cheat and set the model's internal vectors to the expected values, the model performs with expectedly great accuracy and low loss, and learning doesn't degrade (too much), so there is a stable fixed point in the loss landscape.
Here's the model:
class Sim(nn.Module):
def __init__(self, ):
super(Sim, self).__init__()
self.predicate = nn.Parameter(torch.randn(VEC_SIZE) * 1e-2)
self.true = nn.Parameter(torch.randn(VEC_SIZE) * 1e-2)
self.false = nn.Parameter(torch.randn(VEC_SIZE) * 1e-2)
def forward(self, input):
# input : [batch_size, vec_size]
# output : [batch_size, vec_size]
batch_size = input.size(0)
predicate = self.predicate.unsqueeze(0)
matched = torch.cosine_similarity(predicate, input, dim=1)
return (
einsum('v, b -> bv', self.true, matched) +
einsum('v, b -> bv', self.false, 1 - matched)
)
The self.predicate should eventually approximate predicate_vec, and the same for self.true and self.false. If it learns these expected values, it should have the lowest loss. Here's the data:
predicate_vec = torch.randn(VEC_SIZE)
true_vec = torch.randn(VEC_SIZE)
false_vec = torch.randn(VEC_SIZE)
dataset = (
# positives
[(predicate_vec, true_vec) for _ in range(N_DATASET_POS)] +
# negatives
[(torch.randn(VEC_SIZE), false_vec) for _ in range(N_DATASET_POS)]
)
dataset_loader = torch.utils.data.DataLoader(dataset,
batch_size=BATCH_SIZE,
shuffle=True)
When this runs, the output looks something like this, where the loss does decrease some, but accuracy at the task doesn't improve at all.
Epoch 260, Training Loss: 0.005365743
Epoch 270, Training Loss: 0.005237889
Epoch 280, Training Loss: 0.005211671
Epoch 290, Training Loss: 0.005140129
Epoch 300, Training Loss: 0.005135684
SIMILARITY OF LEARNED VECS: p=-0.352 t=-0.244 f=0.266
If I "cheat" and set the model's internal params to the known-good values, the loss plummets, and the params are robust to the training procedure. Notice, the cos-sim of predicate, true, and false are all around 1.0 as expected:
Epoch 570, Training Loss: 0.003478402
Epoch 580, Training Loss: 0.003488328
Epoch 590, Training Loss: 0.003480982
Epoch 600, Training Loss: 0.003456787
SIMILARITY OF LEARNED VECS: p=0.990 t=0.994 f=0.992
Runnable version:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import f1_score
from torch import einsum
import pdb
import torch
import numpy as np
import random
import string
from datasets import Dataset
torch.set_printoptions(precision=3)
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
DEVICE = 'cuda'
##########
# Params
NUM_EPOCHS = 1000
BATCH_SIZE = 10
GRAD_CLIP = 10.0
LR = 1e-2
WD = 0
N_DATASET_POS = 100
N_DATASET_NEG = 100
VEC_SIZE = 128
##########
# Data
# We'll try and find these same vectors being learned within the network.
predicate_vec = torch.randn(VEC_SIZE)
true_vec = torch.randn(VEC_SIZE)
false_vec = torch.randn(VEC_SIZE)
dataset = (
# positives
[(predicate_vec, true_vec) for _ in range(N_DATASET_POS)] +
# negatives
[(torch.randn(VEC_SIZE), false_vec) for _ in range(N_DATASET_POS)]
)
dataset_loader = torch.utils.data.DataLoader(dataset,
batch_size=BATCH_SIZE,
shuffle=True)
##########
# Model
class Sim(nn.Module):
def __init__(self, ):
super(Sim, self).__init__()
self.predicate = nn.Parameter(torch.randn(VEC_SIZE) * 1e-2)
self.true = nn.Parameter(torch.randn(VEC_SIZE) * 1e-2)
self.false = nn.Parameter(torch.randn(VEC_SIZE) * 1e-2)
def forward(self, input):
# input : [batch_size, vec_size]
# output : [batch_size, vec_size]
batch_size = input.size(0)
predicate = self.predicate.unsqueeze(0)
matched = torch.cosine_similarity(predicate, input, dim=1)
return (
einsum('v, b -> bv', self.true, matched) +
einsum('v, b -> bv', self.false, 1 - matched)
)
def run_epoch(data_loader, model, optimizer):
model.train()
total_loss = 0
all_predictions = []
all_true_values = []
for batch in data_loader:
input_tensor, target_tensor = batch
input_tensor = input_tensor.to(DEVICE)
target_tensor = target_tensor.to(DEVICE)
model.zero_grad()
output = model(input_tensor)
loss = (1 - torch.cosine_similarity(target_tensor, output.unsqueeze(1))).mean()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=GRAD_CLIP)
optimizer.step()
total_loss += loss.item() / target_tensor.size(1)
return total_loss / len(data_loader)
##########
# Go
model = Sim()
model = model.to(DEVICE)
##########
# Cheat and set to expected value
# with torch.no_grad():
# model.predicate[:] = predicate_vec
# model.true[:] = true_vec
# model.false[:] = false_vec
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
def check(model):
''' Check how much internal vectors are aligning to known vectors. '''
p = torch.cosine_similarity(model.predicate.to(DEVICE), predicate_vec.to(DEVICE), dim=0)
t = torch.cosine_similarity(model.true.to(DEVICE), true_vec.to(DEVICE), dim=0)
f = torch.cosine_similarity(model.false.to(DEVICE), false_vec.to(DEVICE), dim=0)
print(f'SIMILARITY OF LEARNED VECS: p={p:>.3f} t={t:>.3f} f={f:>.3f}')
for epoch in range(NUM_EPOCHS):
loss = run_epoch(dataset_loader, model, optimizer)
if epoch % 10 == 0:
print(f'Epoch {epoch}, Training Loss: {loss:>.9f}')
if epoch % 100 == 0:
check(model)