Source code for nami_toys.ring
from __future__ import annotations
import math
from dataclasses import dataclass
import torch
from .dataset import ToyDataset
[docs]
@dataclass(frozen=True)
class GaussianRing:
"""Isotropic Gaussian modes arranged in a circle.
Parameters
----------
n_modes : int
Number of equally-spaced modes.
radius : float
Distance of each mode centre from the origin.
std : float
Standard deviation of each isotropic Gaussian mode.
"""
n_modes: int = 8
radius: float = 3.0
std: float = 0.3
[docs]
def generate(
self,
n: int,
*,
generator: torch.Generator | None = None,
) -> ToyDataset:
"""Draw *n* samples from a ring of Gaussian modes."""
angles = torch.linspace(0, 2 * math.pi, self.n_modes + 1)[: self.n_modes]
centres = torch.stack(
[self.radius * angles.cos(), self.radius * angles.sin()], dim=1
)
mode_idx = torch.randint(self.n_modes, (n,), generator=generator)
x = centres[mode_idx] + torch.empty(n, 2).normal_(
0, self.std, generator=generator
)
return ToyDataset(
x=x,
y=mode_idx,
meta={"n_modes": self.n_modes, "radius": self.radius, "std": self.std},
)