= jnp.ones((40, 40, 3))
dummy_2d
= (12, 24, 48)
block_size = "SAME"
padding = "layer"
norm
= ConvEncoder(
conv_encoder =block_size,
block_size=padding,
padding=norm,
norm
)= conv_encoder.init(jax.random.PRNGKey(0), jnp.ones_like(dummy_2d))
conv_vars = conv_encoder.apply(conv_vars, dummy_2d)
out
= ConvDecoder(
conv_decoder =3,
output_features=block_size,
block_size=padding,
padding=norm,
norm
)
= conv_decoder.init(jax.random.PRNGKey(0), jnp.ones_like(out))
conv_dec_vars = conv_decoder.apply(conv_dec_vars, out)
out
assert out.shape == dummy_2d.shape
Convolutional models
ConvDecoder
ConvDecoder (output_features:int=3, block_size:Tuple[int,...]=(16, 32, 64), padding:str='SAME', norm:str='layer', training:bool=True, parent:Union[flax.linen.module.Module,fl ax.core.scope.Scope,flax.linen.module._Sentinel,NoneType]=<f lax.linen.module._Sentinel object at 0x11a1a75e0>, name:Optional[str]=None)
ConvEncoder
ConvEncoder (block_size:Tuple[int,...]=(16, 32, 64), padding:str='SAME', norm:str='layer', training:bool=True, parent:Union[flax.line n.module.Module,flax.core.scope.Scope,flax.linen.module._Sen tinel,NoneType]=<flax.linen.module._Sentinel object at 0x11a1a75e0>, name:Optional[str]=None)
UpsampleBlock
UpsampleBlock (features:int, kernel_size:Tuple[int,int]=(3, 3), padding:str='SAME', norm:str='layer', training:bool=True, parent:Union[flax.linen.module.Module,flax.core.scope.Scop e,flax.linen.module._Sentinel,NoneType]=<flax.linen.module ._Sentinel object at 0x11a1a75e0>, name:Optional[str]=None)
DownsampleBlock
DownsampleBlock (features:int, padding:str='SAME', norm:str='layer', training:bool=True, parent:Union[flax.linen.module.Modul e,flax.core.scope.Scope,flax.linen.module._Sentinel,None Type]=<flax.linen.module._Sentinel object at 0x11a1a75e0>, name:Optional[str]=None)
*Two unpadded convolutions & downsample 2x.
Attributes: features: Num convolutional features. padding: Type of padding: ‘SAME’ or ‘VALID’. norm: Whether to use batchnorm at the end or not.*
DeConv3x3
DeConv3x3 (features:int, padding:str='SAME', norm:str='layer', training:bool=True, parent:Union[flax.linen.module.Module,flax .core.scope.Scope,flax.linen.module._Sentinel,NoneType]=<flax. linen.module._Sentinel object at 0x11a1a75e0>, name:Optional[str]=None)
*Deconvolution layer for upscaling.
Attributes: features: Num convolutional features. padding: Type of padding: ‘SAME’ or ‘VALID’. norm: Whether to use batchnorm at the end or not.*
ConvRelu2
ConvRelu2 (features:int, padding:str='SAME', norm:str='layer', training:bool=True, parent:Union[flax.linen.module.Module,flax .core.scope.Scope,flax.linen.module._Sentinel,NoneType]=<flax. linen.module._Sentinel object at 0x11a1a75e0>, name:Optional[str]=None)
*Two unpadded convolutions & relus.
Attributes: features: Num convolutional features. padding: Type of padding: ‘SAME’ or ‘VALID’. norm: Whether to use batchnorm at the end or not.*