Dataloading utilities

::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

from pathlib import Path
import numpy as np
import jax.numpy as jnp
from typing import Optional, List, Union
from numpy.lib.stride_tricks import sliding_window_view
from einops import rearrange
from collections.abc import Iterable
from absl import logging

:::

::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def standardize(x, mean, std, only_scale=True):
    return (x - mean) / std if not only_scale else x / std


def unstandardize(x, mean, std, only_scale=True):
    return x * std + mean if not only_scale else x * std


# load the data
def load_as_big_array(files):
    return np.stack([np.load(file) for file in files], axis=0)


def reshape_array(array):
    n_runs, n_timesteps, n_gridpoints, n_channels = array.shape
    return np.reshape(array, (n_runs * n_timesteps, n_gridpoints, n_channels))

:::

::: {#cell-5 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def create_grid(
    height: int,
    width: int,
    min_val: float = -1.0,
    max_val: float = 1.0,
):
    """
    Create a grid of size (height, width) with values between min_val and max_val inclusive.
    """

    y, x = jnp.meshgrid(
        jnp.linspace(-min_val, max_val, height),
        jnp.linspace(-min_val, max_val, width),
    )

    grid = jnp.stack([x, y], axis=-1)
    return grid

:::

::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def hankel_matrix(
    x,  # input array with shape (B, grid_size, time_steps, channels)
    depth: int = 2,  # repeats the array on the first axis n times
) -> (
    np.ndarray
):  # returns an array with shape (B, grid_size * depth, time_steps - (d-1), channels)
    x = x.transpose(0, 2, 1, 3)
    d = depth - 1
    b = sliding_window_view(x, window_shape=(x.shape[1], x.shape[2] - d), axis=(1, 2))
    b = b.transpose(0, 1, 2, 4, 5, 3)
    b = b.reshape(x.shape[0], -1, x.shape[2] - d, x.shape[-1])
    return b.transpose(0, 2, 1, 3)

:::

::: {#cell-7 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

a = np.arange(20).reshape(5, 4)
a = np.stack([a, a, a], axis=0)
a = np.stack([a, a * 2], axis=-1)
d = 2

b = hankel_matrix(a, depth=2)
assert b.shape == (a.shape[0], a.shape[1] - (d - 1), a.shape[2] * d, a.shape[3])

:::

::: {#cell-8 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

import jax
from timeit import default_timer as timer
from functools import partial
from typing import Tuple

:::

::: {#cell-9 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def slice_tensor_multi(
    key: jax.random.PRNGKey,
    x: jnp.ndarray,  # input array with shape (T, ...)
    num_slices: int,  # number of slices to take
    num_input: int,  # number of time steps as input
    num_target: int,  # number of time steps as target
    split_mode: str = "no_overlap",  # no_split, overlap, no_overlap
):
    """
    Many random slices of the tensor along the time axis
    Warning: Here the time axis is the first axis of the tensor
    """
    total_samples = x.shape[0]
    num_steps = num_input + num_target
    last_start_sample = total_samples - num_steps
    indices = jax.random.choice(key, last_start_sample, shape=(num_slices,))

    def slice_tensor(start_index):
        if split_mode == "no_overlap":
            input_slices = jax.lax.dynamic_slice_in_dim(
                x, start_index, num_input, axis=0
            )
            target_slices = jax.lax.dynamic_slice_in_dim(
                x, start_index + num_input, num_target, axis=0
            )
            return input_slices, target_slices
        elif split_mode == "overlap":
            input_slices = jax.lax.dynamic_slice_in_dim(
                x, start_index, num_steps - 1, axis=0
            )
            target_slices = jax.lax.dynamic_slice_in_dim(
                x, start_index + 1, num_steps - 1, axis=0
            )
            return input_slices, target_slices
        elif split_mode == "no_split":
            slices = jax.lax.dynamic_slice_in_dim(x, start_index, num_steps, axis=0)
            return slices

    slices = jax.vmap(slice_tensor)(indices)

    return slices


def slice_tensor_single(
    key: jax.random.PRNGKey,
    x: jnp.ndarray,  # input array with shape (B, T, ...)
    num_input,  # number of time steps to slice
    num_target,  # number of time steps to predict
    split_mode: str = "no_overlap",  # no_split, overlap, no_overlap
):
    # random slices
    total_samples = x.shape[1]
    num_steps = num_input + num_target
    last_start_sample = total_samples - num_steps

    start_index = jax.random.choice(
        key,
        last_start_sample,
    )

    def slice_tensor(start_index):
        if split_mode == "no_overlap":
            input_slices = jax.lax.dynamic_slice_in_dim(
                x, start_index, num_input, axis=1
            )
            target_slices = jax.lax.dynamic_slice_in_dim(
                x, start_index + num_input, num_target, axis=1
            )
            return input_slices, target_slices
        elif split_mode == "overlap":
            input_slices = jax.lax.dynamic_slice_in_dim(
                x, start_index, num_steps - 1, axis=1
            )
            target_slices = jax.lax.dynamic_slice_in_dim(
                x, start_index + 1, num_steps - 1, axis=1
            )
            return input_slices, target_slices
        elif split_mode == "no_split":
            slices = jax.lax.dynamic_slice_in_dim(x, start_index, num_steps, axis=1)
            return slices

    # get the slices of the tensor
    slices = slice_tensor(start_index)

    return slices


def split_xy(
    data,
    num_input,
    num_target,
):
    """
    Split the data into input and target
    """
    x = jax.lax.slice_in_dim(data, 0, num_input, axis=1)
    y = jax.lax.slice_in_dim(data, num_input, data.shape[1], axis=1)
    return x, y


@partial(jax.jit, static_argnames=("mode", "num_input", "num_target", "batch_size"))
def select_slices(
    key: jax.random.PRNGKey,
    idx: jnp.ndarray,  # batch indices
    data: jnp.ndarray,  # input array with shape (B, T, ...)
    indices: jnp.ndarray,  # indices of the batches
    num_input: int,  # number of time steps to slice
    num_target: int,  # number of time steps to predict
    mode: str,  # blocks, sequential, passthrough
    batch_size: int,  # number of slices to take in multi_block mode
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    batch_data = data[indices[idx]]
    if mode == "single_random_no_overlap":
        return slice_tensor_single(key, batch_data, num_input, num_target, "no_overlap")
    elif mode == "single_random_overlap":
        return slice_tensor_single(key, batch_data, num_input, num_target, "overlap")
    elif mode == "single_random_no_split":
        return slice_tensor_single(key, batch_data, num_input, num_target, "no_split")
    elif mode == "many_random_no_overlap":
        return slice_tensor_multi(
            key, batch_data, batch_size, num_input, num_target, "no_overlap"
        )
    elif mode == "many_random_overlap":
        return slice_tensor_multi(
            key, batch_data, batch_size, num_input, num_target, "overlap"
        )
    elif mode == "many_random_no_split":
        return slice_tensor_multi(
            key, batch_data, batch_size, num_input, num_target, "no_split"
        )
    elif mode == "split":
        return split_xy(batch_data, num_input, num_target)
    elif mode == "passthrough":
        return batch_data

:::

::: {#cell-10 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

B, T, W, C = 16, 4000, 40, 2
x = jnp.ones((B, T, W, C))

num_input = 1
num_target = 199

x, y = select_slices(
    key=jax.random.PRNGKey(65),
    idx=0,
    data=x,
    indices=jnp.arange(B),
    num_input=num_input,
    num_target=num_target,
    mode="many_random_no_overlap",
    batch_size=B,
)
assert x.shape == (B, num_input, W, C)
assert y.shape == (B, num_target, W, C)

:::

::: {#cell-11 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

x = jnp.ones((B, T, W, C))

x, y = select_slices(
    key=jax.random.PRNGKey(65),
    idx=0,
    data=x,
    indices=jnp.arange(B),
    num_input=1,
    num_target=199,
    mode="many_random_overlap",
    batch_size=B,
)
assert x.shape == (B, 199, W, C)
assert y.shape == (B, 199, W, C)

:::

::: {#cell-12 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

x = jnp.ones((B, T, W, C))


x, y = select_slices(
    key=jax.random.PRNGKey(65),
    idx=jnp.arange(B),
    data=x,
    indices=jnp.arange(B),
    num_input=num_input,
    num_target=num_target,
    mode="single_random_no_overlap",
    batch_size=None,
)
assert x.shape == (B, num_input, W, C)
assert y.shape == (B, num_target, W, C)

:::

::: {#cell-13 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class JaxDataloader:
    def __init__(
        self,
        data: np.ndarray,
        num_input: int = -1,  # length of the input segment
        num_target: int = -1,  # length of the target segment
        batch_size: int = 16,  # batch size
        shuffle: bool = True,  # shuffle the data
        drop_last: bool = True,  # drop the last batch if it's smaller than batch_size
        key: Optional[jax.random.PRNGKey] = jax.random.PRNGKey(0),  # random key
        mode: str = "passthrough",  # mode of the dataloader
    ):
        self.data = data
        self.num_input = num_input
        self.num_target = num_target
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.key = key
        self.mode = mode
        self.data_fits_batch_size = self.data.shape[0] % batch_size == 0
        if self.mode in [
            "many_random_no_overlap",
            "many_random_overlap",
            "many_random_no_split",
        ]:
            # in many_random mode the batch size is the number of slices within a single trajectory
            self.indices = jnp.arange(data.shape[0])
        else:
            if not self.data_fits_batch_size:
                logging.info(
                    f"Warning: The data size {self.data.shape[0]} is not divisible by the batch size {batch_size}."
                )
                # Adjust batch size to fit data
                self.batch_size = self.data.shape[0] // (
                    self.data.shape[0] // batch_size + 1
                )
                logging.info(f"Setting the batch size to {self.batch_size}")
                # Update indices
                self.indices = jnp.arange(data.shape[0]).reshape(-1, self.batch_size)
            else:
                self.indices = jnp.arange(data.shape[0]).reshape(-1, batch_size)

        self._reset()

    def _reset(self):
        """Reset the dataloader for a new iteration over the data."""
        if self.shuffle:
            self.key = jax.random.split(self.key, 1)[0]
            self.indices = jax.random.permutation(self.key, self.indices)

        self.idx = 0

    def __iter__(self):
        return self

    def get_slices(self):
        # Increment the idx and fold it into the key for randomness
        self.key = jax.random.fold_in(self.key, self.idx)

        # Call the select_slices function with common parameters
        result = select_slices(
            self.key,
            self.idx,
            self.data,
            self.indices,
            self.num_input,
            self.num_target,
            self.mode,
            self.batch_size,
        )

        self.idx += 1

        return result

    def __next__(self) -> Union[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
        if self.idx >= len(self.indices):
            # Reset the iterator and raise StopIteration
            self._reset()
            raise StopIteration
        else:
            if self.mode not in [
                "many_random_no_overlap",
                "many_random_overlap",
                "many_random_no_split",
            ]:
                if self.drop_last and len(self.data) - self.idx < self.batch_size:
                    # Reset and stop if we're dropping the last batch and it's too small
                    self._reset()
                    raise StopIteration

            return self.get_slices()

    def __len__(self):
        if self.drop_last:
            # Exclude the last batch if it has fewer than batch_size elements
            return len(self.data) // self.batch_size
        else:
            # Include the last batch regardless of its size
            return int(np.ceil(len(self.data) / self.batch_size))

:::

::: {#cell-14 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def make_rolling_windows(
    x: np.ndarray,  # (B, T, ...)
    window_size: int,
):
    """
    Generate rolling windows of size window_size over the time axis of x.
    Warning: This might generate a large amount of data if the input is large.
    """
    x = rearrange(x, "b t ... -> b ... t")
    x = sliding_window_view(x, window_shape=(window_size), axis=(-1,))
    x = rearrange(x, "b ... n w -> (b n) w ...")
    return x

:::

::: {#cell-15 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

B, T, H, W, C = 5, 64, 40, 40, 2
win = 16
dummy = np.ones((B, T, H, W, C))
dummy = make_rolling_windows(dummy, win)
assert dummy.shape == (B * (T - win + 1), win, H, W, C)

:::

::: {#cell-16 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def split_data(
    mapped_data,
    split: List[float],
    extract_channels: List[int],
):
    n_train, n_val, n_test = [
        int(fraction * mapped_data.shape[0]) for fraction in split
    ]

    n_total = n_train + n_val + n_test
    if n_total < mapped_data.shape[0]:
        n_rest = mapped_data.shape[0] - n_total
        n_train += n_rest

    assert n_train + n_val + n_test == mapped_data.shape[0], (
        "Split fractions do not sum up correctly"
    )

    train = mapped_data[:n_train]
    val = mapped_data[n_train : n_train + n_val]
    test = mapped_data[n_train + n_val :]

    if extract_channels:
        train = train[..., extract_channels]
        val = val[..., extract_channels]
        test = test[..., extract_channels]

    return train, val, test

:::

::: {#cell-17 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def is_list_like(obj):
    return isinstance(obj, Iterable) and not isinstance(obj, (str, bytes, int))


class DirectoryDataModule:
    def __init__(
        self,
        data_array: Union[str, List[str]],
        split: List[float] = [0.8, 0.2, 0.0],  # train, val, test
        batch_size: int = 1,  # batch size per device (time steps in parallel mode)
        extract_channels: List[int] = [],  # extract only these channels (u, v)
        num_steps_train: Optional[
            Union[int, List[int]]
        ] = None,  # number of time steps to use for training
        num_steps_val: Optional[
            Union[int, List[int]]
        ] = None,  # number of time steps to use for validation
        standardize_dataset: bool = False,  # standardize the dataset per channel
        mean: Optional[float] = None,  # mean for standardization
        std: Optional[float] = None,  # std for standardization
        mode: str = "passthrough",  # mode of the dataloader
        windowed: bool = False,  # slice the data in a windowedd way
        hankelize: int = 0,  # hankelize the data with depth (0 = no hankelization)
        shuffle_train: bool = True,  # shuffle the training data
        cache: bool = False,  # cache the data
        rolling_windows: bool = False,  # generate rolling windows
        total_num_train: Optional[int] = None,  # total number of training samples
        total_num_val: Optional[int] = None,  # total number of validation samples
        total_num_test: Optional[int] = None,  # total number of test samples
    ):
        assert mode in [
            "single_random_no_overlap",
            "single_random_overlap",
            "single_random_no_split",
            "many_random_no_overlap",
            "many_random_overlap",
            "many_random_no_split",
            "split",
            "passthrough",
        ], "Invalid mode"
        self.mode = mode
        self.num_steps_train = num_steps_train
        self.num_steps_val = num_steps_val
        self.train_batch_size = batch_size
        self.val_batch_size = batch_size

        assert Path(data_array).exists() or all(
            [Path(d).exists() for d in data_array]
        ), "The data array does not exist"

        # if we have a list of directories, we will concatenate the data
        # else we will assume that the directory contains the data
        if is_list_like(data_array):
            # TODO: this will not work as we need to create a third memmap file
            raise NotImplementedError(
                "Concatenating data from different arrays is not supported yet"
            )
        else:
            mapped_data = np.load(data_array, mmap_mode="r")

        logging.info(f"Found {mapped_data.shape[0]} trajectories")

        # split
        if extract_channels is None:
            extract_channels = list(range(mapped_data.shape[-1]))
        train_array, val_array, test_array = split_data(
            mapped_data,
            split,
            extract_channels,
        )

        # slice only the number of time steps
        if total_num_train is not None:
            train_array = train_array[:, :total_num_train, ...]
        if total_num_val is not None:
            val_array = val_array[:, :total_num_val, ...]
        if total_num_test is not None:
            test_array = test_array[:, :total_num_test, ...]

        self.num_steps_input_train, self.num_steps_target_train = self._parse_num_steps(
            num_steps_train,
        )

        self.num_steps_input_val, self.num_steps_target_val = self._parse_num_steps(
            num_steps_val
        )

        if hankelize != 0:
            raise NotImplementedError("Hankelization is not supported yet")

        if windowed:
            train_array = self._windowed_data(
                train_array, self.num_steps_input_train + self.num_steps_target_train
            )
            val_array = self._windowed_data(
                val_array, self.num_steps_input_val + self.num_steps_target_val
            )

        if rolling_windows:
            train_array = make_rolling_windows(train_array, window_size=num_steps_train)
            val_array = make_rolling_windows(val_array, window_size=num_steps_val)

        # TODO: there is no reliable way to know the shape of the numpy array
        # we save the shape of the first entry
        self.data_shape = train_array[0].shape

        # after slicing to cache the data on the GPU or map some operation
        if cache:
            timer_start = timer()
            train_array = jax.device_put(train_array).block_until_ready()
            val_array = jax.device_put(val_array).block_until_ready()
            logging.info(f"Data cached in {timer() - timer_start} seconds")

        if standardize_dataset:
            # standardize the output per channel
            # the statistics are computed on the training set
            train_array, mean, std = self._standardize_data(train_array, mean, std)
            val_array, *_ = self._standardize_data(val_array, mean, std)
            test_array, *_ = self._standardize_data(test_array, mean, std)

            self.mean = mean
            self.std = std
            logging.info(
                f"The mean and std of the output are {self.mean} and {self.std}, you should save this for later"
            )

        self.train_dataloader = JaxDataloader(
            train_array,
            num_input=self.num_steps_input_train,
            num_target=self.num_steps_target_train,
            batch_size=batch_size,
            shuffle=shuffle_train,
            drop_last=True,
            mode=self.mode,
        )

        self.val_dataloader = JaxDataloader(
            val_array,
            num_input=self.num_steps_input_val,
            num_target=self.num_steps_target_val,
            batch_size=batch_size,
            shuffle=False,
            drop_last=True,
            mode=self.mode,
        )

        self.test_dataloader = JaxDataloader(
            test_array,
            batch_size=1,  # we will handle the batch size in the slice function
            shuffle=False,
            drop_last=True,
            mode="passthrough",  # always passthrough
        )

        logging.info(
            f"Using {self.num_steps_input_train} input and {self.num_steps_target_train} target steps for training in mode {mode}"
        )

    def _standardize_data(self, data_array, mean, std):
        if mean is None or std is None:
            if (
                len(data_array.shape) == 4
            ):  # (batch_size, time_steps, grid_size, channels)
                mean = jnp.mean(data_array, axis=(0, 1, 2))
                std = jnp.std(data_array, axis=(0, 1, 2))
            elif len(data_array.shape) == 5:  # (batch_size, time_steps, H, W, channels)
                mean = jnp.mean(data_array, axis=(0, 1, 2, 3))
                std = jnp.std(data_array, axis=(0, 1, 2, 3))
            else:
                raise ValueError(
                    "The input array has an incorrect number of dimensions"
                )
        data_array = standardize(data_array, mean, std)

        return data_array, mean, std

    def _windowed_data(
        self,
        data_array,
        num_steps,
    ):
        assert data_array.shape[1] % num_steps == 0, (
            "The number of time steps is not divisible by the slice size"
        )
        data_array = data_array.reshape(-1, num_steps, *data_array.shape[2:])
        return data_array

    def _parse_num_steps(
        self,
        num_steps,
    ):
        if is_list_like(num_steps):
            num_steps_input, num_steps_target = num_steps
        else:
            num_steps_input = num_steps
            num_steps_target = num_steps
        return num_steps_input, num_steps_target

    def get_info(self):
        # return the correct shape of the data after slicing
        if self.mode in ["single_random_overlap", "many_random_overlap"]:
            return [
                self.num_steps_input_train + self.num_steps_target_train - 1,
                *self.data_shape[1:],
            ]
        else:
            return [self.num_steps_input_train, *self.data_shape[1:]]

:::

::: {#cell-18 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

# generate and save the data
num_samples = 100
num_timesteps = 120
grid_size = 20
channels = 2

data = np.random.randn(num_samples, num_timesteps, grid_size, channels)
np.save("/tmp/data.npy", data)

:::

::: {#cell-19 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

batch_size = 2

datamodule = DirectoryDataModule(
    data_array="/tmp/data.npy",
    split=[0.8, 0.1, 0.1],
    batch_size=batch_size,
    extract_channels=[0, 1],
    num_steps_train=[1, 19],
    num_steps_val=[3, 17],
    cache=True,
    standardize_dataset=True,
    mode="many_random_no_overlap",
    windowed=False,
    hankelize=0,
)

train_x, train_y = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, 1, grid_size, channels)
assert train_y.shape == (batch_size, 19, grid_size, channels)

val_x, val_y = next(iter(datamodule.val_dataloader))
assert val_x.shape == (batch_size, 3, grid_size, channels)
assert val_y.shape == (batch_size, 17, grid_size, channels)

test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)

:::

::: {#cell-20 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

batch_size = 2
datamodule = DirectoryDataModule(
    data_array="/tmp/data.npy",
    split=[0.8, 0.1, 0.1],
    batch_size=batch_size,
    extract_channels=[0, 1],
    cache=True,
    standardize_dataset=True,
    mode="passthrough",
    windowed=False,
)

train_x = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, num_timesteps, grid_size, channels)

val_x = next(iter(datamodule.val_dataloader))
assert val_x.shape == (batch_size, num_timesteps, grid_size, channels)

test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)

:::

::: {#cell-21 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

batch_size = 2
num_steps_train = [1, 59]
num_steps_val = [1, 59]
datamodule = DirectoryDataModule(
    data_array="/tmp/data.npy",
    split=[0.8, 0.1, 0.1],
    batch_size=batch_size,
    extract_channels=[0, 1],
    num_steps_train=num_steps_train,
    num_steps_val=num_steps_val,
    cache=True,
    standardize_dataset=True,
    mode="passthrough",
    windowed=True,
)

train_x = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, sum(num_steps_train), grid_size, channels)

val_x = next(iter(datamodule.val_dataloader))
assert val_x.shape == (batch_size, sum(num_steps_val), grid_size, channels)

test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)

:::

::: {#cell-22 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

batch_size = 2
num_steps_train = 50
num_steps_val = 50
datamodule = DirectoryDataModule(
    data_array="/tmp/data.npy",
    split=[0.8, 0.1, 0.1],
    batch_size=batch_size,
    extract_channels=[0, 1],
    num_steps_train=[1, 49],
    num_steps_val=[1, 49],
    cache=True,
    standardize_dataset=True,
    mode="many_random_no_split",
    windowed=False,
)

train_x = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, num_steps_train, grid_size, channels)

val_x = next(iter(datamodule.val_dataloader))
assert val_x.shape == (batch_size, num_steps_val, grid_size, channels)

test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)

:::

::: {#cell-23 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

batch_size = 16
num_steps_train = 50
num_steps_val = 50
datamodule = DirectoryDataModule(
    data_array="/tmp/data.npy",
    split=[0.8, 0.1, 0.1],
    batch_size=batch_size,
    extract_channels=[0, 1],
    num_steps_train=[1, 49],
    num_steps_val=[1, 49],
    cache=True,
    standardize_dataset=True,
    mode="single_random_no_split",
    windowed=False,
)

train_x = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, num_steps_train, grid_size, channels)

val_x = next(iter(datamodule.val_dataloader))
assert val_x.shape == (10, num_steps_val, grid_size, channels)

test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)

:::

::: {#cell-24 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

batch_size = 16
num_steps_train = [1, 59]
num_steps_val = [1, 59]
datamodule = DirectoryDataModule(
    data_array="/tmp/data.npy",
    split=[0.8, 0.1, 0.1],
    batch_size=batch_size,
    extract_channels=[0, 1],
    num_steps_train=num_steps_train,
    num_steps_val=num_steps_val,
    cache=True,
    standardize_dataset=True,
    mode="split",
    windowed=True,
)

train_x, train_y = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, num_steps_train[0], grid_size, channels)

val_x, val_y = next(iter(datamodule.val_dataloader))
assert val_x.shape == (10, num_steps_val[0], grid_size, channels)

test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)

:::

def iterate_over_dataloader():
    for x in datamodule.train_dataloader:
        pass


%timeit iterate_over_dataloader()
2.29 ms ± 156 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)