nami.interpolants.transforms#
Classes
|
Combine velocity and score into probability-flow drift. |
Combine velocity and score into markovization SDE drift. |
|
|
Create mirror-flow velocity v_mirror(x, t) = gamma*gamma_dot*s(x, t). |
|
Convert a noise-prediction model eta(x, t) into a score model s(x, t). |
- class nami.interpolants.transforms.DriftFromVelocityScore(velocity_model, score_model, gamma_schedule)[source]#
Bases:
ModuleCombine velocity and score into probability-flow drift.
Computes
u(x, t) = v(x, t) - gamma(t) * gamma_dot(t) * s(x, t). For markovization SDE drift useMarkovizationDriftFromVelocityScore.- Parameters:
velocity_model (
Module) – The velocity field model v(x, t, c).score_model (
Module) – The score field model s(x, t, c).gamma_schedule (
GammaSchedule) – The noise schedule providing gamma(t) and gamma_dot(t).
- velocity_model#
The velocity field model.
- Type:
nn.Module
- score_model#
The score field model.
- Type:
nn.Module
- gamma_schedule#
The noise schedule.
- Type:
- class nami.interpolants.transforms.MarkovizationDriftFromVelocityScore(velocity_model, score_model, gamma_schedule, *, diffusion2)[source]#
Bases:
ModuleCombine velocity and score into markovization SDE drift.
Computes
b(x, t) = u(x, t) + 0.5 * g(t)^2 * s(x, t), whereu(x, t) = v(x, t) - gamma(t) * gamma_dot(t) * s(x, t). Equivalently:b(x, t) = v(x, t) + (-gamma*gamma_dot + 0.5*g^2) * s(x, t).diffusion2is the squared diffusion coefficientg(t)^2, given as a constantfloator a plain callable(Tensor) -> Tensor. It is not registered as annn.Modulesubmodule, so learnable diffusion schedules should be managed separately.- Parameters:
velocity_model (nn.Module)
score_model (nn.Module)
gamma_schedule (GammaSchedule)
diffusion2 (float | Callable[[torch.Tensor], torch.Tensor])
- forward(x, t, c=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class nami.interpolants.transforms.MirrorVelocityFromScore(score_model, gamma_schedule)[source]#
Bases:
ModuleCreate mirror-flow velocity v_mirror(x, t) = gamma*gamma_dot*s(x, t).
- Parameters:
score_model (nn.Module)
gamma_schedule (GammaSchedule)
- forward(x, t, c=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class nami.interpolants.transforms.ScoreFromNoise(eta_model, gamma_schedule, eps=1e-12)[source]#
Bases:
ModuleConvert a noise-prediction model eta(x, t) into a score model s(x, t).
- Parameters:
eta_model (
Module) – The noise prediction model that takes (x, t, c) and returns noise.gamma_schedule (
GammaSchedule) – The noise schedule for converting noise to score.eps (
float) – Small epsilon value to prevent division by zero, by default 1e-12.
- eta_model#
The wrapped noise prediction model.
- Type:
nn.Module
- gamma_schedule#
The noise schedule.
- Type:
- forward(x, t, c=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.