Model Architecture: A Deep Dive
The Hierarchical Reasoning Model (HRM) is a novel recurrent architecture designed for complex, sequential reasoning. Its design is inspired by the hierarchical and multi-timescale processing observed in the human brain.
Overview
At its core, HRM consists of two interconnected recurrent modules that operate at different conceptual and temporal scales:
- High-Level Module (
H_level
): This module acts as a planner. It operates more slowly, processing information to form abstract, high-level plans or sub-goals. - Low-Level Module (
L_level
): This module acts as an executor. It operates more quickly, performing detailed, fine-grained computations based on the current plan from theH_level
and the raw input.
The entire model is implemented in models/hrm/hrm_act_v1.py
.
Key Architectural Components
1. Hierarchical Recurrence
The interaction between the two modules forms a nested loop. In each forward step of the model, the H_level
and L_level
update their internal states (z_H
and z_L
respectively) over multiple cycles.
- The
H_level
statez_H
is passed to theL_level
to guide its computation. - The
L_level
statez_L
is passed back to theH_level
to update the plan based on the results of the detailed computation.
This process is repeated for a configurable number of cycles (H_cycles
and L_cycles
), allowing the model to achieve significant computational depth within a single logical step.
2. Adaptive Computation Time (ACT)
HRM uses Adaptive Computation Time (ACT) to dynamically determine how many recurrent steps are needed to solve a given problem. Instead of running for a fixed number of iterations, the model learns a halting policy.
- Q-Heads: A dedicated linear layer (
q_head
) predicts two values from theH_level
's state:q_halt_logits
andq_continue_logits
. - Halting Condition: The model halts if
q_halt_logits
is greater thanq_continue_logits
, or if it reaches a maximum number of steps (halt_max_steps
). - Training: The halting policy is trained using Q-learning. The
ACTLossHead
inmodels/losses.py
computes a loss that encourages the model to halt if its prediction is correct and continue otherwise. This allows the model to allocate more computation to harder problems.
3. Per-Puzzle Embeddings
To allow for task-specific adaptation, HRM can learn a unique embedding vector for each puzzle.
CastedSparseEmbedding
: This custom layer, defined inmodels/sparse_embedding.py
, stores an embedding for every unique puzzle identifier in the dataset.- Input Injection: During the forward pass, the embedding for the current puzzle is retrieved and prepended to the sequence of input token embeddings. This provides the model with a global context about the specific puzzle it is solving, allowing it to activate a specialized reasoning strategy.
- Custom Optimizer: These sparse puzzle embeddings are trained with a custom
CastedSparseEmbeddingSignSGD_Distributed
optimizer, which is efficient for updating highly sparse parameters.
4. Transformer Blocks
Both the H_level
and L_level
modules are composed of a stack of standard Transformer decoder-style blocks. Each block, defined in models/hrm/hrm_act_v1.py
as HierarchicalReasoningModel_ACTV1Block
, consists of:
- Self-Attention: Using an efficient implementation from FlashAttention (
models/layers.py
). - SwiGLU Feed-Forward Network: A modern variant of the FFN.
- RMSNorm: For layer normalization.
- Rotary Position Embeddings (RoPE): To inject positional information.
This modular design combines the power of established Transformer components with the novel hierarchical recurrent structure.