Source code for nami_toys.shell
from __future__ import annotations
import math
from dataclasses import dataclass
import torch
from torch.distributions import Binomial
from .dataset import ToyDataset
[docs]
@dataclass(frozen=True)
class GaussianShell:
"""2-D Gaussian shell (ring) signal with isotropic normal background.
Parameters
----------
radius : float
Mean radius of the signal ring.
width : float
Standard deviation of the radial spread.
bkg_scale : float
Standard deviation of the isotropic background Gaussian.
"""
radius: float = 2.5
width: float = 0.25
bkg_scale: float = 1.5
[docs]
def generate(
self,
n_expected: int,
sig_frac: float,
*,
generator: torch.Generator | None = None,
) -> ToyDataset:
"""Draw a Poisson-fluctuated 2-D shell dataset."""
n_total = int(torch.poisson(torch.tensor(float(n_expected))).item())
n_sig = int(Binomial(n_total, sig_frac).sample().item())
n_bkg = n_total - n_sig
# signal: points on a noisy ring
if n_sig > 0:
angles = torch.empty(n_sig).uniform_(
0.0, 2.0 * math.pi, generator=generator
)
radii = (
torch.empty(n_sig)
.normal_(self.radius, self.width, generator=generator)
.clamp(min=0.0)
)
signal = torch.stack([radii * angles.cos(), radii * angles.sin()], dim=1)
else:
signal = torch.empty(0, 2)
# background: isotropic 2-D Gaussian
if n_bkg > 0:
background = torch.empty(n_bkg, 2).normal_(
0.0, self.bkg_scale, generator=generator
)
else:
background = torch.empty(0, 2)
x = torch.cat([signal, background], dim=0)
y = torch.cat(
[
torch.ones(n_sig, dtype=torch.long),
torch.zeros(n_bkg, dtype=torch.long),
]
)
perm = torch.randperm(n_total, generator=generator)
return ToyDataset(
x=x[perm],
y=y[perm],
meta={
"n_expected": n_expected,
"sig_frac": sig_frac,
"radius": self.radius,
"width": self.width,
"bkg_scale": self.bkg_scale,
},
)