I trained an Efficienetb0 Model by adding two multi-head attention layers. But when I trainning the model I get the following warning.
Epoch: 1 | train_loss: 2.0100 | train_acc: 0.2708 | validation_loss: 1.7110 | validation_acc: 0.4147
/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py:1084: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Tensor-likes are not close!
Mismatched elements: 3 / 320 (0.9%)
Greatest absolute difference: 0.00018024444580078125 at index (21, 2) (up to 1e-05 allowed)
Greatest relative difference: 7.841616138727894e-05 at index (20, 9) (up to 1e-05 allowed)
_check_trace(
This is my attention layer. How Do I fix this warning and Improve the training accuracy?
class MultiHeadAttention(nn.Module):
def __init__(self, input_dim, num_heads):
super().__init__()
self.input_dim=input_dim
self.num_heads = num_heads
self.head_dim = input_dim // num_heads
#print("input dimention", input_dim)
#print("head dimention", self.head_dim)
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(input_dim, 3 * input_dim, bias=False)
# output projection
self.c_proj = nn.Linear(input_dim, input_dim, bias=False)
def forward(self, x):
z=x.size()
x = torch.reshape(x, (x.size(0), -1, x.size(1)))
batch_size, seq_len, c = x.size()
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.input_dim, dim=2)
#head_size = C// self.n_head
k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
# efficient attention using Flash Attention CUDA kernels
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.c_proj(y)
attention_output = torch.reshape(y, (z[0], z[1], z[2], z[3]))
return attention_output