nami.core.typing#

Classes

DivergenceEstimator(*args, **kwargs)

Interface for computing divergence of velocity field in log-likelihood calc.

NoiseSchedule(*args, **kwargs)

Interface for diffusion model noise schedules with forward process x_t = alpha(t)*x_0 + sigma(t) * epsilon

ODESolver(*args, **kwargs)

Interface for ODE integrators.

ProbabilityPath(*args, **kwargs)

Interface for interpolation paths in the fm models

SDESolver(*args, **kwargs)

Interface for SDE integrators.

class nami.core.typing.DivergenceEstimator(*args, **kwargs)[source]#

Bases: Protocol

Interface for computing divergence of velocity field in log-likelihood calc. via change of variables.

class nami.core.typing.NoiseSchedule(*args, **kwargs)[source]#

Bases: Protocol

Interface for diffusion model noise schedules with forward process x_t = alpha(t)*x_0 + sigma(t) * epsilon

Methods: — - alpha(t): signal scaling coeff. at time t - sigma(t): noise scaling coeff. at time t - snr: signal-to-noise ratio (alpha^2/sigma^2) - drift(x,t): term in SDE dx = f(x,t)dt + g(t) dW - diffusion: the diffusion coeff. g(t)

alpha(t)[source]#
Return type:

Tensor

Parameters:

t (Tensor)

diffusion(t)[source]#
Return type:

Tensor

Parameters:

t (Tensor)

drift(x, t)[source]#
Return type:

Tensor

Parameters:
sigma(t)[source]#
Return type:

Tensor

Parameters:

t (Tensor)

snr(t)[source]#
Return type:

Tensor

Parameters:

t (Tensor)

class nami.core.typing.ODESolver(*args, **kwargs)[source]#

Bases: Protocol

Interface for ODE integrators.

Methods:#

  • integrate: solve dx/dt = f(x,t) from t0 to t1 given initial state x0

  • integrate_augmented: jointly solve for state and the log-prob change

integrate(f, x0, *, t0, t1, atol=1e-06, rtol=1e-05, steps=None)[source]#
Return type:

Tensor

Parameters:
integrate_augmented(f_aug, x0, logp0, *, t0, t1, atol=1e-06, rtol=1e-05, steps=None)[source]#
Return type:

tuple[Tensor, Tensor]

Parameters:
is_sde: bool#
requires_steps: bool#
supports_rsample: bool#
class nami.core.typing.ProbabilityPath(*args, **kwargs)[source]#

Bases: Protocol

Interface for interpolation paths in the fm models

Methods: — - sample_xt: given data, noise and time, return the interpolated point along the path - target_ut: ground truth velocity field used in the loss

sample_xt(x_target, x_source, t)[source]#
Return type:

Tensor

Parameters:
target_ut(x_target, x_source, t)[source]#
Return type:

Tensor

Parameters:
class nami.core.typing.SDESolver(*args, **kwargs)[source]#

Bases: Protocol

Interface for SDE integrators.

Methods: - integrate

integrate(drift, diffusion, x0, *, t0, t1, steps)[source]#
Return type:

Tensor

Parameters:
is_sde: bool#