latent_dim = 40
latent = jax.random.normal(jax.random.PRNGKey(42), (latent_dim,)) * 0.01
grid = jnp.repeat(
jnp.linspace(-1, 1, 40)[None],
96000,
axis=0,
)
modulated_siren = ModulatedSiren(hidden_channels=[64, 1], synthesis_act=nn.selu)
print(latent.shape)
print(grid.shape)
out, variables = modulated_siren.init_with_output(
jax.random.PRNGKey(42),
grid,
latent,
)
print(out.shape)
print(
modulated_siren.tabulate(
jax.random.PRNGKey(42),
jnp.empty_like(grid),
jnp.empty_like(latent),
column_kwargs={"no_wrap": True},
table_kwargs={"expand": True},
console_kwargs={"width": 120},
depth=2,
)
)MLP models
::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
from typing import Any, Callable, Sequence, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax import lax
from jax.random import uniform
from jax.typing import ArrayLike:::
::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class MLP(nn.Module):
"""
MLP with SELU activation and LeCun normal initialization.
"""
hidden_channels: Sequence[int] # number of hidden channels
activation: nn.Module = nn.selu
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
use_bias: bool = True
layer_norm: bool = False
@nn.compact
def __call__(
self,
x: jnp.ndarray,
) -> jnp.ndarray:
for channels in self.hidden_channels:
x = nn.Dense(
features=channels,
kernel_init=self.kernel_init,
use_bias=self.use_bias,
)(x)
if channels != self.hidden_channels[-1]:
if self.layer_norm:
x = nn.LayerNorm()(x)
x = self.activation(x)
return x:::
Siren implementation taken from here
::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def siren_init(weight_std, dtype):
def init_fun(key, shape, dtype=dtype):
if dtype == jnp.dtype(jnp.array([1j])):
key1, key2 = jax.random.split(key)
dtype = jnp.dtype(jnp.array([1j]).real)
a = uniform(key1, shape, dtype) * 2 * weight_std - weight_std
b = uniform(key2, shape, dtype) * 2 * weight_std - weight_std
return a + 1j * b
else:
return uniform(key, shape, dtype) * 2 * weight_std - weight_std
return init_fun
def grid_init(grid_dimension, dtype):
def init_fun(dtype=dtype):
coord_axis = [jnp.linspace(-3, 3, d) for d in grid_dimension]
grid = jnp.stack(jnp.meshgrid(*coord_axis), -1)
return jnp.asarray(grid, dtype)
return init_fun
class Sine(nn.Module):
w0: float = 1.0
dtype: Any = jnp.float32
@nn.compact
def __call__(self, inputs: ArrayLike) -> ArrayLike:
inputs = jnp.asarray(inputs, self.dtype)
return jnp.sin(self.w0 * inputs)
class SirenLayer(nn.Module):
features: int = 32
w0: float = 1.0
c: float = 6.0
is_first: bool = False
use_bias: bool = True
act: Callable = jnp.sin
precision: Any = None
dtype: Any = jnp.float32
@nn.compact
def __call__(self, inputs: ArrayLike) -> ArrayLike:
inputs = jnp.asarray(inputs, self.dtype)
input_dim = inputs.shape[-1]
# Linear projection with init proposed in SIREN paper
weight_std = (
(1 / input_dim) if self.is_first else jnp.sqrt(self.c / input_dim) / self.w0
)
kernel = self.param(
"kernel", siren_init(weight_std, self.dtype), (input_dim, self.features)
)
kernel = jnp.asarray(kernel, self.dtype)
y = lax.dot_general(
inputs,
kernel,
(((inputs.ndim - 1,), (0,)), ((), ())),
precision=self.precision,
)
if self.use_bias:
bias = self.param("bias", uniform, (self.features,))
bias = jnp.asarray(bias, self.dtype)
y = y + bias
return self.w0 * y
class ModulatedLayer(nn.Module):
features: int = 32
is_first: bool = False
synthesis_act: Callable = jnp.sin
modulator_act: Callable = nn.relu
precision: Any = None
dtype: Any = jnp.float32
w0_first_layer: float = 30.0
w0: float = 1.0
@nn.compact
def __call__(
self,
input: ArrayLike,
latent: ArrayLike,
hidden: ArrayLike,
) -> Tuple[ArrayLike, ArrayLike]:
# Get new modulation amplitude
if self.synthesis_act in [jnp.sin]:
synth_dense = SirenLayer(
features=self.features,
w0=self.w0_first_layer if self.is_first else self.w0,
is_first=self.is_first,
act=self.synthesis_act,
dtype=self.dtype,
)
else:
synth_dense = nn.Dense(
features=self.features,
precision=self.precision,
dtype=self.dtype,
)
modulator_dense = nn.Dense(
features=self.features,
precision=self.precision,
dtype=self.dtype,
name="mod_dense",
)
if self.is_first:
# Prepare hidden state
hidden_state_init = nn.Dense(
self.features,
precision=self.precision,
dtype=self.dtype,
)
hidden = hidden_state_init(latent)
# Build modulation signal and generate
mod_input = jnp.concatenate([hidden, latent])
alpha = self.modulator_act(modulator_dense(mod_input))
synth_dense_output = self.synthesis_act(synth_dense(input))
output = alpha * synth_dense_output
return output, alpha
class Siren(nn.Module):
hidden_channels: Sequence[int] # number of hidden channels including output
w0: float = 1.0 # Frequency of the sine activations
w0_first_layer: float = 1.0 # Frequency of the sine activations in the first layer
use_bias: bool = True # Whether to use bias in the layers
final_activation: Callable = lambda x: x # Identity
dtype: Any = jnp.float32
@nn.compact
def __call__(
self,
inputs: ArrayLike,
) -> ArrayLike:
x = jnp.asarray(inputs, self.dtype)
for layer_idx, channels in enumerate(self.hidden_channels[:-1]):
is_first = layer_idx == 0
x = SirenLayer(
features=channels,
w0=self.w0_first_layer if is_first else self.w0,
is_first=is_first,
use_bias=self.use_bias,
)(x)
# Last layer, with different activation function
x = SirenLayer(
features=self.hidden_channels[-1],
w0=self.w0,
is_first=False,
use_bias=self.use_bias,
act=self.final_activation,
)(x)
return x
class ModulatedSiren(nn.Module):
hidden_channels: Sequence[int] # number of hidden channels including output
synthesis_act: Callable = jnp.sin
modulator_act: Callable = nn.relu
w0_first_layer: float = 30.0
dtype: Any = jnp.float32
@nn.compact
def __call__(
self,
inputs: ArrayLike,
latent: ArrayLike,
) -> ArrayLike:
x = jnp.asarray(inputs, self.dtype)
latent = jnp.asarray(latent, self.dtype)
hidden = None
for layer_idx, channels in enumerate(self.hidden_channels[:-1]):
is_first = layer_idx == 0
x, hidden = ModulatedLayer(
features=channels,
is_first=is_first,
synthesis_act=self.synthesis_act,
modulator_act=self.modulator_act,
dtype=self.dtype,
w0_first_layer=self.w0_first_layer,
)(x, latent, hidden)
# Last layer
x = nn.Dense(
self.hidden_channels[-1],
dtype=self.dtype,
name="output_layer",
)(x)
return x:::
::: {#cell-8 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class RandomFourierFeatures(nn.Module):
scale: float
n_features: int
@nn.compact
def __call__(self, x):
B = (
self.variable(
"buffers",
"B",
nn.initializers.normal(),
jax.random.PRNGKey(42),
(
x.shape[-1],
self.n_features // 2,
),
).value
* self.scale
)
x_proj = (2 * jnp.pi * x) @ B
return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1):::
rff = RandomFourierFeatures(scale=1.0, n_features=256)
variables = rff.init(jax.random.PRNGKey(42), jnp.ones((1, 2)))
out = rff.apply(variables, jnp.ones((1, 2)))
print(out)