Source code for nami.core.specs
from __future__ import annotations
import math
from collections.abc import Iterable
from dataclasses import dataclass
import torch
[docs]
def as_tuple(x: Iterable[int] | int | None) -> tuple[int, ...]:
"""normaliser to take flexible input and return always a tuple for convenience."""
if x is None:
return ()
if isinstance(x, tuple):
return x
if isinstance(x, list):
return tuple(int(v) for v in x)
return (int(x),)
[docs]
def event_numel(event_shape: Iterable[int] | None) -> int:
"""returns the total number of elements in the event shape"""
shape = as_tuple(event_shape)
if not shape:
return 1
return int(math.prod(shape))
[docs]
def split_event(
x: torch.Tensor, event_ndim: int
) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""given a tensor, return shape split into leading shape and event shape"""
if event_ndim < 0:
msg = "event_ndim must be >= 0"
raise ValueError(msg)
if event_ndim > x.ndim:
msg = "event_ndim exceeds x.ndim"
raise ValueError(msg)
if event_ndim == 0:
return tuple(x.shape), ()
return tuple(x.shape[:-event_ndim]), tuple(x.shape[-event_ndim:])
[docs]
def flatten_event(x: torch.Tensor, event_ndim: int) -> torch.Tensor:
"""collapse all event dimensions into a single flat dimension"""
if event_ndim < 0:
msg = "event_ndim must be >= 0"
raise ValueError(msg)
if event_ndim > x.ndim:
msg = "event_ndim exceeds x.ndim"
raise ValueError(msg)
if event_ndim == 0:
return x
return x.reshape(*x.shape[:-event_ndim], -1)
[docs]
def unflatten_event(x: torch.Tensor, event_shape: tuple[int, ...]) -> torch.Tensor:
"""inverse of `flatten_event`"""
if not event_shape:
return x
return x.reshape(*x.shape[:-1], *event_shape)
[docs]
def validate_shapes(
tensor: torch.Tensor,
event_ndim: int,
expected_event_shape: tuple[int, ...] | None = None,
batch_shape: tuple[int, ...] | None = None,
) -> None:
"""Runtime assertion helper to enforce explicit shapes and
prevent silent broadcasting
"""
if event_ndim < 0:
msg = "event_ndim must be >= 0"
raise ValueError(msg)
if event_ndim > tensor.ndim:
msg = "event_ndim exceeds tensor.ndim"
raise ValueError(msg)
if expected_event_shape is not None:
actual_event_shape = tuple(tensor.shape[-event_ndim:] if event_ndim > 0 else ())
if actual_event_shape != expected_event_shape:
msg = f"event_shape mismatch: expected {expected_event_shape}, got {actual_event_shape}"
raise ValueError(msg)
if batch_shape is not None:
actual_batch_shape = tuple(
tensor.shape[:-event_ndim] if event_ndim > 0 else tensor.shape
)
if actual_batch_shape != batch_shape:
msg = f"batch_shape mismatch: expected {batch_shape}, got {actual_batch_shape}"
raise ValueError(msg)
[docs]
@dataclass(frozen=True)
class TensorSpec:
"""Minimal tensor specification for models, samplers, and distributions.
Attributes:
----------
event_shape (tuple[int, ...]): The shape of a single event (sample, vector, matrix, etc).
dtype (torch.dtype | None): The expected data type of the tensor.
"""
event_shape: tuple[int, ...]
dtype: torch.dtype | None = None
@property
def event_ndim(self) -> int:
return len(self.event_shape)
@property
def numel(self) -> int:
return event_numel(self.event_shape)