def iterate_over_dataloader():
for x in datamodule.train_dataloader:
pass
%timeit iterate_over_dataloader()
2.29 ms ± 156 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
from pathlib import Path
import numpy as np
import jax.numpy as jnp
from typing import Optional, List, Union
from numpy.lib.stride_tricks import sliding_window_view
from einops import rearrange
from collections.abc import Iterable
from absl import logging
:::
::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def standardize(x, mean, std, only_scale=True):
return (x - mean) / std if not only_scale else x / std
def unstandardize(x, mean, std, only_scale=True):
return x * std + mean if not only_scale else x * std
# load the data
def load_as_big_array(files):
return np.stack([np.load(file) for file in files], axis=0)
def reshape_array(array):
n_runs, n_timesteps, n_gridpoints, n_channels = array.shape
return np.reshape(array, (n_runs * n_timesteps, n_gridpoints, n_channels))
:::
::: {#cell-5 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def create_grid(
height: int,
width: int,
min_val: float = -1.0,
max_val: float = 1.0,
):
"""
Create a grid of size (height, width) with values between min_val and max_val inclusive.
"""
y, x = jnp.meshgrid(
jnp.linspace(-min_val, max_val, height),
jnp.linspace(-min_val, max_val, width),
)
grid = jnp.stack([x, y], axis=-1)
return grid
:::
::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def hankel_matrix(
x, # input array with shape (B, grid_size, time_steps, channels)
depth: int = 2, # repeats the array on the first axis n times
) -> (
np.ndarray
): # returns an array with shape (B, grid_size * depth, time_steps - (d-1), channels)
x = x.transpose(0, 2, 1, 3)
d = depth - 1
b = sliding_window_view(x, window_shape=(x.shape[1], x.shape[2] - d), axis=(1, 2))
b = b.transpose(0, 1, 2, 4, 5, 3)
b = b.reshape(x.shape[0], -1, x.shape[2] - d, x.shape[-1])
return b.transpose(0, 2, 1, 3)
:::
::: {#cell-7 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
a = np.arange(20).reshape(5, 4)
a = np.stack([a, a, a], axis=0)
a = np.stack([a, a * 2], axis=-1)
d = 2
b = hankel_matrix(a, depth=2)
assert b.shape == (a.shape[0], a.shape[1] - (d - 1), a.shape[2] * d, a.shape[3])
:::
::: {#cell-8 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
import jax
from timeit import default_timer as timer
from functools import partial
from typing import Tuple
:::
::: {#cell-9 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def slice_tensor_multi(
key: jax.random.PRNGKey,
x: jnp.ndarray, # input array with shape (T, ...)
num_slices: int, # number of slices to take
num_input: int, # number of time steps as input
num_target: int, # number of time steps as target
split_mode: str = "no_overlap", # no_split, overlap, no_overlap
):
"""
Many random slices of the tensor along the time axis
Warning: Here the time axis is the first axis of the tensor
"""
total_samples = x.shape[0]
num_steps = num_input + num_target
last_start_sample = total_samples - num_steps
indices = jax.random.choice(key, last_start_sample, shape=(num_slices,))
def slice_tensor(start_index):
if split_mode == "no_overlap":
input_slices = jax.lax.dynamic_slice_in_dim(
x, start_index, num_input, axis=0
)
target_slices = jax.lax.dynamic_slice_in_dim(
x, start_index + num_input, num_target, axis=0
)
return input_slices, target_slices
elif split_mode == "overlap":
input_slices = jax.lax.dynamic_slice_in_dim(
x, start_index, num_steps - 1, axis=0
)
target_slices = jax.lax.dynamic_slice_in_dim(
x, start_index + 1, num_steps - 1, axis=0
)
return input_slices, target_slices
elif split_mode == "no_split":
slices = jax.lax.dynamic_slice_in_dim(x, start_index, num_steps, axis=0)
return slices
slices = jax.vmap(slice_tensor)(indices)
return slices
def slice_tensor_single(
key: jax.random.PRNGKey,
x: jnp.ndarray, # input array with shape (B, T, ...)
num_input, # number of time steps to slice
num_target, # number of time steps to predict
split_mode: str = "no_overlap", # no_split, overlap, no_overlap
):
# random slices
total_samples = x.shape[1]
num_steps = num_input + num_target
last_start_sample = total_samples - num_steps
start_index = jax.random.choice(
key,
last_start_sample,
)
def slice_tensor(start_index):
if split_mode == "no_overlap":
input_slices = jax.lax.dynamic_slice_in_dim(
x, start_index, num_input, axis=1
)
target_slices = jax.lax.dynamic_slice_in_dim(
x, start_index + num_input, num_target, axis=1
)
return input_slices, target_slices
elif split_mode == "overlap":
input_slices = jax.lax.dynamic_slice_in_dim(
x, start_index, num_steps - 1, axis=1
)
target_slices = jax.lax.dynamic_slice_in_dim(
x, start_index + 1, num_steps - 1, axis=1
)
return input_slices, target_slices
elif split_mode == "no_split":
slices = jax.lax.dynamic_slice_in_dim(x, start_index, num_steps, axis=1)
return slices
# get the slices of the tensor
slices = slice_tensor(start_index)
return slices
def split_xy(
data,
num_input,
num_target,
):
"""
Split the data into input and target
"""
x = jax.lax.slice_in_dim(data, 0, num_input, axis=1)
y = jax.lax.slice_in_dim(data, num_input, data.shape[1], axis=1)
return x, y
@partial(jax.jit, static_argnames=("mode", "num_input", "num_target", "batch_size"))
def select_slices(
key: jax.random.PRNGKey,
idx: jnp.ndarray, # batch indices
data: jnp.ndarray, # input array with shape (B, T, ...)
indices: jnp.ndarray, # indices of the batches
num_input: int, # number of time steps to slice
num_target: int, # number of time steps to predict
mode: str, # blocks, sequential, passthrough
batch_size: int, # number of slices to take in multi_block mode
) -> Tuple[jnp.ndarray, jnp.ndarray]:
batch_data = data[indices[idx]]
if mode == "single_random_no_overlap":
return slice_tensor_single(key, batch_data, num_input, num_target, "no_overlap")
elif mode == "single_random_overlap":
return slice_tensor_single(key, batch_data, num_input, num_target, "overlap")
elif mode == "single_random_no_split":
return slice_tensor_single(key, batch_data, num_input, num_target, "no_split")
elif mode == "many_random_no_overlap":
return slice_tensor_multi(
key, batch_data, batch_size, num_input, num_target, "no_overlap"
)
elif mode == "many_random_overlap":
return slice_tensor_multi(
key, batch_data, batch_size, num_input, num_target, "overlap"
)
elif mode == "many_random_no_split":
return slice_tensor_multi(
key, batch_data, batch_size, num_input, num_target, "no_split"
)
elif mode == "split":
return split_xy(batch_data, num_input, num_target)
elif mode == "passthrough":
return batch_data
:::
::: {#cell-10 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
B, T, W, C = 16, 4000, 40, 2
x = jnp.ones((B, T, W, C))
num_input = 1
num_target = 199
x, y = select_slices(
key=jax.random.PRNGKey(65),
idx=0,
data=x,
indices=jnp.arange(B),
num_input=num_input,
num_target=num_target,
mode="many_random_no_overlap",
batch_size=B,
)
assert x.shape == (B, num_input, W, C)
assert y.shape == (B, num_target, W, C)
:::
::: {#cell-11 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
x = jnp.ones((B, T, W, C))
x, y = select_slices(
key=jax.random.PRNGKey(65),
idx=0,
data=x,
indices=jnp.arange(B),
num_input=1,
num_target=199,
mode="many_random_overlap",
batch_size=B,
)
assert x.shape == (B, 199, W, C)
assert y.shape == (B, 199, W, C)
:::
::: {#cell-12 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
x = jnp.ones((B, T, W, C))
x, y = select_slices(
key=jax.random.PRNGKey(65),
idx=jnp.arange(B),
data=x,
indices=jnp.arange(B),
num_input=num_input,
num_target=num_target,
mode="single_random_no_overlap",
batch_size=None,
)
assert x.shape == (B, num_input, W, C)
assert y.shape == (B, num_target, W, C)
:::
::: {#cell-13 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class JaxDataloader:
def __init__(
self,
data: np.ndarray,
num_input: int = -1, # length of the input segment
num_target: int = -1, # length of the target segment
batch_size: int = 16, # batch size
shuffle: bool = True, # shuffle the data
drop_last: bool = True, # drop the last batch if it's smaller than batch_size
key: Optional[jax.random.PRNGKey] = jax.random.PRNGKey(0), # random key
mode: str = "passthrough", # mode of the dataloader
):
self.data = data
self.num_input = num_input
self.num_target = num_target
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.key = key
self.mode = mode
self.data_fits_batch_size = self.data.shape[0] % batch_size == 0
if self.mode in [
"many_random_no_overlap",
"many_random_overlap",
"many_random_no_split",
]:
# in many_random mode the batch size is the number of slices within a single trajectory
self.indices = jnp.arange(data.shape[0])
else:
if not self.data_fits_batch_size:
logging.info(
f"Warning: The data size {self.data.shape[0]} is not divisible by the batch size {batch_size}."
)
# Adjust batch size to fit data
self.batch_size = self.data.shape[0] // (
self.data.shape[0] // batch_size + 1
)
logging.info(f"Setting the batch size to {self.batch_size}")
# Update indices
self.indices = jnp.arange(data.shape[0]).reshape(-1, self.batch_size)
else:
self.indices = jnp.arange(data.shape[0]).reshape(-1, batch_size)
self._reset()
def _reset(self):
"""Reset the dataloader for a new iteration over the data."""
if self.shuffle:
self.key = jax.random.split(self.key, 1)[0]
self.indices = jax.random.permutation(self.key, self.indices)
self.idx = 0
def __iter__(self):
return self
def get_slices(self):
# Increment the idx and fold it into the key for randomness
self.key = jax.random.fold_in(self.key, self.idx)
# Call the select_slices function with common parameters
result = select_slices(
self.key,
self.idx,
self.data,
self.indices,
self.num_input,
self.num_target,
self.mode,
self.batch_size,
)
self.idx += 1
return result
def __next__(self) -> Union[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
if self.idx >= len(self.indices):
# Reset the iterator and raise StopIteration
self._reset()
raise StopIteration
else:
if self.mode not in [
"many_random_no_overlap",
"many_random_overlap",
"many_random_no_split",
]:
if self.drop_last and len(self.data) - self.idx < self.batch_size:
# Reset and stop if we're dropping the last batch and it's too small
self._reset()
raise StopIteration
return self.get_slices()
def __len__(self):
if self.drop_last:
# Exclude the last batch if it has fewer than batch_size elements
return len(self.data) // self.batch_size
else:
# Include the last batch regardless of its size
return int(np.ceil(len(self.data) / self.batch_size))
:::
::: {#cell-14 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def make_rolling_windows(
x: np.ndarray, # (B, T, ...)
window_size: int,
):
"""
Generate rolling windows of size window_size over the time axis of x.
Warning: This might generate a large amount of data if the input is large.
"""
x = rearrange(x, "b t ... -> b ... t")
x = sliding_window_view(x, window_shape=(window_size), axis=(-1,))
x = rearrange(x, "b ... n w -> (b n) w ...")
return x
:::
::: {#cell-15 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
B, T, H, W, C = 5, 64, 40, 40, 2
win = 16
dummy = np.ones((B, T, H, W, C))
dummy = make_rolling_windows(dummy, win)
assert dummy.shape == (B * (T - win + 1), win, H, W, C)
:::
::: {#cell-16 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def split_data(
mapped_data,
split: List[float],
extract_channels: List[int],
):
n_train, n_val, n_test = [
int(fraction * mapped_data.shape[0]) for fraction in split
]
n_total = n_train + n_val + n_test
if n_total < mapped_data.shape[0]:
n_rest = mapped_data.shape[0] - n_total
n_train += n_rest
assert n_train + n_val + n_test == mapped_data.shape[0], (
"Split fractions do not sum up correctly"
)
train = mapped_data[:n_train]
val = mapped_data[n_train : n_train + n_val]
test = mapped_data[n_train + n_val :]
if extract_channels:
train = train[..., extract_channels]
val = val[..., extract_channels]
test = test[..., extract_channels]
return train, val, test
:::
::: {#cell-17 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def is_list_like(obj):
return isinstance(obj, Iterable) and not isinstance(obj, (str, bytes, int))
class DirectoryDataModule:
def __init__(
self,
data_array: Union[str, List[str]],
split: List[float] = [0.8, 0.2, 0.0], # train, val, test
batch_size: int = 1, # batch size per device (time steps in parallel mode)
extract_channels: List[int] = [], # extract only these channels (u, v)
num_steps_train: Optional[
Union[int, List[int]]
] = None, # number of time steps to use for training
num_steps_val: Optional[
Union[int, List[int]]
] = None, # number of time steps to use for validation
standardize_dataset: bool = False, # standardize the dataset per channel
mean: Optional[float] = None, # mean for standardization
std: Optional[float] = None, # std for standardization
mode: str = "passthrough", # mode of the dataloader
windowed: bool = False, # slice the data in a windowedd way
hankelize: int = 0, # hankelize the data with depth (0 = no hankelization)
shuffle_train: bool = True, # shuffle the training data
cache: bool = False, # cache the data
rolling_windows: bool = False, # generate rolling windows
total_num_train: Optional[int] = None, # total number of training samples
total_num_val: Optional[int] = None, # total number of validation samples
total_num_test: Optional[int] = None, # total number of test samples
):
assert mode in [
"single_random_no_overlap",
"single_random_overlap",
"single_random_no_split",
"many_random_no_overlap",
"many_random_overlap",
"many_random_no_split",
"split",
"passthrough",
], "Invalid mode"
self.mode = mode
self.num_steps_train = num_steps_train
self.num_steps_val = num_steps_val
self.train_batch_size = batch_size
self.val_batch_size = batch_size
assert Path(data_array).exists() or all(
[Path(d).exists() for d in data_array]
), "The data array does not exist"
# if we have a list of directories, we will concatenate the data
# else we will assume that the directory contains the data
if is_list_like(data_array):
# TODO: this will not work as we need to create a third memmap file
raise NotImplementedError(
"Concatenating data from different arrays is not supported yet"
)
else:
mapped_data = np.load(data_array, mmap_mode="r")
logging.info(f"Found {mapped_data.shape[0]} trajectories")
# split
if extract_channels is None:
extract_channels = list(range(mapped_data.shape[-1]))
train_array, val_array, test_array = split_data(
mapped_data,
split,
extract_channels,
)
# slice only the number of time steps
if total_num_train is not None:
train_array = train_array[:, :total_num_train, ...]
if total_num_val is not None:
val_array = val_array[:, :total_num_val, ...]
if total_num_test is not None:
test_array = test_array[:, :total_num_test, ...]
self.num_steps_input_train, self.num_steps_target_train = self._parse_num_steps(
num_steps_train,
)
self.num_steps_input_val, self.num_steps_target_val = self._parse_num_steps(
num_steps_val
)
if hankelize != 0:
raise NotImplementedError("Hankelization is not supported yet")
if windowed:
train_array = self._windowed_data(
train_array, self.num_steps_input_train + self.num_steps_target_train
)
val_array = self._windowed_data(
val_array, self.num_steps_input_val + self.num_steps_target_val
)
if rolling_windows:
train_array = make_rolling_windows(train_array, window_size=num_steps_train)
val_array = make_rolling_windows(val_array, window_size=num_steps_val)
# TODO: there is no reliable way to know the shape of the numpy array
# we save the shape of the first entry
self.data_shape = train_array[0].shape
# after slicing to cache the data on the GPU or map some operation
if cache:
timer_start = timer()
train_array = jax.device_put(train_array).block_until_ready()
val_array = jax.device_put(val_array).block_until_ready()
logging.info(f"Data cached in {timer() - timer_start} seconds")
if standardize_dataset:
# standardize the output per channel
# the statistics are computed on the training set
train_array, mean, std = self._standardize_data(train_array, mean, std)
val_array, *_ = self._standardize_data(val_array, mean, std)
test_array, *_ = self._standardize_data(test_array, mean, std)
self.mean = mean
self.std = std
logging.info(
f"The mean and std of the output are {self.mean} and {self.std}, you should save this for later"
)
self.train_dataloader = JaxDataloader(
train_array,
num_input=self.num_steps_input_train,
num_target=self.num_steps_target_train,
batch_size=batch_size,
shuffle=shuffle_train,
drop_last=True,
mode=self.mode,
)
self.val_dataloader = JaxDataloader(
val_array,
num_input=self.num_steps_input_val,
num_target=self.num_steps_target_val,
batch_size=batch_size,
shuffle=False,
drop_last=True,
mode=self.mode,
)
self.test_dataloader = JaxDataloader(
test_array,
batch_size=1, # we will handle the batch size in the slice function
shuffle=False,
drop_last=True,
mode="passthrough", # always passthrough
)
logging.info(
f"Using {self.num_steps_input_train} input and {self.num_steps_target_train} target steps for training in mode {mode}"
)
def _standardize_data(self, data_array, mean, std):
if mean is None or std is None:
if (
len(data_array.shape) == 4
): # (batch_size, time_steps, grid_size, channels)
mean = jnp.mean(data_array, axis=(0, 1, 2))
std = jnp.std(data_array, axis=(0, 1, 2))
elif len(data_array.shape) == 5: # (batch_size, time_steps, H, W, channels)
mean = jnp.mean(data_array, axis=(0, 1, 2, 3))
std = jnp.std(data_array, axis=(0, 1, 2, 3))
else:
raise ValueError(
"The input array has an incorrect number of dimensions"
)
data_array = standardize(data_array, mean, std)
return data_array, mean, std
def _windowed_data(
self,
data_array,
num_steps,
):
assert data_array.shape[1] % num_steps == 0, (
"The number of time steps is not divisible by the slice size"
)
data_array = data_array.reshape(-1, num_steps, *data_array.shape[2:])
return data_array
def _parse_num_steps(
self,
num_steps,
):
if is_list_like(num_steps):
num_steps_input, num_steps_target = num_steps
else:
num_steps_input = num_steps
num_steps_target = num_steps
return num_steps_input, num_steps_target
def get_info(self):
# return the correct shape of the data after slicing
if self.mode in ["single_random_overlap", "many_random_overlap"]:
return [
self.num_steps_input_train + self.num_steps_target_train - 1,
*self.data_shape[1:],
]
else:
return [self.num_steps_input_train, *self.data_shape[1:]]
:::
::: {#cell-18 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
# generate and save the data
num_samples = 100
num_timesteps = 120
grid_size = 20
channels = 2
data = np.random.randn(num_samples, num_timesteps, grid_size, channels)
np.save("/tmp/data.npy", data)
:::
::: {#cell-19 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
batch_size = 2
datamodule = DirectoryDataModule(
data_array="/tmp/data.npy",
split=[0.8, 0.1, 0.1],
batch_size=batch_size,
extract_channels=[0, 1],
num_steps_train=[1, 19],
num_steps_val=[3, 17],
cache=True,
standardize_dataset=True,
mode="many_random_no_overlap",
windowed=False,
hankelize=0,
)
train_x, train_y = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, 1, grid_size, channels)
assert train_y.shape == (batch_size, 19, grid_size, channels)
val_x, val_y = next(iter(datamodule.val_dataloader))
assert val_x.shape == (batch_size, 3, grid_size, channels)
assert val_y.shape == (batch_size, 17, grid_size, channels)
test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)
:::
::: {#cell-20 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
batch_size = 2
datamodule = DirectoryDataModule(
data_array="/tmp/data.npy",
split=[0.8, 0.1, 0.1],
batch_size=batch_size,
extract_channels=[0, 1],
cache=True,
standardize_dataset=True,
mode="passthrough",
windowed=False,
)
train_x = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, num_timesteps, grid_size, channels)
val_x = next(iter(datamodule.val_dataloader))
assert val_x.shape == (batch_size, num_timesteps, grid_size, channels)
test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)
:::
::: {#cell-21 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
batch_size = 2
num_steps_train = [1, 59]
num_steps_val = [1, 59]
datamodule = DirectoryDataModule(
data_array="/tmp/data.npy",
split=[0.8, 0.1, 0.1],
batch_size=batch_size,
extract_channels=[0, 1],
num_steps_train=num_steps_train,
num_steps_val=num_steps_val,
cache=True,
standardize_dataset=True,
mode="passthrough",
windowed=True,
)
train_x = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, sum(num_steps_train), grid_size, channels)
val_x = next(iter(datamodule.val_dataloader))
assert val_x.shape == (batch_size, sum(num_steps_val), grid_size, channels)
test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)
:::
::: {#cell-22 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
batch_size = 2
num_steps_train = 50
num_steps_val = 50
datamodule = DirectoryDataModule(
data_array="/tmp/data.npy",
split=[0.8, 0.1, 0.1],
batch_size=batch_size,
extract_channels=[0, 1],
num_steps_train=[1, 49],
num_steps_val=[1, 49],
cache=True,
standardize_dataset=True,
mode="many_random_no_split",
windowed=False,
)
train_x = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, num_steps_train, grid_size, channels)
val_x = next(iter(datamodule.val_dataloader))
assert val_x.shape == (batch_size, num_steps_val, grid_size, channels)
test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)
:::
::: {#cell-23 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
batch_size = 16
num_steps_train = 50
num_steps_val = 50
datamodule = DirectoryDataModule(
data_array="/tmp/data.npy",
split=[0.8, 0.1, 0.1],
batch_size=batch_size,
extract_channels=[0, 1],
num_steps_train=[1, 49],
num_steps_val=[1, 49],
cache=True,
standardize_dataset=True,
mode="single_random_no_split",
windowed=False,
)
train_x = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, num_steps_train, grid_size, channels)
val_x = next(iter(datamodule.val_dataloader))
assert val_x.shape == (10, num_steps_val, grid_size, channels)
test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)
:::
::: {#cell-24 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
batch_size = 16
num_steps_train = [1, 59]
num_steps_val = [1, 59]
datamodule = DirectoryDataModule(
data_array="/tmp/data.npy",
split=[0.8, 0.1, 0.1],
batch_size=batch_size,
extract_channels=[0, 1],
num_steps_train=num_steps_train,
num_steps_val=num_steps_val,
cache=True,
standardize_dataset=True,
mode="split",
windowed=True,
)
train_x, train_y = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, num_steps_train[0], grid_size, channels)
val_x, val_y = next(iter(datamodule.val_dataloader))
assert val_x.shape == (10, num_steps_val[0], grid_size, channels)
test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)
:::