nami.masking#
Mask convention#
1 = real object, 0 = padding. Masks have shape (..., N) where
N is the object (first event) dimension.
Functions
|
Flow matching loss computed only over real (unmasked) objects. |
|
Sample from a flow matching model with variable-cardinality masking. |
- nami.masking.masked_fm_loss(field, x_target, x_source, mask, t=None, c=None, *, path=None, reduction='mean')[source]#
Flow matching loss computed only over real (unmasked) objects.
Like
fm_loss()but padded objects — positions wheremask == 0— are excluded from the loss.- Parameters:
field (nn.Module) – Velocity field. Must expose an
event_ndimattribute >= 2.x_target (
Tensor) – Target and source tensors, eachlead + event_shape.x_source (
Tensor) – Target and source tensors, eachlead + event_shape.mask (
Tensor) – Binary mask,lead + (N,)where N is the first event dim.1 = real,0 = padding.t (
Tensor|None) – Per-sample time values (lead). Uniform random if None.c (
Tensor|None) – Conditioning context forwarded to the field.path (nami.paths.base.ProbabilityPath, optional) – Defaults to
LinearPath.reduction (
str)
- Returns:
Scalar loss, or per-sample losses when
reduction='none'.- Return type:
- nami.masking.masked_sample(field, base, solver, mask, *, sample_shape=(), c=None, t0=1.0, t1=0.0)[source]#
Sample from a flow matching model with variable-cardinality masking.
Draws noise from base and zeros padded positions.
At every solver step, masks the velocity output so padded positions receive zero velocity and remain at zero throughout integration.
- Parameters:
field (nn.Module) – Velocity field with
forward(x, t, c)andevent_ndim.base (Distribution) – Base (source) distribution.
solver – ODE solver with
integrate(f, x0, *, t0, t1, ...).mask (
Tensor) – Binary mask(batch..., N).1 = real,0 = padding.sample_shape (
tuple[int,...]) – Independent sample dimensions prepended to the output.c (
Tensor|None) – Conditioning context forwarded to the field.t0 (
float) – Integration start (default1.0, noise end).t1 (
float) – Integration end (default0.0, data end).
- Returns:
Samples with shape
sample_shape + batch + event_shape. Padded positions are exactly zero.- Return type: