Source code for nami.fields.velocity

from __future__ import annotations

import torch

from ..components import MLPBackbone, ScalarTimeEmbedding
from ..core.specs import (
    event_numel,
    flatten_event,
    unflatten_event,
    validate_shapes,
)
from ._common import normalise_event_shape, validate_context
from .base import VectorField


[docs] class VelocityField(VectorField): """MLP velocity field for flow matching. Supports unconditional and conditional workflows. When ``condition_dim`` is non-zero the field expects a context vector ``c`` concatenated to the input; otherwise ``c`` should be ``None``. Conditioning is handled by the process layer via lazy binding. This field simply receives whatever context the process passes through. Args: dim: Data dimensionality or event shape. condition_dim: Conditioning vector dimensionality (0 for unconditional). hidden: Hidden layer width. layers: Number of hidden layers. activation: Activation function ('silu', 'relu', 'gelu', 'tanh'). dropout: Dropout probability (0 disables). layer_norm: Whether to apply layer normalisation in hidden layers. """ def __init__( self, dim: int | tuple[int, ...], *, condition_dim: int = 0, hidden: int = 256, layers: int = 3, activation: str = "silu", dropout: float = 0.0, layer_norm: bool = False, ): super().__init__() if condition_dim < 0: msg = f"condition_dim must be non-negative, got {condition_dim}" raise ValueError(msg) self.event_shape = normalise_event_shape(dim) self.condition_dim = int(condition_dim) self.flat_dim = event_numel(self.event_shape) self.time_embedding = ScalarTimeEmbedding() self.backbone = MLPBackbone( self.flat_dim + 1 + self.condition_dim, self.flat_dim, hidden=hidden, layers=layers, activation=activation, dropout=dropout, layer_norm=layer_norm, ) @property def event_ndim(self) -> int: return len(self.event_shape)
[docs] def forward( self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor | None = None, ) -> torch.Tensor: validate_shapes(x, self.event_ndim, expected_event_shape=self.event_shape) x_flat = flatten_event(x, self.event_ndim) lead_shape = tuple(x_flat.shape[:-1]) t_features = self.time_embedding( t, leading_shape=lead_shape, device=x.device, dtype=x.dtype, ) validate_context(c, self.condition_dim, lead_shape) inputs = torch.cat([x_flat, t_features], dim=-1) if c is not None: inputs = torch.cat([inputs, c], dim=-1) return unflatten_event(self.backbone(inputs), self.event_shape)