Masked Flow Matching with variable-cardinality sets#

Demonstrate the masked_fm_loss and masked_sample for generating variable-size sets of 2D data samples. Each event has a different numberof objects padded to a common N_max; the mask keeps padding out of the loss and out of the ODE trajectory.

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

from nami import StandardNormal, RK4, masked_fm_loss, masked_sample

toy: variable-size particle events#

Each event contains 3-8 “particles” sampled on a circle of radius 1, with a little Gaussian noise. Pad every event to N_MAX = 8 and build a binary mask.

N_MAX = 8      # maximum particles per event
DIM = 2        # features per particle (x, y)
N_EVENTS = 10000 # training set size

torch.manual_seed(0)

def make_event(n_particles, n_max=N_MAX):
    """Return (padded_particles, mask) for one event."""
    theta = torch.linspace(0, 2 * torch.pi, n_particles + 1)[:-1]
    pts = torch.stack([theta.cos(), theta.sin()], dim=-1)  # (n, 2)
    pts = pts + 0.05 * torch.randn_like(pts)               # small noise
    padded = torch.zeros(n_max, DIM)
    padded[:n_particles] = pts
    mask = torch.zeros(n_max)
    mask[:n_particles] = 1.0
    return padded, mask

cardinalities = torch.randint(3, N_MAX + 1, (N_EVENTS,))
events, masks = zip(*[make_event(int(n)) for n in cardinalities])
x_data = torch.stack(events)   # (N_EVENTS, N_MAX, DIM)
mask_data = torch.stack(masks) # (N_EVENTS, N_MAX)

print(f"Data shape : {x_data.shape}")

print(f"Mask shape : {mask_data.shape}")
print(f"Cardinality range: {cardinalities.min().item()} {cardinalities.max().item()}")
Data shape : torch.Size([10000, 8, 2])
Mask shape : torch.Size([10000, 8])
Cardinality range: 3 8

Quick look at a few training events (real particles as dots, padding absent):

fig, axes = plt.subplots(1, 4, figsize=(12, 3))
for ax, idx in zip(axes, range(4)):
    n = int(mask_data[idx].sum())
    pts = x_data[idx, :n]
    ax.scatter(pts[:, 0], pts[:, 1], s=40)
    ax.set_title(f"Event {idx}  (n={n})")
    ax.set_xlim(-1.6, 1.6); ax.set_ylim(-1.6, 1.6)
    ax.set_aspect("equal")
plt.tight_layout()
plt.show()
../../../_images/84a89f8677dfe4a44ddb3da09f35921f55844405b693c75ff64c23877cd10662.png

Velocity field#

A small MLP that flattens the (N_MAX, DIM) event, concatenates time, and predicts a velocity for every particle slot. event_ndim = 2 tells nami the last two dims form one data point.

class ParticleField(nn.Module):
    def __init__(self, n_max, dim, hidden=128):
        super().__init__()
        flat = n_max * dim
        self.net = nn.Sequential(
            nn.Linear(flat + 1, hidden),
            nn.SiLU(),
            nn.Linear(hidden, hidden),
            nn.SiLU(),
            nn.Linear(hidden, flat),
        )
        self._shape = (n_max, dim)

    @property
    def event_ndim(self):
        return 2

    def forward(self, x, t, c=None):
        flat = x.reshape(*x.shape[:-2], -1)
        t_exp = t.unsqueeze(-1).expand(*flat.shape[:-1], 1)
        out = self.net(torch.cat([flat, t_exp], dim=-1))
        return out.reshape_as(x)


field = ParticleField(N_MAX, DIM)
print(f"Parameters: {sum(p.numel() for p in field.parameters()):,}")
Parameters: 20,880

Training with masked_fm_loss#

Padding should never contribute to the gradient.

optimizer = torch.optim.Adam(field.parameters(), lr=1e-3)
steps = 5000
losses = []
for step in range(1, steps + 1):
    idx = torch.randint(0, N_EVENTS, (128,))
    x_target = x_data[idx]
    x_source = torch.randn_like(x_target)
    m = mask_data[idx]

    loss = masked_fm_loss(field, x_target, x_source, m)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

    if step % 500 == 0 or step == 1:
        print(f"Step {step:>5}/{steps}  loss = {loss.item():.4f}")

plt.plot(losses)
plt.xlabel("Step"); plt.ylabel("Loss")
plt.title("masked_fm_loss training curve")
plt.show()
Step     1/5000  loss = 1.4380
Step   500/5000  loss = 0.3670
Step  1000/5000  loss = 0.2752
Step  1500/5000  loss = 0.3172
Step  2000/5000  loss = 0.3152
Step  2500/5000  loss = 0.3075
Step  3000/5000  loss = 0.3562
Step  3500/5000  loss = 0.3054
Step  4000/5000  loss = 0.2698
Step  4500/5000  loss = 0.2913
Step  5000/5000  loss = 0.2762
../../../_images/873c8d9d7ea941e35efe6d06cb482099e1e1481927af721a71f5476f93c3c1d4.png

sample with masked_sample#

generate events with different cardinalities by passing per-event masks. Padded positions should stay exactly zero throughout integration.

base = StandardNormal(event_shape=(N_MAX, DIM))
solver = RK4(steps=50)

# build masks for 8 samples with cardinalities 3,4,5,6,7,8,5,4
sample_ns = [3, 4, 5, 6, 7, 8, 5, 4]
sample_mask = torch.zeros(len(sample_ns), N_MAX)
for i, n in enumerate(sample_ns):
    sample_mask[i, :n] = 1.0

with torch.no_grad():
    samples = masked_sample(
        field, base, solver, sample_mask,
        sample_shape=(len(sample_ns),),
    )

print(f"Samples shape: {samples.shape}")
print(f"Padding is zero: {torch.all(samples[0, sample_ns[0]:] == 0).item()}")
Samples shape: torch.Size([8, 8, 2])
Padding is zero: True
fig, axes = plt.subplots(2, 4, figsize=(14, 6))
for ax, i in zip(axes.flat, range(len(sample_ns))):
    n = sample_ns[i]
    pts = samples[i, :n].numpy()
    ax.scatter(pts[:, 0], pts[:, 1], s=50, zorder=3)
    ax.set_title(f"Gen.(n={n})")
    ax.set_xlim(-2, 2); ax.set_ylim(-2, 2)
    ax.set_aspect("equal")
    ax.grid(True, alpha=0.3)
plt.suptitle("masked_sample variable-cardinality generation", fontsize=13)
plt.tight_layout()
plt.show()
../../../_images/9112c4c2fcd964ed65402a3a52b6b021ea5563c0e41036ff1c4ca7fb6fe5a358.png