nami.fields.transformer#

Classes

TransformerVelocityField(dim, *[, ...])

Transformer velocity field over flattened event tokens.

class nami.fields.transformer.TransformerVelocityField(dim, *, model_dim=128, depth=4, num_heads=4, time_dim=32, condition_dim=0, mlp_ratio=4.0, dropout=0.0, activation='gelu')[source]#

Bases: VectorField

Transformer velocity field over flattened event tokens.

Each scalar feature in the flattened event is treated as one token. Time is embedded once per sample, then broadcast across all tokens. Optional context is projected to a single cross-attention token, so attention cost scales quadratically with the flattened event size.

Parameters:
property event_ndim: int#
forward(x, t, c=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

Tensor

Parameters: