nami.fields.transformer#
Classes
|
Transformer velocity field over flattened event tokens. |
- class nami.fields.transformer.TransformerVelocityField(dim, *, model_dim=128, depth=4, num_heads=4, time_dim=32, condition_dim=0, mlp_ratio=4.0, dropout=0.0, activation='gelu')[source]#
Bases:
VectorFieldTransformer velocity field over flattened event tokens.
Each scalar feature in the flattened event is treated as one token. Time is embedded once per sample, then broadcast across all tokens. Optional context is projected to a single cross-attention token, so attention cost scales quadratically with the flattened event size.
- Parameters:
- forward(x, t, c=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.