TTT-E2E
Context: TTT and Fast Weights
Early work by Mikolov et al. in
Then, the idea of "fast" weights (FWPs) updated on relevant data dates back to the
Prior versions of Test-Time Training focused on Key-Value Binding, compressing context into weights by reconstructing values from keys.
TTT-E2E: End-to-End Test-Time Training for Long Context
TTT-E2E removes these layer-wise losses in favor of a single end-to-end objective..
Instead of designing complex new layers, the authors utilize a standard Transformer with Sliding-Window Attention (SWA) and treat the long context as a training set for the model's weights.
TTT-E2E: The model performs next-token prediction on its own context at test time, effectively "compressing" the context into its MLP weights.
Meta-Learning (Outer Loop): To ensure the model is "ready" to learn at test time, the initial weights are optimized via meta-learning at training time
Hybrid Efficiency: By combining SWA for short-term memory and TTT-updated weights for long-term memory, the model achieves linear computational complexity.
Implementation Details
Only MLP layers are updated. Embedding and attention layers are frozen.
Only the last
1 /4 of the blocks are updated to balance compute cost and storage capacity.There are two MLPs per block—one static MLP to preserve pre-trained knowledge and one dynamic MLP for TTT.
Mathematics: The Test-Time Objective
The process is divided into an Inner Loop (test-time) and an Outer Loop (training-time meta-learning).
Inner Loop (Test-Time Update): For each token step
t , the next-token prediction loss(ℓ_t) , is calculated using the current weights(W_t-1) :(ℓ_t)(W)=CE(ƒ*((x_t-1);W),(x_t)) The weights are then updated via gradient descent:
(W_t)=(W_t-1)+η∇(ℓ_t)((W_t-1)) However, to improve parallelism and stability, the gradient descent is actually performed in mini-batches. GIven a context sequence
(x_1),…,(x_T) , the sequence is partitioned intoT/b batches. For each batchi=1,…,T/b , the weightsW are updated as follows:(W_i)=(W_i-1)-η/b*(∑_t*=(i-(1)*b+1^ib)(∇(ℓ_t)((W_i-1))) The update uses the average gradient of the next-token prediction loss
(ℓ_t) calculated over all tokens within the current mini-batch. Crucially for performance, all gradients within a mini-batch are computed using the same weights((W_i-1)) ) from the previous batch, rather than updating them token-by-token.
Interaction with SWA: Because the weights are static during the processing of a mini-batch, the model would behave like a "bigram" for tokens inside that batch if it relied solely on the MLP for memory. The architecture includes a SWA layer with a window size
k (wherek≥(b) . This ensures that the model can use its short-term (attention) memory to "remember" the context within the current mini-batch before the TTT update has a chance to compress that information into the weightsOuter Loop (Meta-Learning): During training, the initialization
(W_0) is optimized to minimize the total loss accumulated after the inner-loop updates:L*((W_0);X)=1/T*(∑_t=1^T)((ℓ_t)((W_t-1))) For mini-batches, the loss is updated to:
L((W_0):X)=1/T*(∑_i=1^T/b)((∑_t=(i-1)*b+1^ib)((ℓ_t)((W_i-1))))
Impact: Performance and Scaling
The paper focuses on scaling properties for
Context Scaling: While RNNs like Mamba2 degrade in performance as context grows, TTT-E2E maintains a constant advantage over full attention, effectively scaling in the same manner as the full Transformer.
Constant Latency: Unlike full attention, which has
O(T) decoding complexity, TTT-E2E hasO(1) decoding latency. At128 context, it is2.7 faster than full attention.Compression Trade-off: While TTT-E2E excels at language modeling, it underperforms in Needle in the Haystack tasks compared to full attention as compression naturally discards details that full attention caches explicitly.
Discussions
Titan/MIRAS paper Differences
In the Titans architecture and the broader MIRAS framework, the mechanism of learning at test time is driven by "surprisal," an info-theoretic signal that determines if something is worth committing to memory. This treats the model's parameters as a dynamic storage system.
Example: If during pre-training phase, the model has seen an entire encyclopedia and model learns to memorize that encyclopedia in its memory (Persistent Memory). In prefill, if we show the same encyclopedia to the model, the gradient is very small and the model wouldn't receive any update.
Inner-Loop Objective:
Titans uses Associative Memory where it commits information to memory by optimizing an associative mapping of keys
((k_t)) to values((v_t)) . Theℓ at timet measures how well the current memory state(M_t-1) can reconstruct the value from the key:ℓ((M_t-1);(x_t))=‖(M_t-1)((k_t))-(v_t)‖22 TTT-E2E uses Next-Token Prediction where it updates weights by directly optimizing the main language modeling task (cross-entropy
(CE) on the current context:(ℓ_t)(W)=CE(ƒ*((x_t-1);W),(x_t))
Update Dynamics:
Titans refines the update using momentum
((S_t)) to capture context flow and adaptive weight decay((α_t)) ) to manage capacity. This allows the model to consider "past surprise" alongside "momentary surprise":(M_t)=(1-(α_t))*(M_t-1)+(S_t) (S_t)=η*(S_t-1)-(θ_t)∇ℓ((M_t-1);(x_t)) TTT-E2E standard gradient descent, in mini-batches
(b) to improve hardware utilization (see Mathematics section).
Analogy:
Titans is closer to "noise-cancelling" memory whereas TTT-E2E is like an "endless learner".
Test-Time Training on Decoding
What about test-time training on self-generated tokens (which may not be the default for this method)?
How much effort is put into the meta-learning to be good for both pre-fill and decoding? Since we base the gradient descent / cross-entropy loss on
(x_t+1) which is "generated from nature."