API Reference: Dataset

This section details the classes responsible for loading and managing the datasets.

PuzzleDatasetMetadata

This Pydantic model stores metadata about a processed dataset. It is saved as dataset.json in each split's directory.

Source: dataset/common.py

class PuzzleDatasetMetadata(pydantic.BaseModel):
    pad_id: int
    ignore_label_id: Optional[int]
    blank_identifier_id: int

    vocab_size: int
    seq_len: int
    num_puzzle_identifiers: int

    total_groups: int
    mean_puzzle_examples: float

    sets: List[str]

Parameters:

  • pad_id: The integer ID used for padding.
  • ignore_label_id: The integer ID for labels that should be ignored in the loss calculation.
  • blank_identifier_id: The ID for the padding puzzle identifier.
  • vocab_size: The total size of the vocabulary.
  • seq_len: The length of each sequence.
  • num_puzzle_identifiers: The total number of unique puzzle IDs.
  • total_groups: The number of puzzle groups (an original puzzle and its augmentations form a group).
  • mean_puzzle_examples: The average number of training examples (input/output pairs) per puzzle.
  • sets: A list of available subsets in this split (e.g., ["all"]).

PuzzleDatasetConfig

This Pydantic model holds the configuration for instantiating a PuzzleDataset.

Source: puzzle_dataset.py

class PuzzleDatasetConfig(pydantic.BaseModel):
    seed: int
    dataset_path: str
    global_batch_size: int
    test_set_mode: bool

    epochs_per_iter: int  # Batch X epochs in an iteration to reduce overhead.

    rank: int
    num_replicas: int

PuzzleDataset

This is the PyTorch IterableDataset used to load data for training and evaluation.

Source: puzzle_dataset.py

class PuzzleDataset(IterableDataset):
    def __init__(self, config: PuzzleDatasetConfig, split: str = "train"):
        # ...

    def __iter__(self):
        # ...

Key Functionality:

  • Lazy Loading: The .npy data files are loaded only when the dataset is first iterated over, using memory-mapping for efficiency.
  • Distributed Sampling: It correctly handles data sharding for distributed training based on rank and num_replicas.
  • Training vs. Test Mode (test_set_mode):
    • In training mode, it shuffles puzzle groups and samples from them to create batches. This ensures that the model sees a variety of augmentations.
    • In test mode, it iterates through the dataset sequentially, ensuring a consistent evaluation order.