MLP models

::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

from typing import Any, Callable, Sequence, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
from jax import lax
from jax.random import uniform
from jax.typing import ArrayLike

:::

::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class MLP(nn.Module):
    """
    MLP with SELU activation and LeCun normal initialization.
    """

    hidden_channels: Sequence[int]  # number of hidden channels
    activation: nn.Module = nn.selu
    kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
    use_bias: bool = True
    layer_norm: bool = False

    @nn.compact
    def __call__(
        self,
        x: jnp.ndarray,
    ) -> jnp.ndarray:
        for channels in self.hidden_channels:
            x = nn.Dense(
                features=channels,
                kernel_init=self.kernel_init,
                use_bias=self.use_bias,
            )(x)
            if channels != self.hidden_channels[-1]:
                if self.layer_norm:
                    x = nn.LayerNorm()(x)
                x = self.activation(x)
        return x

:::

Siren implementation taken from here

::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def siren_init(weight_std, dtype):
    def init_fun(key, shape, dtype=dtype):
        if dtype == jnp.dtype(jnp.array([1j])):
            key1, key2 = jax.random.split(key)
            dtype = jnp.dtype(jnp.array([1j]).real)
            a = uniform(key1, shape, dtype) * 2 * weight_std - weight_std
            b = uniform(key2, shape, dtype) * 2 * weight_std - weight_std
            return a + 1j * b
        else:
            return uniform(key, shape, dtype) * 2 * weight_std - weight_std

    return init_fun


def grid_init(grid_dimension, dtype):
    def init_fun(dtype=dtype):
        coord_axis = [jnp.linspace(-3, 3, d) for d in grid_dimension]
        grid = jnp.stack(jnp.meshgrid(*coord_axis), -1)
        return jnp.asarray(grid, dtype)

    return init_fun


class Sine(nn.Module):
    w0: float = 1.0
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, inputs: ArrayLike) -> ArrayLike:
        inputs = jnp.asarray(inputs, self.dtype)
        return jnp.sin(self.w0 * inputs)


class SirenLayer(nn.Module):
    features: int = 32
    w0: float = 1.0
    c: float = 6.0
    is_first: bool = False
    use_bias: bool = True
    act: Callable = jnp.sin
    precision: Any = None
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, inputs: ArrayLike) -> ArrayLike:
        inputs = jnp.asarray(inputs, self.dtype)
        input_dim = inputs.shape[-1]

        # Linear projection with init proposed in SIREN paper
        weight_std = (
            (1 / input_dim) if self.is_first else jnp.sqrt(self.c / input_dim) / self.w0
        )

        kernel = self.param(
            "kernel", siren_init(weight_std, self.dtype), (input_dim, self.features)
        )
        kernel = jnp.asarray(kernel, self.dtype)

        y = lax.dot_general(
            inputs,
            kernel,
            (((inputs.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )

        if self.use_bias:
            bias = self.param("bias", uniform, (self.features,))
            bias = jnp.asarray(bias, self.dtype)
            y = y + bias

        return self.w0 * y


class ModulatedLayer(nn.Module):
    features: int = 32
    is_first: bool = False
    synthesis_act: Callable = jnp.sin
    modulator_act: Callable = nn.relu
    precision: Any = None
    dtype: Any = jnp.float32
    w0_first_layer: float = 30.0
    w0: float = 1.0

    @nn.compact
    def __call__(
        self,
        input: ArrayLike,
        latent: ArrayLike,
        hidden: ArrayLike,
    ) -> Tuple[ArrayLike, ArrayLike]:
        # Get new modulation amplitude
        if self.synthesis_act in [jnp.sin]:
            synth_dense = SirenLayer(
                features=self.features,
                w0=self.w0_first_layer if self.is_first else self.w0,
                is_first=self.is_first,
                act=self.synthesis_act,
                dtype=self.dtype,
            )
        else:
            synth_dense = nn.Dense(
                features=self.features,
                precision=self.precision,
                dtype=self.dtype,
            )

        modulator_dense = nn.Dense(
            features=self.features,
            precision=self.precision,
            dtype=self.dtype,
            name="mod_dense",
        )

        if self.is_first:
            # Prepare hidden state
            hidden_state_init = nn.Dense(
                self.features,
                precision=self.precision,
                dtype=self.dtype,
            )
            hidden = hidden_state_init(latent)

        # Build modulation signal and generate
        mod_input = jnp.concatenate([hidden, latent])
        alpha = self.modulator_act(modulator_dense(mod_input))
        synth_dense_output = self.synthesis_act(synth_dense(input))
        output = alpha * synth_dense_output
        return output, alpha


class Siren(nn.Module):
    hidden_channels: Sequence[int]  # number of hidden channels including output
    w0: float = 1.0  # Frequency of the sine activations
    w0_first_layer: float = 1.0  # Frequency of the sine activations in the first layer
    use_bias: bool = True  # Whether to use bias in the layers
    final_activation: Callable = lambda x: x  # Identity
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(
        self,
        inputs: ArrayLike,
    ) -> ArrayLike:
        x = jnp.asarray(inputs, self.dtype)

        for layer_idx, channels in enumerate(self.hidden_channels[:-1]):
            is_first = layer_idx == 0

            x = SirenLayer(
                features=channels,
                w0=self.w0_first_layer if is_first else self.w0,
                is_first=is_first,
                use_bias=self.use_bias,
            )(x)

        # Last layer, with different activation function
        x = SirenLayer(
            features=self.hidden_channels[-1],
            w0=self.w0,
            is_first=False,
            use_bias=self.use_bias,
            act=self.final_activation,
        )(x)

        return x


class ModulatedSiren(nn.Module):
    hidden_channels: Sequence[int]  # number of hidden channels including output
    synthesis_act: Callable = jnp.sin
    modulator_act: Callable = nn.relu
    w0_first_layer: float = 30.0
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(
        self,
        inputs: ArrayLike,
        latent: ArrayLike,
    ) -> ArrayLike:
        x = jnp.asarray(inputs, self.dtype)
        latent = jnp.asarray(latent, self.dtype)
        hidden = None

        for layer_idx, channels in enumerate(self.hidden_channels[:-1]):
            is_first = layer_idx == 0

            x, hidden = ModulatedLayer(
                features=channels,
                is_first=is_first,
                synthesis_act=self.synthesis_act,
                modulator_act=self.modulator_act,
                dtype=self.dtype,
                w0_first_layer=self.w0_first_layer,
            )(x, latent, hidden)

        # Last layer
        x = nn.Dense(
            self.hidden_channels[-1],
            dtype=self.dtype,
            name="output_layer",
        )(x)
        return x

:::

latent_dim = 40
latent = jax.random.normal(jax.random.PRNGKey(42), (latent_dim,)) * 0.01

grid = jnp.repeat(
    jnp.linspace(-1, 1, 40)[None],
    96000,
    axis=0,
)

modulated_siren = ModulatedSiren(hidden_channels=[64, 1], synthesis_act=nn.selu)

print(latent.shape)
print(grid.shape)
out, variables = modulated_siren.init_with_output(
    jax.random.PRNGKey(42),
    grid,
    latent,
)

print(out.shape)
print(
    modulated_siren.tabulate(
        jax.random.PRNGKey(42),
        jnp.empty_like(grid),
        jnp.empty_like(latent),
        column_kwargs={"no_wrap": True},
        table_kwargs={"expand": True},
        console_kwargs={"width": 120},
        depth=2,
    )
)

::: {#cell-8 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class RandomFourierFeatures(nn.Module):
    scale: float
    n_features: int

    @nn.compact
    def __call__(self, x):
        B = (
            self.variable(
                "buffers",
                "B",
                nn.initializers.normal(),
                jax.random.PRNGKey(42),
                (
                    x.shape[-1],
                    self.n_features // 2,
                ),
            ).value
            * self.scale
        )
        x_proj = (2 * jnp.pi * x) @ B
        return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1)

:::

rff = RandomFourierFeatures(scale=1.0, n_features=256)
variables = rff.init(jax.random.PRNGKey(42), jnp.ones((1, 2)))
out = rff.apply(variables, jnp.ones((1, 2)))
print(out)