Source code for nami_toys.checkerboard
from __future__ import annotations
from dataclasses import dataclass
import torch
from .dataset import ToyDataset
[docs]
@dataclass(frozen=True)
class Checkerboard:
"""Uniform density on the "on" cells of a 2-D checkerboard.
Parameters
----------
cells : int
Number of cells per side (total grid is *cells* x *cells*).
bound : float
The grid spans [-bound, bound] along each axis.
"""
cells: int = 4
bound: float = 2.0
[docs]
def generate(
self,
n: int,
*,
generator: torch.Generator | None = None,
) -> ToyDataset:
"""Draw *n* samples uniformly from the filled squares."""
on_cells = [
(i, j)
for i in range(self.cells)
for j in range(self.cells)
if (i + j) % 2 == 0
]
centres = torch.tensor(on_cells, dtype=torch.float) # (K, 2)
idx = torch.randint(len(on_cells), (n,), generator=generator)
offsets = torch.empty(n, 2).uniform_(0, 1, generator=generator)
cell_size = 2.0 * self.bound / self.cells
x = -self.bound + (centres[idx] + offsets) * cell_size
return ToyDataset(
x=x,
meta={"cells": self.cells, "bound": self.bound},
)