API Reference: Model
This section provides details on the core model class, HierarchicalReasoningModel_ACTV1, and its configuration.
HierarchicalReasoningModel_ACTV1Config
This Pydantic model defines all architectural hyperparameters for the model.
Source: models/hrm/hrm_act_v1.py
class HierarchicalReasoningModel_ACTV1Config(BaseModel):
batch_size: int
seq_len: int
puzzle_emb_ndim: int = 0
num_puzzle_identifiers: int
vocab_size: int
H_cycles: int
L_cycles: int
H_layers: int
L_layers: int
# Transformer config
hidden_size: int
expansion: float
num_heads: int
pos_encodings: str
rms_norm_eps: float = 1e-5
rope_theta: float = 10000.0
# Halting Q-learning config
halt_max_steps: int
halt_exploration_prob: float
forward_dtype: str = "bfloat16"
Parameters:
batch_size: The number of samples per GPU.seq_len: The length of the input sequence (e.g., 900 for a 30x30 grid).puzzle_emb_ndim: The dimensionality of the learnable per-puzzle embeddings.num_puzzle_identifiers: The total number of unique puzzles in the dataset.vocab_size: The size of the token vocabulary.H_cycles,L_cycles: The number of recurrent updates for the high-level and low-level modules per step.H_layers,L_layers: The number of Transformer blocks in each module.hidden_size,expansion,num_heads: Standard Transformer layer dimensions.pos_encodings: Type of positional encoding to use (ropeorlearned).halt_max_steps: Maximum number of computation steps for the ACT mechanism.halt_exploration_prob: Probability of taking a random number of steps during training to encourage exploration.forward_dtype: The data type for forward pass computations (e.g.,bfloat16).
HierarchicalReasoningModel_ACTV1
This is the main nn.Module that wraps the core reasoning logic with the Adaptive Computation Time (ACT) mechanism.
Source: models/hrm/hrm_act_v1.py
class HierarchicalReasoningModel_ACTV1(nn.Module):
def __init__(self, config_dict: dict):
# ...
def initial_carry(self, batch: Dict[str, torch.Tensor]) -> HierarchicalReasoningModel_ACTV1Carry:
# ...
def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
# ...
Key Methods:
-
__init__(self, config_dict: dict)- Initializes the model. Takes a dictionary that will be parsed by
HierarchicalReasoningModel_ACTV1Config.
- Initializes the model. Takes a dictionary that will be parsed by
-
initial_carry(self, batch)- Creates the initial state (or "carry") for the recurrent model. The carry holds the internal states of the H and L modules (
z_H,z_L), the current step count, and the halt status for each item in the batch. - Arguments:
batch: A dictionary of input tensors, used to determine the batch size.
- Returns: An initial
HierarchicalReasoningModel_ACTV1Carryobject.
- Creates the initial state (or "carry") for the recurrent model. The carry holds the internal states of the H and L modules (
-
forward(self, carry, batch)- Performs one step of the recurrent computation.
- Arguments:
carry: TheHierarchicalReasoningModel_ACTV1Carryfrom the previous step.batch: A dictionary containing the input data (inputs,puzzle_identifiers, etc.).
- Returns: A tuple containing:
new_carry: The updated carry object for the next step.outputs: A dictionary of output tensors, including:logits: The raw prediction logits for the output sequence.q_halt_logits: The model's learned Q-value for halting.q_continue_logits: The model's learned Q-value for continuing.