Source code for nami_toys.parameterised
from __future__ import annotations
from dataclasses import dataclass, field
from functools import cached_property
import torch
from torch.distributions import Binomial, MultivariateNormal
from .dataset import ToyDataset
_DEFAULT_SIG_LOC = torch.tensor([0.0, 0.0])
_DEFAULT_SIG_COV = torch.tensor([[1.0, 0.3], [0.3, 1.0]])
_DEFAULT_BKG_LOC = torch.tensor([0.0, 0.0])
_DEFAULT_BKG_COV = torch.tensor([[2.0, -0.2], [-0.2, 2.0]])
[docs]
@dataclass(frozen=True)
class ParameterisedGaussian:
r"""Gaussian mixture whose signal location depends on a parameter :math:`\theta`.
The background distribution is fixed while the signal mean is set to
:math:`\theta` along ``param_dim``, keeping the base mean elsewhere.
Parameters
----------
sig_loc : torch.Tensor
Base signal mean ``(d,)``; entry at *param_dim* is replaced by theta.
sig_cov : torch.Tensor
Signal covariance ``(d, d)`` (fixed).
bkg_loc, bkg_cov : torch.Tensor
Background mean and covariance (fixed).
sig_frac : float
Expected signal fraction.
param_dim : int
Dimension of the mean vector that theta controls.
"""
sig_loc: torch.Tensor = field(default_factory=lambda: _DEFAULT_SIG_LOC.clone())
sig_cov: torch.Tensor = field(default_factory=lambda: _DEFAULT_SIG_COV.clone())
bkg_loc: torch.Tensor = field(default_factory=lambda: _DEFAULT_BKG_LOC.clone())
bkg_cov: torch.Tensor = field(default_factory=lambda: _DEFAULT_BKG_COV.clone())
sig_frac: float = 0.3
param_dim: int = 0
@cached_property
def bkg(self) -> MultivariateNormal:
return MultivariateNormal(self.bkg_loc, self.bkg_cov)
@property
def d(self) -> int:
return self.sig_loc.shape[0]
[docs]
def sig_at(self, theta: float) -> MultivariateNormal:
"""Return the signal distribution at a given *theta*."""
mu = self.sig_loc.clone()
mu[self.param_dim] = theta
return MultivariateNormal(mu, self.sig_cov)
[docs]
def log_prob(self, x: torch.Tensor, theta: float) -> torch.Tensor:
r"""Mixture log-probability :math:`\log p(x \mid \theta)`.
Parameters
----------
x : torch.Tensor
Events ``(N, d)`` or ``(d,)``.
theta : float
Parameter value.
"""
sig = self.sig_at(theta)
p = (
self.sig_frac * sig.log_prob(x).exp()
+ (1 - self.sig_frac) * self.bkg.log_prob(x).exp()
)
return p.log()
[docs]
def log_likelihood_ratio(self, x: torch.Tensor, theta: float) -> torch.Tensor:
r"""Per-event log-likelihood ratio :math:`\log p(x \mid \text{sig}, \theta) - \log p(x \mid \text{bkg})`."""
return self.sig_at(theta).log_prob(x) - self.bkg.log_prob(x)
[docs]
def generate(
self,
theta: float,
n_expected: int,
*,
generator: torch.Generator | None = None,
) -> ToyDataset:
"""Draw a Poisson-fluctuated dataset at the given *theta*."""
n_total = int(torch.poisson(torch.tensor(float(n_expected))).item())
n_sig = int(Binomial(n_total, self.sig_frac).sample().item())
n_bkg = n_total - n_sig
sig = self.sig_at(theta)
x = torch.cat([sig.sample((n_sig,)), self.bkg.sample((n_bkg,))], dim=0)
y = torch.cat(
[
torch.ones(n_sig, dtype=torch.long),
torch.zeros(n_bkg, dtype=torch.long),
]
)
perm = torch.randperm(n_total, generator=generator)
return ToyDataset(
x=x[perm],
y=y[perm],
meta={"n_expected": n_expected, "sig_frac": self.sig_frac, "theta": theta},
)