translation invariance of Rotary Embedding

84 views Asked by At

RoPE (rotary position encoding), the positional encoding used in Llama, is a relative position encoding. The attention scores are bound to be decided by the relative distance between tokens only. Therefore, the position indices in the Llama model should have shift invariance.

But when I add all the position indices by a given number (e.g. 1000), the model's performance will actually be affected (which can be told from the performance difference on a downstream task). But why? Shouldn't it be shift invariance?

1

There are 1 answers

0
tepsijash On

In theory, it should be shift invariant, but in practice it is very dependent on the numerical precision as it involves multiple nonlinear operations involving sines and cosines.

Lets look at a standard RoPE implementation in NumPy (unbatched for simplicity):

sequence_length = 5

positions = np.arange(sequence_length)

# Construct some dummy query and key arrays of shape [S, H].
hidden_size = 4
q = np.arange(hidden_size)[None].repeat(sequence_length, axis=0)
k = np.arange(hidden_size)[None].repeat(sequence_length, axis=0)

frequencies = 1e4 ** (np.arange(0, hidden_size, 2.0) / hidden_size)
inputs = positions[:, None] / frequencies[None, :]
sin, cos = np.sin(inputs), np.cos(inputs)

q1, q2 = np.split(q, 2, axis=-1)
q_rot = np.concatenate([q1 * cos - q2 * sin, q1 * sin + q2 * cos], axis=-1)

k1, k2 = np.split(k, 2, axis=-1)
k_rot = np.concatenate([k1 * cos - k2 * sin, k1 * sin + k2 * cos], axis=-1)

np.einsum('td,Td->tT', q_rot, k_rot)

>>> array([[14.        , 12.16070923,  8.33341272],
          [12.16070923, 14.        , 12.16070923],
          [ 8.33341272, 12.16070923, 14.        ]])

The above code runs in double precision which is NumPy default and shifting positions (e.g. positions = 100 + np.arange(sequence_length)) gives the exact same result.

After reducing precision to float32, the differences are still negligible (within 1e-7), but going down to float16 the difference becomes relatively large (~1e-2).

Since large models such as LLaMA typically run with mixed precision, this would explain your observation.