API Reference: Model
This section provides details on the core model class, HierarchicalReasoningModel_ACTV1
, and its configuration.
HierarchicalReasoningModel_ACTV1Config
This Pydantic model defines all architectural hyperparameters for the model.
Source: models/hrm/hrm_act_v1.py
class HierarchicalReasoningModel_ACTV1Config(BaseModel):
batch_size: int
seq_len: int
puzzle_emb_ndim: int = 0
num_puzzle_identifiers: int
vocab_size: int
H_cycles: int
L_cycles: int
H_layers: int
L_layers: int
# Transformer config
hidden_size: int
expansion: float
num_heads: int
pos_encodings: str
rms_norm_eps: float = 1e-5
rope_theta: float = 10000.0
# Halting Q-learning config
halt_max_steps: int
halt_exploration_prob: float
forward_dtype: str = "bfloat16"
Parameters:
batch_size
: The number of samples per GPU.seq_len
: The length of the input sequence (e.g., 900 for a 30x30 grid).puzzle_emb_ndim
: The dimensionality of the learnable per-puzzle embeddings.num_puzzle_identifiers
: The total number of unique puzzles in the dataset.vocab_size
: The size of the token vocabulary.H_cycles
,L_cycles
: The number of recurrent updates for the high-level and low-level modules per step.H_layers
,L_layers
: The number of Transformer blocks in each module.hidden_size
,expansion
,num_heads
: Standard Transformer layer dimensions.pos_encodings
: Type of positional encoding to use (rope
orlearned
).halt_max_steps
: Maximum number of computation steps for the ACT mechanism.halt_exploration_prob
: Probability of taking a random number of steps during training to encourage exploration.forward_dtype
: The data type for forward pass computations (e.g.,bfloat16
).
HierarchicalReasoningModel_ACTV1
This is the main nn.Module
that wraps the core reasoning logic with the Adaptive Computation Time (ACT) mechanism.
Source: models/hrm/hrm_act_v1.py
class HierarchicalReasoningModel_ACTV1(nn.Module):
def __init__(self, config_dict: dict):
# ...
def initial_carry(self, batch: Dict[str, torch.Tensor]) -> HierarchicalReasoningModel_ACTV1Carry:
# ...
def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
# ...
Key Methods:
-
__init__(self, config_dict: dict)
- Initializes the model. Takes a dictionary that will be parsed by
HierarchicalReasoningModel_ACTV1Config
.
- Initializes the model. Takes a dictionary that will be parsed by
-
initial_carry(self, batch)
- Creates the initial state (or "carry") for the recurrent model. The carry holds the internal states of the H and L modules (
z_H
,z_L
), the current step count, and the halt status for each item in the batch. - Arguments:
batch
: A dictionary of input tensors, used to determine the batch size.
- Returns: An initial
HierarchicalReasoningModel_ACTV1Carry
object.
- Creates the initial state (or "carry") for the recurrent model. The carry holds the internal states of the H and L modules (
-
forward(self, carry, batch)
- Performs one step of the recurrent computation.
- Arguments:
carry
: TheHierarchicalReasoningModel_ACTV1Carry
from the previous step.batch
: A dictionary containing the input data (inputs
,puzzle_identifiers
, etc.).
- Returns: A tuple containing:
new_carry
: The updated carry object for the next step.outputs
: A dictionary of output tensors, including:logits
: The raw prediction logits for the output sequence.q_halt_logits
: The model's learned Q-value for halting.q_continue_logits
: The model's learned Q-value for continuing.