API Reference: Model

This section provides details on the core model class, HierarchicalReasoningModel_ACTV1, and its configuration.

HierarchicalReasoningModel_ACTV1Config

This Pydantic model defines all architectural hyperparameters for the model.

Source: models/hrm/hrm_act_v1.py

class HierarchicalReasoningModel_ACTV1Config(BaseModel):
    batch_size: int
    seq_len: int
    puzzle_emb_ndim: int = 0
    num_puzzle_identifiers: int
    vocab_size: int

    H_cycles: int
    L_cycles: int

    H_layers: int
    L_layers: int

    # Transformer config
    hidden_size: int
    expansion: float
    num_heads: int
    pos_encodings: str

    rms_norm_eps: float = 1e-5
    rope_theta: float = 10000.0

    # Halting Q-learning config
    halt_max_steps: int
    halt_exploration_prob: float

    forward_dtype: str = "bfloat16"

Parameters:

  • batch_size: The number of samples per GPU.
  • seq_len: The length of the input sequence (e.g., 900 for a 30x30 grid).
  • puzzle_emb_ndim: The dimensionality of the learnable per-puzzle embeddings.
  • num_puzzle_identifiers: The total number of unique puzzles in the dataset.
  • vocab_size: The size of the token vocabulary.
  • H_cycles, L_cycles: The number of recurrent updates for the high-level and low-level modules per step.
  • H_layers, L_layers: The number of Transformer blocks in each module.
  • hidden_size, expansion, num_heads: Standard Transformer layer dimensions.
  • pos_encodings: Type of positional encoding to use (rope or learned).
  • halt_max_steps: Maximum number of computation steps for the ACT mechanism.
  • halt_exploration_prob: Probability of taking a random number of steps during training to encourage exploration.
  • forward_dtype: The data type for forward pass computations (e.g., bfloat16).

HierarchicalReasoningModel_ACTV1

This is the main nn.Module that wraps the core reasoning logic with the Adaptive Computation Time (ACT) mechanism.

Source: models/hrm/hrm_act_v1.py

class HierarchicalReasoningModel_ACTV1(nn.Module):
    def __init__(self, config_dict: dict):
        # ...

    def initial_carry(self, batch: Dict[str, torch.Tensor]) -> HierarchicalReasoningModel_ACTV1Carry:
        # ...

    def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
        # ...

Key Methods:

  • __init__(self, config_dict: dict)

    • Initializes the model. Takes a dictionary that will be parsed by HierarchicalReasoningModel_ACTV1Config.
  • initial_carry(self, batch)

    • Creates the initial state (or "carry") for the recurrent model. The carry holds the internal states of the H and L modules (z_H, z_L), the current step count, and the halt status for each item in the batch.
    • Arguments:
      • batch: A dictionary of input tensors, used to determine the batch size.
    • Returns: An initial HierarchicalReasoningModel_ACTV1Carry object.
  • forward(self, carry, batch)

    • Performs one step of the recurrent computation.
    • Arguments:
      • carry: The HierarchicalReasoningModel_ACTV1Carry from the previous step.
      • batch: A dictionary containing the input data (inputs, puzzle_identifiers, etc.).
    • Returns: A tuple containing:
      1. new_carry: The updated carry object for the next step.
      2. outputs: A dictionary of output tensors, including:
        • logits: The raw prediction logits for the output sequence.
        • q_halt_logits: The model's learned Q-value for halting.
        • q_continue_logits: The model's learned Q-value for continuing.