Convolutional models

::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

import jax
import jax.numpy as jnp
import flax.linen as nn
from typing import Tuple
from functools import partial

:::

::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

Conv3x3 = partial(nn.Conv, kernel_size=(3, 3))


class ConvRelu2(nn.Module):
    """Two unpadded convolutions & relus.

    Attributes:
      features: Num convolutional features.
      padding: Type of padding: 'SAME' or 'VALID'.
      norm: Whether to use batchnorm at the end or not.
    """

    features: int
    padding: str = "SAME"
    norm: str = "layer"
    training: bool = True

    @nn.compact
    def __call__(
        self,
        x: jnp.ndarray,
    ) -> jnp.ndarray:
        x = Conv3x3(features=self.features, name="conv1", padding=self.padding)(x)
        if self.norm in ["batch"]:
            x = nn.BatchNorm(use_running_average=not self.training, axis_name="batch")(
                x
            )
        x = nn.relu(x)
        x = Conv3x3(features=self.features, name="conv2", padding=self.padding)(x)
        if self.norm in ["batch"]:
            x = nn.BatchNorm(use_running_average=not self.training, axis_name="batch")(
                x
            )
        x = nn.relu(x)
        return x


class DeConv3x3(nn.Module):
    """Deconvolution layer for upscaling.

    Attributes:
      features: Num convolutional features.
      padding: Type of padding: 'SAME' or 'VALID'.
      norm: Whether to use batchnorm at the end or not.
    """

    features: int
    padding: str = "SAME"
    norm: str = "layer"
    training: bool = True

    @nn.compact
    def __call__(
        self,
        x: jnp.ndarray,
    ) -> jnp.ndarray:
        """Applies transposed convolution with 3x3 kernel."""
        # NOTE: In the scenic code this is a deconvolution.

        if self.padding == "SAME":
            padding = ((1, 2), (1, 2))
        elif self.padding == "VALID":
            padding = ((0, 0), (0, 0))
        else:
            raise ValueError(f"Unkonwn padding: {self.padding}")
        x = nn.Conv(
            features=self.features,
            kernel_size=(3, 3),
            input_dilation=(2, 2),
            padding=padding,
        )(x)
        if self.norm in ["batch"]:
            x = nn.BatchNorm(use_running_average=not self.training, axis_name="batch")(
                x
            )
        return x


class DownsampleBlock(nn.Module):
    """Two unpadded convolutions & downsample 2x.

    Attributes:
      features: Num convolutional features.
      padding: Type of padding: 'SAME' or 'VALID'.
      norm: Whether to use batchnorm at the end or not.
    """

    features: int
    padding: str = "SAME"
    norm: str = "layer"
    training: bool = True

    @nn.compact
    def __call__(
        self,
        x: jnp.ndarray,
    ) -> jnp.ndarray:
        x = ConvRelu2(
            features=self.features,
            padding=self.padding,
            norm=self.norm,
            training=self.training,
        )(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        return x


class UpsampleBlock(nn.Module):
    features: int
    kernel_size: Tuple[int, int] = (3, 3)
    padding: str = "SAME"
    norm: str = "layer"
    training: bool = True

    @nn.compact
    def __call__(self, x):
        x = ConvRelu2(
            self.features,
            padding=self.padding,
            norm=self.norm,
            training=self.training,
        )(x)
        x = DeConv3x3(
            self.features // 2,
            padding=self.padding,
            norm=self.norm,
            training=self.training,
        )(x)
        return x


class ConvEncoder(nn.Module):
    block_size: Tuple[int, ...] = (16, 32, 64)
    padding: str = "SAME"
    norm: str = "layer"
    training: bool = True

    @nn.compact
    def __call__(
        self,
        x: jnp.ndarray,
    ) -> jnp.ndarray:
        for i, features in enumerate(self.block_size):
            x = DownsampleBlock(
                features=features,
                padding=self.padding,
                norm=self.norm,
                training=self.training,
            )(x)

        return x


class ConvDecoder(nn.Module):
    output_features: int = 3
    block_size: Tuple[int, ...] = (16, 32, 64)
    padding: str = "SAME"
    norm: str = "layer"
    training: bool = True

    @nn.compact
    def __call__(
        self,
        x: jnp.ndarray,
    ) -> jnp.ndarray:
        for i, features in enumerate(self.block_size):
            x = UpsampleBlock(
                features=features,
                padding=self.padding,
                norm=self.norm,
                training=self.training,
            )(x)

        # Final convolution to reconstruct the output
        x = nn.Conv(
            features=self.output_features,
            kernel_size=(3, 3),
            padding=self.padding,
        )(x)
        return x

:::

::: {#cell-5 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

dummy_2d = jnp.ones((40, 40, 3))

block_size = (12, 24, 48)
padding = "SAME"
norm = "layer"

conv_encoder = ConvEncoder(
    block_size=block_size,
    padding=padding,
    norm=norm,
)
conv_vars = conv_encoder.init(jax.random.PRNGKey(0), jnp.ones_like(dummy_2d))
out = conv_encoder.apply(conv_vars, dummy_2d)

conv_decoder = ConvDecoder(
    output_features=3,
    block_size=block_size,
    padding=padding,
    norm=norm,
)

conv_dec_vars = conv_decoder.init(jax.random.PRNGKey(0), jnp.ones_like(out))
out = conv_decoder.apply(conv_dec_vars, out)

assert out.shape == dummy_2d.shape

:::