nami.interpolants.transforms#

Classes

DriftFromVelocityScore(velocity_model, ...)

Combine velocity and score into probability-flow drift.

MarkovizationDriftFromVelocityScore(...)

Combine velocity and score into markovization SDE drift.

MirrorVelocityFromScore(score_model, ...)

Create mirror-flow velocity v_mirror(x, t) = gamma*gamma_dot*s(x, t).

ScoreFromNoise(eta_model, gamma_schedule[, eps])

Convert a noise-prediction model eta(x, t) into a score model s(x, t).

class nami.interpolants.transforms.DriftFromVelocityScore(velocity_model, score_model, gamma_schedule)[source]#

Bases: Module

Combine velocity and score into probability-flow drift.

Computes u(x, t) = v(x, t) - gamma(t) * gamma_dot(t) * s(x, t). For markovization SDE drift use MarkovizationDriftFromVelocityScore.

Parameters:
  • velocity_model (Module) – The velocity field model v(x, t, c).

  • score_model (Module) – The score field model s(x, t, c).

  • gamma_schedule (GammaSchedule) – The noise schedule providing gamma(t) and gamma_dot(t).

velocity_model#

The velocity field model.

Type:

nn.Module

score_model#

The score field model.

Type:

nn.Module

gamma_schedule#

The noise schedule.

Type:

GammaSchedule

property event_ndim: int | None#
forward(x, t, c=None)[source]#

Compute the probability-flow drift.

Parameters:
  • x (Tensor) – The state variable.

  • t (Tensor) – The time variable.

  • c (Tensor | None) – Optional conditioning information, by default None.

Returns:

The drift value u(x, t) = v(x, t) - gamma*gamma_dot*s(x, t).

Return type:

Tensor

class nami.interpolants.transforms.MarkovizationDriftFromVelocityScore(velocity_model, score_model, gamma_schedule, *, diffusion2)[source]#

Bases: Module

Combine velocity and score into markovization SDE drift.

Computes b(x, t) = u(x, t) + 0.5 * g(t)^2 * s(x, t), where u(x, t) = v(x, t) - gamma(t) * gamma_dot(t) * s(x, t). Equivalently: b(x, t) = v(x, t) + (-gamma*gamma_dot + 0.5*g^2) * s(x, t).

diffusion2 is the squared diffusion coefficient g(t)^2, given as a constant float or a plain callable (Tensor) -> Tensor. It is not registered as an nn.Module submodule, so learnable diffusion schedules should be managed separately.

Parameters:
property event_ndim: int | None#
forward(x, t, c=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

Tensor

Parameters:
class nami.interpolants.transforms.MirrorVelocityFromScore(score_model, gamma_schedule)[source]#

Bases: Module

Create mirror-flow velocity v_mirror(x, t) = gamma*gamma_dot*s(x, t).

Parameters:
property event_ndim: int | None#
forward(x, t, c=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

Tensor

Parameters:
class nami.interpolants.transforms.ScoreFromNoise(eta_model, gamma_schedule, eps=1e-12)[source]#

Bases: Module

Convert a noise-prediction model eta(x, t) into a score model s(x, t).

Parameters:
  • eta_model (Module) – The noise prediction model that takes (x, t, c) and returns noise.

  • gamma_schedule (GammaSchedule) – The noise schedule for converting noise to score.

  • eps (float) – Small epsilon value to prevent division by zero, by default 1e-12.

eta_model#

The wrapped noise prediction model.

Type:

nn.Module

gamma_schedule#

The noise schedule.

Type:

GammaSchedule

eps#

Numerical stability constant.

Type:

float

property event_ndim: int | None#
forward(x, t, c=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

Tensor

Parameters: