Dataloading utilities


source

reshape_array

 reshape_array (array)

source

load_as_big_array

 load_as_big_array (files)

source

unstandardize

 unstandardize (x, mean, std, only_scale=True)

source

standardize

 standardize (x, mean, std, only_scale=True)

source

create_grid

 create_grid (height:int, width:int, min_val:float=-1.0,
              max_val:float=1.0)

Create a grid of size (height, width) with values between min_val and max_val inclusive.


source

hankel_matrix

 hankel_matrix (x, depth:int=2)
Type Default Details
x input array with shape (B, grid_size, time_steps, channels)
depth int 2 repeats the array on the first axis n times
Returns ndarray
a = np.arange(20).reshape(5, 4)
a = np.stack([a, a, a], axis=0)
a = np.stack([a, a*2], axis=-1)
d = 2

b = hankel_matrix(a, depth=2)
assert b.shape == (a.shape[0], a.shape[1]- (d-1), a.shape[2]*d, a.shape[3])

source

select_slices

 select_slices (key:<function PRNGKey>, idx:jax.Array, data:jax.Array,
                indices:jax.Array, num_input:int, num_target:int,
                mode:str, batch_size:int)
Type Details
key PRNGKey
idx Array batch indices
data Array input array with shape (B, T, …)
indices Array indices of the batches
num_input int number of time steps to slice
num_target int number of time steps to predict
mode str blocks, sequential, passthrough
batch_size int number of slices to take in multi_block mode
Returns Tuple

source

split_xy

 split_xy (data, num_input, num_target)

Split the data into input and target


source

slice_tensor_single

 slice_tensor_single (key:<function PRNGKey>, x:jax.Array, num_input,
                      num_target, split_mode:str='no_overlap')
Type Default Details
key PRNGKey
x Array input array with shape (B, T, …)
num_input number of time steps to slice
num_target number of time steps to predict
split_mode str no_overlap no_split, overlap, no_overlap

source

slice_tensor_multi

 slice_tensor_multi (key:<function PRNGKey>, x:jax.Array, num_slices:int,
                     num_input:int, num_target:int,
                     split_mode:str='no_overlap')

Many random slices of the tensor along the time axis Warning: Here the time axis is the first axis of the tensor

Type Default Details
key PRNGKey
x Array input array with shape (T, …)
num_slices int number of slices to take
num_input int number of time steps as input
num_target int number of time steps as target
split_mode str no_overlap no_split, overlap, no_overlap
B, T, W, C = 16, 4000, 40, 2
x = jnp.ones((B, T, W, C))

num_input = 1
num_target = 199

x, y = select_slices(
    key=jax.random.PRNGKey(65),
    idx=0,
    data=x,
    indices=jnp.arange(B),
    num_input=num_input,
    num_target=num_target,
    mode="many_random_no_overlap",
    batch_size=B,
)
assert x.shape == (B, num_input, W, C)
assert y.shape == (B, num_target, W, C)
x = jnp.ones((B, T, W, C))

x, y = select_slices(
    key=jax.random.PRNGKey(65),
    idx=0,
    data=x,
    indices=jnp.arange(B),
    num_input=1,
    num_target=199,
    mode="many_random_overlap",
    batch_size=B,
)
assert x.shape == (B, 199, W, C)
assert y.shape == (B, 199, W, C)
x = jnp.ones((B, T, W, C))


x, y = select_slices(
    key=jax.random.PRNGKey(65),
    idx=jnp.arange(B),
    data=x,
    indices=jnp.arange(B),
    num_input=num_input,
    num_target=num_target,
    mode="single_random_no_overlap",
    batch_size=None,
)
assert x.shape == (B, num_input, W, C)
assert y.shape == (B, num_target, W, C)

source

JaxDataloader

 JaxDataloader (data:numpy.ndarray, num_input:int=-1, num_target:int=-1,
                batch_size:int=16, shuffle:bool=True, drop_last:bool=True,
                key:Optional[PRNGKey]=Array([0, 0], dtype=uint32),
                mode:str='passthrough')

Initialize self. See help(type(self)) for accurate signature.

Type Default Details
data ndarray
num_input int -1 length of the input segment
num_target int -1 length of the target segment
batch_size int 16 batch size
shuffle bool True shuffle the data
drop_last bool True drop the last batch if it’s smaller than batch_size
key Optional [0 0] random key
mode str passthrough mode of the dataloader

source

make_rolling_windows

 make_rolling_windows (x:numpy.ndarray, window_size:int)

Generate rolling windows of size window_size over the time axis of x. Warning: This might generate a large amount of data if the input is large.

Type Details
x ndarray (B, T, …)
window_size int
B, T, H, W, C = 5, 64, 40, 40, 2
win = 16
dummy = np.ones((B, T, H, W, C))
dummy = make_rolling_windows(dummy, win)
assert dummy.shape == (B * (T-win+1), win, H, W, C)

source

split_data

 split_data (mapped_data, split:List[float], extract_channels:List[int])

source

DirectoryDataModule

 DirectoryDataModule (data_array:Union[str,List[str]],
                      split:List[float]=[0.8, 0.2, 0.0], batch_size:int=1,
                      extract_channels:List[int]=[],
                      num_steps_train:Union[int,List[int],NoneType]=None,
                      num_steps_val:Union[int,List[int],NoneType]=None,
                      standardize_dataset:bool=False,
                      mean:Optional[float]=None, std:Optional[float]=None,
                      mode:str='passthrough', windowed:bool=False,
                      hankelize:int=0, shuffle_train:bool=True,
                      cache:bool=False, rolling_windows:bool=False,
                      total_num_train:Optional[int]=None,
                      total_num_val:Optional[int]=None,
                      total_num_test:Optional[int]=None)

Initialize self. See help(type(self)) for accurate signature.

Type Default Details
data_array Union
split List [0.8, 0.2, 0.0] train, val, test
batch_size int 1 batch size per device (time steps in parallel mode)
extract_channels List [] extract only these channels (u, v)
num_steps_train Union None
num_steps_val Union None number of time steps to use for training
standardize_dataset bool False standardize the dataset per channel
mean Optional None mean for standardization
std Optional None std for standardization
mode str passthrough mode of the dataloader
windowed bool False slice the data in a windowedd way
hankelize int 0 hankelize the data with depth (0 = no hankelization)
shuffle_train bool True shuffle the training data
cache bool False cache the data
rolling_windows bool False generate rolling windows
total_num_train Optional None total number of training samples
total_num_val Optional None total number of validation samples
total_num_test Optional None total number of test samples

source

is_list_like

 is_list_like (obj)
# generate and save the data
num_samples = 100
num_timesteps = 120
grid_size = 20
channels = 2

data = np.random.randn(num_samples, num_timesteps, grid_size, channels)
np.save("/tmp/data.npy", data)
batch_size = 2

datamodule = DirectoryDataModule(
    data_array="/tmp/data.npy",
    split=[0.8, 0.1, 0.1],
    batch_size=batch_size,
    extract_channels=[0, 1],
    num_steps_train=[1, 19],
    num_steps_val=[3, 17],
    cache=True,
    standardize_dataset=True,
    mode="many_random_no_overlap",
    windowed=False,
    hankelize=0,
)

train_x, train_y = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, 1, grid_size, channels)
assert train_y.shape == (batch_size, 19, grid_size, channels)

val_x, val_y = next(iter(datamodule.val_dataloader))
assert val_x.shape == (batch_size, 3, grid_size, channels)
assert val_y.shape == (batch_size, 17, grid_size, channels)

test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)
batch_size = 2
datamodule = DirectoryDataModule(
    data_array="/tmp/data.npy",
    split=[0.8, 0.1, 0.1],
    batch_size=batch_size,
    extract_channels=[0, 1],
    cache=True,
    standardize_dataset=True,
    mode="passthrough",
    windowed=False,
)

train_x = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, num_timesteps, grid_size, channels)

val_x = next(iter(datamodule.val_dataloader))
assert val_x.shape == (batch_size, num_timesteps, grid_size, channels)

test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)
batch_size = 2
num_steps_train = [1, 59]
num_steps_val = [1, 59]
datamodule = DirectoryDataModule(
    data_array="/tmp/data.npy",
    split=[0.8, 0.1, 0.1],
    batch_size=batch_size,
    extract_channels=[0, 1],
    num_steps_train=num_steps_train,
    num_steps_val=num_steps_val,
    cache=True,
    standardize_dataset=True,
    mode="passthrough",
    windowed=True,
)

train_x = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, sum(num_steps_train), grid_size, channels)

val_x = next(iter(datamodule.val_dataloader))
assert val_x.shape == (batch_size, sum(num_steps_val), grid_size, channels)

test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)
batch_size = 2
num_steps_train = 50
num_steps_val = 50
datamodule = DirectoryDataModule(
    data_array="/tmp/data.npy",
    split=[0.8, 0.1, 0.1],
    batch_size=batch_size,
    extract_channels=[0, 1],
    num_steps_train=[1, 49],
    num_steps_val=[1, 49],
    cache=True,
    standardize_dataset=True,
    mode="many_random_no_split",
    windowed=False,
)

train_x = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, num_steps_train, grid_size, channels)

val_x = next(iter(datamodule.val_dataloader))
assert val_x.shape == (batch_size, num_steps_val, grid_size, channels)

test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)
batch_size = 16
num_steps_train = 50
num_steps_val = 50
datamodule = DirectoryDataModule(
    data_array="/tmp/data.npy",
    split=[0.8, 0.1, 0.1],
    batch_size=batch_size,
    extract_channels=[0, 1],
    num_steps_train=[1, 49],
    num_steps_val=[1, 49],
    cache=True,
    standardize_dataset=True,
    mode="single_random_no_split",
    windowed=False,
)

train_x = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, num_steps_train, grid_size, channels)

val_x = next(iter(datamodule.val_dataloader))
assert val_x.shape == (10, num_steps_val, grid_size, channels)

test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)
batch_size = 16
num_steps_train = [1, 59]
num_steps_val = [1, 59]
datamodule = DirectoryDataModule(
    data_array="/tmp/data.npy",
    split=[0.8, 0.1, 0.1],
    batch_size=batch_size,
    extract_channels=[0, 1],
    num_steps_train=num_steps_train,
    num_steps_val=num_steps_val,
    cache=True,
    standardize_dataset=True,
    mode="split",
    windowed=True,
)

train_x, train_y = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, num_steps_train[0], grid_size, channels)

val_x, val_y = next(iter(datamodule.val_dataloader))
assert val_x.shape == (10, num_steps_val[0], grid_size, channels)

test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)
def iterate_over_dataloader():
    for x in datamodule.train_dataloader:
        pass
2.29 ms ± 156 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)