Model Architecture: A Deep Dive
The Hierarchical Reasoning Model (HRM) is a novel recurrent architecture designed for complex, sequential reasoning. Its design is inspired by the hierarchical and multi-timescale processing observed in the human brain.
Overview
At its core, HRM consists of two interconnected recurrent modules that operate at different conceptual and temporal scales:
- High-Level Module (
H_level): This module acts as a planner. It operates more slowly, processing information to form abstract, high-level plans or sub-goals. - Low-Level Module (
L_level): This module acts as an executor. It operates more quickly, performing detailed, fine-grained computations based on the current plan from theH_leveland the raw input.
The entire model is implemented in models/hrm/hrm_act_v1.py.
Key Architectural Components
1. Hierarchical Recurrence
The interaction between the two modules forms a nested loop. In each forward step of the model, the H_level and L_level update their internal states (z_H and z_L respectively) over multiple cycles.
- The
H_levelstatez_His passed to theL_levelto guide its computation. - The
L_levelstatez_Lis passed back to theH_levelto update the plan based on the results of the detailed computation.
This process is repeated for a configurable number of cycles (H_cycles and L_cycles), allowing the model to achieve significant computational depth within a single logical step.
2. Adaptive Computation Time (ACT)
HRM uses Adaptive Computation Time (ACT) to dynamically determine how many recurrent steps are needed to solve a given problem. Instead of running for a fixed number of iterations, the model learns a halting policy.
- Q-Heads: A dedicated linear layer (
q_head) predicts two values from theH_level's state:q_halt_logitsandq_continue_logits. - Halting Condition: The model halts if
q_halt_logitsis greater thanq_continue_logits, or if it reaches a maximum number of steps (halt_max_steps). - Training: The halting policy is trained using Q-learning. The
ACTLossHeadinmodels/losses.pycomputes a loss that encourages the model to halt if its prediction is correct and continue otherwise. This allows the model to allocate more computation to harder problems.
3. Per-Puzzle Embeddings
To allow for task-specific adaptation, HRM can learn a unique embedding vector for each puzzle.
CastedSparseEmbedding: This custom layer, defined inmodels/sparse_embedding.py, stores an embedding for every unique puzzle identifier in the dataset.- Input Injection: During the forward pass, the embedding for the current puzzle is retrieved and prepended to the sequence of input token embeddings. This provides the model with a global context about the specific puzzle it is solving, allowing it to activate a specialized reasoning strategy.
- Custom Optimizer: These sparse puzzle embeddings are trained with a custom
CastedSparseEmbeddingSignSGD_Distributedoptimizer, which is efficient for updating highly sparse parameters.
4. Transformer Blocks
Both the H_level and L_level modules are composed of a stack of standard Transformer decoder-style blocks. Each block, defined in models/hrm/hrm_act_v1.py as HierarchicalReasoningModel_ACTV1Block, consists of:
- Self-Attention: Using an efficient implementation from FlashAttention (
models/layers.py). - SwiGLU Feed-Forward Network: A modern variant of the FFN.
- RMSNorm: For layer normalization.
- Rotary Position Embeddings (RoPE): To inject positional information.
This modular design combines the power of established Transformer components with the novel hierarchical recurrent structure.