source
reshape_array
reshape_array (array)
source
load_as_big_array
load_as_big_array (files)
source
unstandardize
unstandardize (x, mean, std, only_scale=True)
source
standardize
standardize (x, mean, std, only_scale=True)
source
create_grid
create_grid (height:int, width:int, min_val:float=-1.0,
max_val:float=1.0)
Create a grid of size (height, width) with values between min_val and max_val inclusive.
source
hankel_matrix
hankel_matrix (x, depth:int=2)
x |
|
|
input array with shape (B, grid_size, time_steps, channels) |
depth |
int |
2 |
repeats the array on the first axis n times |
Returns |
ndarray |
|
|
a = np.arange(20).reshape(5, 4)
a = np.stack([a, a, a], axis=0)
a = np.stack([a, a*2], axis=-1)
d = 2
b = hankel_matrix(a, depth=2)
assert b.shape == (a.shape[0], a.shape[1]- (d-1), a.shape[2]*d, a.shape[3])
source
select_slices
select_slices (key:<function PRNGKey>, idx:jax.Array, data:jax.Array,
indices:jax.Array, num_input:int, num_target:int,
mode:str, batch_size:int)
key |
PRNGKey |
|
idx |
Array |
batch indices |
data |
Array |
input array with shape (B, T, …) |
indices |
Array |
indices of the batches |
num_input |
int |
number of time steps to slice |
num_target |
int |
number of time steps to predict |
mode |
str |
blocks, sequential, passthrough |
batch_size |
int |
number of slices to take in multi_block mode |
Returns |
Tuple |
|
source
split_xy
split_xy (data, num_input, num_target)
Split the data into input and target
source
slice_tensor_single
slice_tensor_single (key:<function PRNGKey>, x:jax.Array, num_input,
num_target, split_mode:str='no_overlap')
key |
PRNGKey |
|
|
x |
Array |
|
input array with shape (B, T, …) |
num_input |
|
|
number of time steps to slice |
num_target |
|
|
number of time steps to predict |
split_mode |
str |
no_overlap |
no_split, overlap, no_overlap |
source
slice_tensor_multi
slice_tensor_multi (key:<function PRNGKey>, x:jax.Array, num_slices:int,
num_input:int, num_target:int,
split_mode:str='no_overlap')
Many random slices of the tensor along the time axis Warning: Here the time axis is the first axis of the tensor
key |
PRNGKey |
|
|
x |
Array |
|
input array with shape (T, …) |
num_slices |
int |
|
number of slices to take |
num_input |
int |
|
number of time steps as input |
num_target |
int |
|
number of time steps as target |
split_mode |
str |
no_overlap |
no_split, overlap, no_overlap |
B, T, W, C = 16, 4000, 40, 2
x = jnp.ones((B, T, W, C))
num_input = 1
num_target = 199
x, y = select_slices(
key=jax.random.PRNGKey(65),
idx=0,
data=x,
indices=jnp.arange(B),
num_input=num_input,
num_target=num_target,
mode="many_random_no_overlap",
batch_size=B,
)
assert x.shape == (B, num_input, W, C)
assert y.shape == (B, num_target, W, C)
x = jnp.ones((B, T, W, C))
x, y = select_slices(
key=jax.random.PRNGKey(65),
idx=0,
data=x,
indices=jnp.arange(B),
num_input=1,
num_target=199,
mode="many_random_overlap",
batch_size=B,
)
assert x.shape == (B, 199, W, C)
assert y.shape == (B, 199, W, C)
x = jnp.ones((B, T, W, C))
x, y = select_slices(
key=jax.random.PRNGKey(65),
idx=jnp.arange(B),
data=x,
indices=jnp.arange(B),
num_input=num_input,
num_target=num_target,
mode="single_random_no_overlap",
batch_size=None,
)
assert x.shape == (B, num_input, W, C)
assert y.shape == (B, num_target, W, C)
source
JaxDataloader
JaxDataloader (data:numpy.ndarray, num_input:int=-1, num_target:int=-1,
batch_size:int=16, shuffle:bool=True, drop_last:bool=True,
key:Optional[PRNGKey]=Array([0, 0], dtype=uint32),
mode:str='passthrough')
Initialize self. See help(type(self)) for accurate signature.
data |
ndarray |
|
|
num_input |
int |
-1 |
length of the input segment |
num_target |
int |
-1 |
length of the target segment |
batch_size |
int |
16 |
batch size |
shuffle |
bool |
True |
shuffle the data |
drop_last |
bool |
True |
drop the last batch if it’s smaller than batch_size |
key |
Optional |
[0 0] |
random key |
mode |
str |
passthrough |
mode of the dataloader |
source
make_rolling_windows
make_rolling_windows (x:numpy.ndarray, window_size:int)
Generate rolling windows of size window_size over the time axis of x. Warning: This might generate a large amount of data if the input is large.
x |
ndarray |
(B, T, …) |
window_size |
int |
|
B, T, H, W, C = 5, 64, 40, 40, 2
win = 16
dummy = np.ones((B, T, H, W, C))
dummy = make_rolling_windows(dummy, win)
assert dummy.shape == (B * (T-win+1), win, H, W, C)
source
split_data
split_data (mapped_data, split:List[float], extract_channels:List[int])
source
DirectoryDataModule
DirectoryDataModule (data_array:Union[str,List[str]],
split:List[float]=[0.8, 0.2, 0.0], batch_size:int=1,
extract_channels:List[int]=[],
num_steps_train:Union[int,List[int],NoneType]=None,
num_steps_val:Union[int,List[int],NoneType]=None,
standardize_dataset:bool=False,
mean:Optional[float]=None, std:Optional[float]=None,
mode:str='passthrough', windowed:bool=False,
hankelize:int=0, shuffle_train:bool=True,
cache:bool=False, rolling_windows:bool=False,
total_num_train:Optional[int]=None,
total_num_val:Optional[int]=None,
total_num_test:Optional[int]=None)
Initialize self. See help(type(self)) for accurate signature.
data_array |
Union |
|
|
split |
List |
[0.8, 0.2, 0.0] |
train, val, test |
batch_size |
int |
1 |
batch size per device (time steps in parallel mode) |
extract_channels |
List |
[] |
extract only these channels (u, v) |
num_steps_train |
Union |
None |
|
num_steps_val |
Union |
None |
number of time steps to use for training |
standardize_dataset |
bool |
False |
standardize the dataset per channel |
mean |
Optional |
None |
mean for standardization |
std |
Optional |
None |
std for standardization |
mode |
str |
passthrough |
mode of the dataloader |
windowed |
bool |
False |
slice the data in a windowedd way |
hankelize |
int |
0 |
hankelize the data with depth (0 = no hankelization) |
shuffle_train |
bool |
True |
shuffle the training data |
cache |
bool |
False |
cache the data |
rolling_windows |
bool |
False |
generate rolling windows |
total_num_train |
Optional |
None |
total number of training samples |
total_num_val |
Optional |
None |
total number of validation samples |
total_num_test |
Optional |
None |
total number of test samples |
source
is_list_like
is_list_like (obj)
# generate and save the data
num_samples = 100
num_timesteps = 120
grid_size = 20
channels = 2
data = np.random.randn(num_samples, num_timesteps, grid_size, channels)
np.save("/tmp/data.npy", data)
batch_size = 2
datamodule = DirectoryDataModule(
data_array="/tmp/data.npy",
split=[0.8, 0.1, 0.1],
batch_size=batch_size,
extract_channels=[0, 1],
num_steps_train=[1, 19],
num_steps_val=[3, 17],
cache=True,
standardize_dataset=True,
mode="many_random_no_overlap",
windowed=False,
hankelize=0,
)
train_x, train_y = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, 1, grid_size, channels)
assert train_y.shape == (batch_size, 19, grid_size, channels)
val_x, val_y = next(iter(datamodule.val_dataloader))
assert val_x.shape == (batch_size, 3, grid_size, channels)
assert val_y.shape == (batch_size, 17, grid_size, channels)
test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)
batch_size = 2
datamodule = DirectoryDataModule(
data_array="/tmp/data.npy",
split=[0.8, 0.1, 0.1],
batch_size=batch_size,
extract_channels=[0, 1],
cache=True,
standardize_dataset=True,
mode="passthrough",
windowed=False,
)
train_x = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, num_timesteps, grid_size, channels)
val_x = next(iter(datamodule.val_dataloader))
assert val_x.shape == (batch_size, num_timesteps, grid_size, channels)
test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)
batch_size = 2
num_steps_train = [1, 59]
num_steps_val = [1, 59]
datamodule = DirectoryDataModule(
data_array="/tmp/data.npy",
split=[0.8, 0.1, 0.1],
batch_size=batch_size,
extract_channels=[0, 1],
num_steps_train=num_steps_train,
num_steps_val=num_steps_val,
cache=True,
standardize_dataset=True,
mode="passthrough",
windowed=True,
)
train_x = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, sum(num_steps_train), grid_size, channels)
val_x = next(iter(datamodule.val_dataloader))
assert val_x.shape == (batch_size, sum(num_steps_val), grid_size, channels)
test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)
batch_size = 2
num_steps_train = 50
num_steps_val = 50
datamodule = DirectoryDataModule(
data_array="/tmp/data.npy",
split=[0.8, 0.1, 0.1],
batch_size=batch_size,
extract_channels=[0, 1],
num_steps_train=[1, 49],
num_steps_val=[1, 49],
cache=True,
standardize_dataset=True,
mode="many_random_no_split",
windowed=False,
)
train_x = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, num_steps_train, grid_size, channels)
val_x = next(iter(datamodule.val_dataloader))
assert val_x.shape == (batch_size, num_steps_val, grid_size, channels)
test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)
batch_size = 16
num_steps_train = 50
num_steps_val = 50
datamodule = DirectoryDataModule(
data_array="/tmp/data.npy",
split=[0.8, 0.1, 0.1],
batch_size=batch_size,
extract_channels=[0, 1],
num_steps_train=[1, 49],
num_steps_val=[1, 49],
cache=True,
standardize_dataset=True,
mode="single_random_no_split",
windowed=False,
)
train_x = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, num_steps_train, grid_size, channels)
val_x = next(iter(datamodule.val_dataloader))
assert val_x.shape == (10, num_steps_val, grid_size, channels)
test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)
batch_size = 16
num_steps_train = [1, 59]
num_steps_val = [1, 59]
datamodule = DirectoryDataModule(
data_array="/tmp/data.npy",
split=[0.8, 0.1, 0.1],
batch_size=batch_size,
extract_channels=[0, 1],
num_steps_train=num_steps_train,
num_steps_val=num_steps_val,
cache=True,
standardize_dataset=True,
mode="split",
windowed=True,
)
train_x, train_y = next(iter(datamodule.train_dataloader))
assert train_x.shape == (batch_size, num_steps_train[0], grid_size, channels)
val_x, val_y = next(iter(datamodule.val_dataloader))
assert val_x.shape == (10, num_steps_val[0], grid_size, channels)
test_data = next(iter(datamodule.test_dataloader))
assert test_data.shape == (1, num_timesteps, grid_size, channels)
def iterate_over_dataloader():
for x in datamodule.train_dataloader:
pass
2.29 ms ± 156 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)