Source code for nami.components.activation
from __future__ import annotations
from torch import nn
_ACTIVATIONS: dict[str, type[nn.Module]] = {
"relu": nn.ReLU,
"silu": nn.SiLU,
"gelu": nn.GELU,
"tanh": nn.Tanh,
"elu": nn.ELU,
"leaky_relu": nn.LeakyReLU,
"selu": nn.SELU,
"swish": nn.SiLU,
"mish": nn.Mish,
"hard_swish": nn.Hardswish,
"hard_sigmoid": nn.Hardsigmoid,
}
[docs]
def get_activation(name: str) -> nn.Module:
activation = _ACTIVATIONS.get(name)
if activation is None:
msg = f"Unknown activation: {name!r}. Available: {sorted(_ACTIVATIONS)}"
raise ValueError(msg)
return activation()
__all__ = ["get_activation"]