Convolutional models
::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
import jax
import jax.numpy as jnp
import flax.linen as nn
from typing import Tuple
from functools import partial
:::
::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
= partial(nn.Conv, kernel_size=(3, 3))
Conv3x3
class ConvRelu2(nn.Module):
"""Two unpadded convolutions & relus.
Attributes:
features: Num convolutional features.
padding: Type of padding: 'SAME' or 'VALID'.
norm: Whether to use batchnorm at the end or not.
"""
int
features: str = "SAME"
padding: str = "layer"
norm: bool = True
training:
@nn.compact
def __call__(
self,
x: jnp.ndarray,-> jnp.ndarray:
) = Conv3x3(features=self.features, name="conv1", padding=self.padding)(x)
x if self.norm in ["batch"]:
= nn.BatchNorm(use_running_average=not self.training, axis_name="batch")(
x
x
)= nn.relu(x)
x = Conv3x3(features=self.features, name="conv2", padding=self.padding)(x)
x if self.norm in ["batch"]:
= nn.BatchNorm(use_running_average=not self.training, axis_name="batch")(
x
x
)= nn.relu(x)
x return x
class DeConv3x3(nn.Module):
"""Deconvolution layer for upscaling.
Attributes:
features: Num convolutional features.
padding: Type of padding: 'SAME' or 'VALID'.
norm: Whether to use batchnorm at the end or not.
"""
int
features: str = "SAME"
padding: str = "layer"
norm: bool = True
training:
@nn.compact
def __call__(
self,
x: jnp.ndarray,-> jnp.ndarray:
) """Applies transposed convolution with 3x3 kernel."""
# NOTE: In the scenic code this is a deconvolution.
if self.padding == "SAME":
= ((1, 2), (1, 2))
padding elif self.padding == "VALID":
= ((0, 0), (0, 0))
padding else:
raise ValueError(f"Unkonwn padding: {self.padding}")
= nn.Conv(
x =self.features,
features=(3, 3),
kernel_size=(2, 2),
input_dilation=padding,
padding
)(x)if self.norm in ["batch"]:
= nn.BatchNorm(use_running_average=not self.training, axis_name="batch")(
x
x
)return x
class DownsampleBlock(nn.Module):
"""Two unpadded convolutions & downsample 2x.
Attributes:
features: Num convolutional features.
padding: Type of padding: 'SAME' or 'VALID'.
norm: Whether to use batchnorm at the end or not.
"""
int
features: str = "SAME"
padding: str = "layer"
norm: bool = True
training:
@nn.compact
def __call__(
self,
x: jnp.ndarray,-> jnp.ndarray:
) = ConvRelu2(
x =self.features,
features=self.padding,
padding=self.norm,
norm=self.training,
training
)(x)= nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x return x
class UpsampleBlock(nn.Module):
int
features: int, int] = (3, 3)
kernel_size: Tuple[str = "SAME"
padding: str = "layer"
norm: bool = True
training:
@nn.compact
def __call__(self, x):
= ConvRelu2(
x self.features,
=self.padding,
padding=self.norm,
norm=self.training,
training
)(x)= DeConv3x3(
x self.features // 2,
=self.padding,
padding=self.norm,
norm=self.training,
training
)(x)return x
class ConvEncoder(nn.Module):
int, ...] = (16, 32, 64)
block_size: Tuple[str = "SAME"
padding: str = "layer"
norm: bool = True
training:
@nn.compact
def __call__(
self,
x: jnp.ndarray,-> jnp.ndarray:
) for i, features in enumerate(self.block_size):
= DownsampleBlock(
x =features,
features=self.padding,
padding=self.norm,
norm=self.training,
training
)(x)
return x
class ConvDecoder(nn.Module):
int = 3
output_features: int, ...] = (16, 32, 64)
block_size: Tuple[str = "SAME"
padding: str = "layer"
norm: bool = True
training:
@nn.compact
def __call__(
self,
x: jnp.ndarray,-> jnp.ndarray:
) for i, features in enumerate(self.block_size):
= UpsampleBlock(
x =features,
features=self.padding,
padding=self.norm,
norm=self.training,
training
)(x)
# Final convolution to reconstruct the output
= nn.Conv(
x =self.output_features,
features=(3, 3),
kernel_size=self.padding,
padding
)(x)return x
:::
::: {#cell-5 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= jnp.ones((40, 40, 3))
dummy_2d
= (12, 24, 48)
block_size = "SAME"
padding = "layer"
norm
= ConvEncoder(
conv_encoder =block_size,
block_size=padding,
padding=norm,
norm
)= conv_encoder.init(jax.random.PRNGKey(0), jnp.ones_like(dummy_2d))
conv_vars = conv_encoder.apply(conv_vars, dummy_2d)
out
= ConvDecoder(
conv_decoder =3,
output_features=block_size,
block_size=padding,
padding=norm,
norm
)
= conv_decoder.init(jax.random.PRNGKey(0), jnp.ones_like(out))
conv_dec_vars = conv_decoder.apply(conv_dec_vars, out)
out
assert out.shape == dummy_2d.shape
:::