= 40
latent_dim = jax.random.normal(jax.random.PRNGKey(42), (latent_dim,)) * 0.01
latent
= jnp.repeat(
grid -1, 1, 40)[None],
jnp.linspace(96000,
=0,
axis
)
= ModulatedSiren(hidden_channels=[64, 1], synthesis_act=nn.selu)
modulated_siren
print(latent.shape)
print(grid.shape)
= modulated_siren.init_with_output(
out, variables 42),
jax.random.PRNGKey(
grid,
latent,
)
print(out.shape)
print(
modulated_siren.tabulate(42),
jax.random.PRNGKey(
jnp.empty_like(grid),
jnp.empty_like(latent),={"no_wrap": True},
column_kwargs={"expand": True},
table_kwargs={"width": 120},
console_kwargs=2,
depth
) )
MLP models
::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
from typing import Any, Callable, Sequence, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax import lax
from jax.random import uniform
from jax.typing import ArrayLike
:::
::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class MLP(nn.Module):
"""
MLP with SELU activation and LeCun normal initialization.
"""
int] # number of hidden channels
hidden_channels: Sequence[= nn.selu
activation: nn.Module = nn.initializers.lecun_normal()
kernel_init: nn.initializers.Initializer bool = True
use_bias: bool = False
layer_norm:
@nn.compact
def __call__(
self,
x: jnp.ndarray,-> jnp.ndarray:
) for channels in self.hidden_channels:
= nn.Dense(
x =channels,
features=self.kernel_init,
kernel_init=self.use_bias,
use_bias
)(x)if channels != self.hidden_channels[-1]:
if self.layer_norm:
= nn.LayerNorm()(x)
x = self.activation(x)
x return x
:::
Siren implementation taken from here
::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def siren_init(weight_std, dtype):
def init_fun(key, shape, dtype=dtype):
if dtype == jnp.dtype(jnp.array([1j])):
= jax.random.split(key)
key1, key2 = jnp.dtype(jnp.array([1j]).real)
dtype = uniform(key1, shape, dtype) * 2 * weight_std - weight_std
a = uniform(key2, shape, dtype) * 2 * weight_std - weight_std
b return a + 1j * b
else:
return uniform(key, shape, dtype) * 2 * weight_std - weight_std
return init_fun
def grid_init(grid_dimension, dtype):
def init_fun(dtype=dtype):
= [jnp.linspace(-3, 3, d) for d in grid_dimension]
coord_axis = jnp.stack(jnp.meshgrid(*coord_axis), -1)
grid return jnp.asarray(grid, dtype)
return init_fun
class Sine(nn.Module):
float = 1.0
w0: = jnp.float32
dtype: Any
@nn.compact
def __call__(self, inputs: ArrayLike) -> ArrayLike:
= jnp.asarray(inputs, self.dtype)
inputs return jnp.sin(self.w0 * inputs)
class SirenLayer(nn.Module):
int = 32
features: float = 1.0
w0: float = 6.0
c: bool = False
is_first: bool = True
use_bias: = jnp.sin
act: Callable = None
precision: Any = jnp.float32
dtype: Any
@nn.compact
def __call__(self, inputs: ArrayLike) -> ArrayLike:
= jnp.asarray(inputs, self.dtype)
inputs = inputs.shape[-1]
input_dim
# Linear projection with init proposed in SIREN paper
= (
weight_std 1 / input_dim) if self.is_first else jnp.sqrt(self.c / input_dim) / self.w0
(
)
= self.param(
kernel "kernel", siren_init(weight_std, self.dtype), (input_dim, self.features)
)= jnp.asarray(kernel, self.dtype)
kernel
= lax.dot_general(
y
inputs,
kernel,- 1,), (0,)), ((), ())),
(((inputs.ndim =self.precision,
precision
)
if self.use_bias:
= self.param("bias", uniform, (self.features,))
bias = jnp.asarray(bias, self.dtype)
bias = y + bias
y
return self.w0 * y
class ModulatedLayer(nn.Module):
int = 32
features: bool = False
is_first: = jnp.sin
synthesis_act: Callable = nn.relu
modulator_act: Callable = None
precision: Any = jnp.float32
dtype: Any float = 30.0
w0_first_layer: float = 1.0
w0:
@nn.compact
def __call__(
self,
input: ArrayLike,
latent: ArrayLike,
hidden: ArrayLike,-> Tuple[ArrayLike, ArrayLike]:
) # Get new modulation amplitude
if self.synthesis_act in [jnp.sin]:
= SirenLayer(
synth_dense =self.features,
features=self.w0_first_layer if self.is_first else self.w0,
w0=self.is_first,
is_first=self.synthesis_act,
act=self.dtype,
dtype
)else:
= nn.Dense(
synth_dense =self.features,
features=self.precision,
precision=self.dtype,
dtype
)
= nn.Dense(
modulator_dense =self.features,
features=self.precision,
precision=self.dtype,
dtype="mod_dense",
name
)
if self.is_first:
# Prepare hidden state
= nn.Dense(
hidden_state_init self.features,
=self.precision,
precision=self.dtype,
dtype
)= hidden_state_init(latent)
hidden
# Build modulation signal and generate
= jnp.concatenate([hidden, latent])
mod_input = self.modulator_act(modulator_dense(mod_input))
alpha = self.synthesis_act(synth_dense(input))
synth_dense_output = alpha * synth_dense_output
output return output, alpha
class Siren(nn.Module):
int] # number of hidden channels including output
hidden_channels: Sequence[float = 1.0 # Frequency of the sine activations
w0: float = 1.0 # Frequency of the sine activations in the first layer
w0_first_layer: bool = True # Whether to use bias in the layers
use_bias: = lambda x: x # Identity
final_activation: Callable = jnp.float32
dtype: Any
@nn.compact
def __call__(
self,
inputs: ArrayLike,-> ArrayLike:
) = jnp.asarray(inputs, self.dtype)
x
for layer_idx, channels in enumerate(self.hidden_channels[:-1]):
= layer_idx == 0
is_first
= SirenLayer(
x =channels,
features=self.w0_first_layer if is_first else self.w0,
w0=is_first,
is_first=self.use_bias,
use_bias
)(x)
# Last layer, with different activation function
= SirenLayer(
x =self.hidden_channels[-1],
features=self.w0,
w0=False,
is_first=self.use_bias,
use_bias=self.final_activation,
act
)(x)
return x
class ModulatedSiren(nn.Module):
int] # number of hidden channels including output
hidden_channels: Sequence[= jnp.sin
synthesis_act: Callable = nn.relu
modulator_act: Callable float = 30.0
w0_first_layer: = jnp.float32
dtype: Any
@nn.compact
def __call__(
self,
inputs: ArrayLike,
latent: ArrayLike,-> ArrayLike:
) = jnp.asarray(inputs, self.dtype)
x = jnp.asarray(latent, self.dtype)
latent = None
hidden
for layer_idx, channels in enumerate(self.hidden_channels[:-1]):
= layer_idx == 0
is_first
= ModulatedLayer(
x, hidden =channels,
features=is_first,
is_first=self.synthesis_act,
synthesis_act=self.modulator_act,
modulator_act=self.dtype,
dtype=self.w0_first_layer,
w0_first_layer
)(x, latent, hidden)
# Last layer
= nn.Dense(
x self.hidden_channels[-1],
=self.dtype,
dtype="output_layer",
name
)(x)return x
:::
::: {#cell-8 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class RandomFourierFeatures(nn.Module):
float
scale: int
n_features:
@nn.compact
def __call__(self, x):
= (
B self.variable(
"buffers",
"B",
nn.initializers.normal(),42),
jax.random.PRNGKey(
(-1],
x.shape[self.n_features // 2,
),
).value* self.scale
)= (2 * jnp.pi * x) @ B
x_proj return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1)
:::
= RandomFourierFeatures(scale=1.0, n_features=256)
rff = rff.init(jax.random.PRNGKey(42), jnp.ones((1, 2)))
variables = rff.apply(variables, jnp.ones((1, 2)))
out print(out)