Source code for nami_toys.moons
from __future__ import annotations
import math
from dataclasses import dataclass
import torch
from .dataset import ToyDataset
[docs]
@dataclass(frozen=True)
class TwoMoons:
"""Two interleaving crescents in 2-D.
Parameters
----------
noise : float
Std-dev of isotropic Gaussian noise added to the arc positions.
"""
noise: float = 0.1
[docs]
def generate(
self,
n: int,
*,
generator: torch.Generator | None = None,
) -> ToyDataset:
"""Draw *n* labelled samples from two crescents."""
n_upper = n // 2
n_lower = n - n_upper
# upper moon: arc from 0 to pi at origin
t_up = torch.empty(n_upper).uniform_(0, math.pi, generator=generator)
upper = torch.stack([t_up.cos(), t_up.sin()], dim=1)
# lower moon: arc from 0 to pi, shifted right and down
t_lo = torch.empty(n_lower).uniform_(0, math.pi, generator=generator)
lower = torch.stack([1.0 - t_lo.cos(), 0.5 - t_lo.sin()], dim=1)
x = torch.cat([upper, lower], dim=0)
x = x + torch.empty_like(x).normal_(0, self.noise, generator=generator)
y = torch.cat(
[
torch.zeros(n_upper, dtype=torch.long),
torch.ones(n_lower, dtype=torch.long),
]
)
perm = torch.randperm(n, generator=generator)
return ToyDataset(x=x[perm], y=y[perm], meta={"noise": self.noise})