Source code for nami.components.time
from __future__ import annotations
import math
import torch
from torch import nn
def _broadcast_time(
t: torch.Tensor,
leading_shape: tuple[int, ...],
*,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""Broadcast time values to a target leading shape."""
t = torch.as_tensor(t, device=device, dtype=dtype)
return torch.broadcast_to(t, leading_shape)
[docs]
class ScalarTimeEmbedding(nn.Module):
"""Return scalar time as a one-dimensional feature."""
out_dim = 1
[docs]
def forward(
self,
t: torch.Tensor,
*,
leading_shape: tuple[int, ...],
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
return _broadcast_time(
t,
leading_shape,
device=device,
dtype=dtype,
).unsqueeze(-1)
[docs]
class SinusoidalTimeEmbedding(nn.Module):
"""Map scalar time to sinusoidal features.
When ``dim=1`` no sinusoids are produced; the output is simply the
raw scalar time (equivalent to :class:`ScalarTimeEmbedding`).
"""
def __init__(self, dim: int, *, max_period: float = 10000.0):
super().__init__()
if dim <= 0:
msg = f"dim must be positive, got {dim}"
raise ValueError(msg)
if max_period <= 0:
msg = f"max_period must be positive, got {max_period}"
raise ValueError(msg)
self.dim = int(dim)
self.max_period = float(max_period)
@property
def out_dim(self) -> int:
return self.dim
[docs]
def forward(
self,
t: torch.Tensor,
*,
leading_shape: tuple[int, ...],
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
t = _broadcast_time(t, leading_shape, device=device, dtype=dtype)
half = self.dim // 2
if half == 0:
return t.unsqueeze(-1)
scale = math.log(self.max_period) / max(half, 1)
freqs = torch.exp(-scale * torch.arange(half, device=device, dtype=dtype))
args = t.unsqueeze(-1) * freqs
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
if self.dim % 2 == 1:
emb = torch.cat([emb, t.unsqueeze(-1)], dim=-1)
return emb