Convolutional models
::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
import jax
import jax.numpy as jnp
import flax.linen as nn
from typing import Tuple
from functools import partial:::
::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
Conv3x3 = partial(nn.Conv, kernel_size=(3, 3))
class ConvRelu2(nn.Module):
"""Two unpadded convolutions & relus.
Attributes:
features: Num convolutional features.
padding: Type of padding: 'SAME' or 'VALID'.
norm: Whether to use batchnorm at the end or not.
"""
features: int
padding: str = "SAME"
norm: str = "layer"
training: bool = True
@nn.compact
def __call__(
self,
x: jnp.ndarray,
) -> jnp.ndarray:
x = Conv3x3(features=self.features, name="conv1", padding=self.padding)(x)
if self.norm in ["batch"]:
x = nn.BatchNorm(use_running_average=not self.training, axis_name="batch")(
x
)
x = nn.relu(x)
x = Conv3x3(features=self.features, name="conv2", padding=self.padding)(x)
if self.norm in ["batch"]:
x = nn.BatchNorm(use_running_average=not self.training, axis_name="batch")(
x
)
x = nn.relu(x)
return x
class DeConv3x3(nn.Module):
"""Deconvolution layer for upscaling.
Attributes:
features: Num convolutional features.
padding: Type of padding: 'SAME' or 'VALID'.
norm: Whether to use batchnorm at the end or not.
"""
features: int
padding: str = "SAME"
norm: str = "layer"
training: bool = True
@nn.compact
def __call__(
self,
x: jnp.ndarray,
) -> jnp.ndarray:
"""Applies transposed convolution with 3x3 kernel."""
# NOTE: In the scenic code this is a deconvolution.
if self.padding == "SAME":
padding = ((1, 2), (1, 2))
elif self.padding == "VALID":
padding = ((0, 0), (0, 0))
else:
raise ValueError(f"Unkonwn padding: {self.padding}")
x = nn.Conv(
features=self.features,
kernel_size=(3, 3),
input_dilation=(2, 2),
padding=padding,
)(x)
if self.norm in ["batch"]:
x = nn.BatchNorm(use_running_average=not self.training, axis_name="batch")(
x
)
return x
class DownsampleBlock(nn.Module):
"""Two unpadded convolutions & downsample 2x.
Attributes:
features: Num convolutional features.
padding: Type of padding: 'SAME' or 'VALID'.
norm: Whether to use batchnorm at the end or not.
"""
features: int
padding: str = "SAME"
norm: str = "layer"
training: bool = True
@nn.compact
def __call__(
self,
x: jnp.ndarray,
) -> jnp.ndarray:
x = ConvRelu2(
features=self.features,
padding=self.padding,
norm=self.norm,
training=self.training,
)(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
return x
class UpsampleBlock(nn.Module):
features: int
kernel_size: Tuple[int, int] = (3, 3)
padding: str = "SAME"
norm: str = "layer"
training: bool = True
@nn.compact
def __call__(self, x):
x = ConvRelu2(
self.features,
padding=self.padding,
norm=self.norm,
training=self.training,
)(x)
x = DeConv3x3(
self.features // 2,
padding=self.padding,
norm=self.norm,
training=self.training,
)(x)
return x
class ConvEncoder(nn.Module):
block_size: Tuple[int, ...] = (16, 32, 64)
padding: str = "SAME"
norm: str = "layer"
training: bool = True
@nn.compact
def __call__(
self,
x: jnp.ndarray,
) -> jnp.ndarray:
for i, features in enumerate(self.block_size):
x = DownsampleBlock(
features=features,
padding=self.padding,
norm=self.norm,
training=self.training,
)(x)
return x
class ConvDecoder(nn.Module):
output_features: int = 3
block_size: Tuple[int, ...] = (16, 32, 64)
padding: str = "SAME"
norm: str = "layer"
training: bool = True
@nn.compact
def __call__(
self,
x: jnp.ndarray,
) -> jnp.ndarray:
for i, features in enumerate(self.block_size):
x = UpsampleBlock(
features=features,
padding=self.padding,
norm=self.norm,
training=self.training,
)(x)
# Final convolution to reconstruct the output
x = nn.Conv(
features=self.output_features,
kernel_size=(3, 3),
padding=self.padding,
)(x)
return x:::
::: {#cell-5 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
dummy_2d = jnp.ones((40, 40, 3))
block_size = (12, 24, 48)
padding = "SAME"
norm = "layer"
conv_encoder = ConvEncoder(
block_size=block_size,
padding=padding,
norm=norm,
)
conv_vars = conv_encoder.init(jax.random.PRNGKey(0), jnp.ones_like(dummy_2d))
out = conv_encoder.apply(conv_vars, dummy_2d)
conv_decoder = ConvDecoder(
output_features=3,
block_size=block_size,
padding=padding,
norm=norm,
)
conv_dec_vars = conv_decoder.init(jax.random.PRNGKey(0), jnp.ones_like(out))
out = conv_decoder.apply(conv_dec_vars, out)
assert out.shape == dummy_2d.shape:::