Convolutional models


source

ConvDecoder

 ConvDecoder (output_features:int=3, block_size:Tuple[int,...]=(16, 32,
              64), padding:str='SAME', norm:str='layer',
              training:bool=True, parent:Union[flax.linen.module.Module,fl
              ax.core.scope.Scope,flax.linen.module._Sentinel,NoneType]=<f
              lax.linen.module._Sentinel object at 0x11a1a75e0>,
              name:Optional[str]=None)

source

ConvEncoder

 ConvEncoder (block_size:Tuple[int,...]=(16, 32, 64), padding:str='SAME',
              norm:str='layer', training:bool=True, parent:Union[flax.line
              n.module.Module,flax.core.scope.Scope,flax.linen.module._Sen
              tinel,NoneType]=<flax.linen.module._Sentinel object at
              0x11a1a75e0>, name:Optional[str]=None)

source

UpsampleBlock

 UpsampleBlock (features:int, kernel_size:Tuple[int,int]=(3, 3),
                padding:str='SAME', norm:str='layer', training:bool=True, 
                parent:Union[flax.linen.module.Module,flax.core.scope.Scop
                e,flax.linen.module._Sentinel,NoneType]=<flax.linen.module
                ._Sentinel object at 0x11a1a75e0>,
                name:Optional[str]=None)

source

DownsampleBlock

 DownsampleBlock (features:int, padding:str='SAME', norm:str='layer',
                  training:bool=True, parent:Union[flax.linen.module.Modul
                  e,flax.core.scope.Scope,flax.linen.module._Sentinel,None
                  Type]=<flax.linen.module._Sentinel object at
                  0x11a1a75e0>, name:Optional[str]=None)

*Two unpadded convolutions & downsample 2x.

Attributes: features: Num convolutional features. padding: Type of padding: ‘SAME’ or ‘VALID’. norm: Whether to use batchnorm at the end or not.*


source

DeConv3x3

 DeConv3x3 (features:int, padding:str='SAME', norm:str='layer',
            training:bool=True, parent:Union[flax.linen.module.Module,flax
            .core.scope.Scope,flax.linen.module._Sentinel,NoneType]=<flax.
            linen.module._Sentinel object at 0x11a1a75e0>,
            name:Optional[str]=None)

*Deconvolution layer for upscaling.

Attributes: features: Num convolutional features. padding: Type of padding: ‘SAME’ or ‘VALID’. norm: Whether to use batchnorm at the end or not.*


source

ConvRelu2

 ConvRelu2 (features:int, padding:str='SAME', norm:str='layer',
            training:bool=True, parent:Union[flax.linen.module.Module,flax
            .core.scope.Scope,flax.linen.module._Sentinel,NoneType]=<flax.
            linen.module._Sentinel object at 0x11a1a75e0>,
            name:Optional[str]=None)

*Two unpadded convolutions & relus.

Attributes: features: Num convolutional features. padding: Type of padding: ‘SAME’ or ‘VALID’. norm: Whether to use batchnorm at the end or not.*

dummy_2d = jnp.ones((40, 40, 3))

block_size = (12, 24, 48)
padding = "SAME"
norm = "layer"

conv_encoder = ConvEncoder(
    block_size=block_size,
    padding=padding,
    norm=norm,
)
conv_vars = conv_encoder.init(jax.random.PRNGKey(0), jnp.ones_like(dummy_2d))
out = conv_encoder.apply(conv_vars, dummy_2d)

conv_decoder = ConvDecoder(
    output_features=3,
    block_size=block_size,
    padding=padding,
    norm=norm,
)

conv_dec_vars = conv_decoder.init(jax.random.PRNGKey(0), jnp.ones_like(out))
out = conv_decoder.apply(conv_dec_vars, out)

assert out.shape == dummy_2d.shape