Source code for nami_toys.spirals
from __future__ import annotations
import math
from dataclasses import dataclass
import torch
from .dataset import ToyDataset
[docs]
@dataclass(frozen=True)
class TwoSpirals:
"""Two Archimedean spirals winding in opposite directions.
Parameters
----------
noise : float
Std-dev of isotropic Gaussian noise.
n_turns : float
Number of full turns each spiral makes.
"""
noise: float = 0.1
n_turns: float = 1.5
[docs]
def generate(
self,
n: int,
*,
generator: torch.Generator | None = None,
) -> ToyDataset:
"""Draw *n* labelled samples from two spirals."""
n_a = n // 2
n_b = n - n_a
max_angle = self.n_turns * 2 * math.pi
def _arm(count: int, offset: float) -> torch.Tensor:
# sqrt spacing gives uniform density along the arm
t = torch.empty(count).uniform_(0, 1, generator=generator).sqrt()
theta = t * max_angle + offset
r = t * max_angle / (2 * math.pi) # radius grows with arc length
pts = torch.stack([r * theta.cos(), r * theta.sin()], dim=1)
return pts + torch.empty_like(pts).normal_(
0, self.noise, generator=generator
)
x = torch.cat([_arm(n_a, 0.0), _arm(n_b, math.pi)], dim=0)
y = torch.cat(
[
torch.zeros(n_a, dtype=torch.long),
torch.ones(n_b, dtype=torch.long),
]
)
perm = torch.randperm(n, generator=generator)
return ToyDataset(
x=x[perm],
y=y[perm],
meta={"noise": self.noise, "n_turns": self.n_turns},
)