API Reference: Training & Configuration
This section covers the Pydantic configuration classes used by the main training script, pretrain.py
.
PretrainConfig
This is the top-level configuration model that orchestrates the entire training run.
Source: pretrain.py
class PretrainConfig(pydantic.BaseModel):
# Config
arch: ArchConfig
# Data
data_path: str
# Hyperparams
global_batch_size: int
epochs: int
lr: float
lr_min_ratio: float
lr_warmup_steps: int
weight_decay: float
beta1: float
beta2: float
# Puzzle embedding
puzzle_emb_lr: float
puzzle_emb_weight_decay: float
# Names
project_name: Optional[str] = None
run_name: Optional[str] = None
checkpoint_path: Optional[str] = None
# Extras
seed: int = 0
checkpoint_every_eval: bool = False
eval_interval: Optional[int] = None
eval_save_outputs: List[str] = []
Parameters:
arch
: A nestedArchConfig
object specifying the model architecture.data_path
: The path to the processed dataset.global_batch_size
: The total batch size across all devices.epochs
: Total number of training epochs.lr
,lr_min_ratio
,lr_warmup_steps
: Parameters for the cosine learning rate scheduler.weight_decay
,beta1
,beta2
: Adam optimizer parameters.puzzle_emb_lr
,puzzle_emb_weight_decay
: Specific learning rate and weight decay for the sparse puzzle embeddings.project_name
,run_name
: Names for the Weights & Biases project and run.checkpoint_path
: Directory to save model checkpoints.eval_interval
: Frequency (in epochs) for running evaluation.eval_save_outputs
: List of keys to save from the model output during evaluation.
ArchConfig
Defines the model architecture and its associated loss function.
Source: pretrain.py
class ArchConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra='allow')
name: str
loss: LossConfig
Parameters:
name
: The identifier for the model class (e.g.,hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1
).loss
: A nestedLossConfig
object.extra='allow'
: Allows passing additional, unspecified parameters that will be forwarded to the model's constructor.
LossConfig
Defines the loss head and its parameters.
Source: pretrain.py
class LossConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra='allow')
name: str
Parameters:
name
: The identifier for the loss head class (e.g.,losses@ACTLossHead
).