Source code for nami_toys.dataset
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
import torch
[docs]
@dataclass
class ToyDataset:
"""Lightweight container for toy simulation data.
Parameters
----------
x : torch.Tensor
Data points with shape ``(N, d)``.
y : torch.Tensor | None
Optional integer labels with shape ``(N,)``.
meta : dict[str, Any]
Metadata from the generation call.
"""
x: torch.Tensor
y: torch.Tensor | None = None
meta: dict[str, Any] = field(default_factory=dict)
def __len__(self) -> int:
return self.x.shape[0]
def __repr__(self) -> str:
n, d = self.x.shape
labels = (
f", labels={set(self.y.unique().tolist())}" if self.y is not None else ""
)
meta = f", meta={self.meta}" if self.meta else ""
return f"ToyDataset(n={n}, d={d}{labels}{meta})"
[docs]
def subset(self, mask: torch.Tensor) -> ToyDataset:
"""Return a new dataset containing only entries where *mask* is True."""
y_sub = self.y[mask] if self.y is not None else None
return ToyDataset(x=self.x[mask], y=y_sub, meta={**self.meta})
[docs]
def limit(self, n: int) -> ToyDataset:
"""Return a dataset containing at most *n* entries."""
n = min(n, len(self))
y_lim = self.y[:n] if self.y is not None else None
return ToyDataset(x=self.x[:n], y=y_lim, meta={**self.meta})