Source code for nami_toys.gaussian

from __future__ import annotations

from dataclasses import dataclass, field
from functools import cached_property

import torch
from torch.distributions import Binomial, MultivariateNormal

from .dataset import ToyDataset

_DEFAULT_SIG_LOC = torch.tensor([1.0, 0.0])
_DEFAULT_SIG_COV = torch.tensor([[1.0, 0.3], [0.3, 1.0]])
_DEFAULT_BKG_LOC = torch.tensor([0.0, 0.0])
_DEFAULT_BKG_COV = torch.tensor([[2.0, -0.2], [-0.2, 2.0]])


[docs] @dataclass(frozen=True) class GaussianMixture: """N-dimensional Gaussian signal + background simulator. Parameters ---------- sig_loc, sig_cov : torch.Tensor Mean ``(d,)`` and covariance ``(d, d)`` of the signal component. bkg_loc, bkg_cov : torch.Tensor Mean ``(d,)`` and covariance ``(d, d)`` of the background component. """ sig_loc: torch.Tensor = field(default_factory=lambda: _DEFAULT_SIG_LOC.clone()) sig_cov: torch.Tensor = field(default_factory=lambda: _DEFAULT_SIG_COV.clone()) bkg_loc: torch.Tensor = field(default_factory=lambda: _DEFAULT_BKG_LOC.clone()) bkg_cov: torch.Tensor = field(default_factory=lambda: _DEFAULT_BKG_COV.clone()) @cached_property def sig(self) -> MultivariateNormal: return MultivariateNormal(self.sig_loc, self.sig_cov) @cached_property def bkg(self) -> MultivariateNormal: return MultivariateNormal(self.bkg_loc, self.bkg_cov) @property def d(self) -> int: return self.sig_loc.shape[0]
[docs] def generate( self, n_expected: int, sig_frac: float, *, generator: torch.Generator | None = None, ) -> ToyDataset: """Draw a Poisson-fluctuated dataset of signal + background events.""" n_total = int(torch.poisson(torch.tensor(float(n_expected))).item()) n_sig = int(Binomial(n_total, sig_frac).sample().item()) n_bkg = n_total - n_sig sig_data = self.sig.sample((n_sig,)) bkg_data = self.bkg.sample((n_bkg,)) x = torch.cat([sig_data, bkg_data], dim=0) y = torch.cat( [ torch.ones(n_sig, dtype=torch.long), torch.zeros(n_bkg, dtype=torch.long), ] ) perm = torch.randperm(n_total, generator=generator) return ToyDataset( x=x[perm], y=y[perm], meta={"n_expected": n_expected, "sig_frac": sig_frac}, )