Usage#
Installation#
I highly recommend to clone and run the repo on linux device because the installation of jax and related libraries is easier,
but with a bit of time you may end up running on other compatible platform. To start, just clone the repo:
git clone https://github.com/rsanchezmo/jaxer.git
cd jaxer
Create a python venv and source it:
python -m venv venv
source venv/bin/activate
You can install your desired jax version (more info at jax installation official doc).
You must notice that jax currently only provides 2 distributions with cuda (12.3 and 11.8).
As this repo also depends on torch, you could install torch in cpu and jax in gpu, torch
is only used for the dataloaders. However, I decided to better install cuda 11.8 versions as they are compatible.
For example, if already installed CUDA 11.8 on linux (make sure to have exported to PATH your CUDA version):
(venv) $ pip install jax[cuda11_local] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
(venv) $ pip install torch --index-url https://download.pytorch.org/whl/cu118
If CUDA is not installed, you may better use pip installation and let pip install the right version for your gpu. Make sure that
when doing :nvcc --version, it says that you should install cuda toolkit (just remove if you have a release already reported on
.bashrc):
(venv) $ pip install --upgrade jax[cuda11_pip] -f https://storage.googleapis.com/jax-releases/jax_releases.html
(venv) $ pip install torch --index-url https://download.pytorch.org/whl/cu118
Then install the rest of the dependencies (which are in the requirements.txt file):
(venv) $ pip install .
You could have omitted the jax installation as flax library installs it, but it is preferred to select the proper version of jax
according to the hardware.
Running the code#
Dataset#
You should first unzip the dataset in the dataset folder. You will end up with several subfolders regarding each time resolution (e.g. 1h):
cd ./data/datasets/
unzip data.zip
Training#
The training of the model is made really easy, as simple as creating a trainer with the experiment configuration and calling the train_and_evaluate method:
import jaxer
from training_config import config
if __name__ == '__main__':
trainer = jaxer.run.FlaxTrainer(config=config)
trainer.train_and_evaluate()
Configuration#
An example of the experiment configuration is in the training_config.py file:
import jaxer
output_mode = 'mean' # 'mean' or 'distribution' or 'discrete_grid
seq_len = 100
d_model = 128
model_config = jaxer.config.ModelConfig(
d_model=d_model,
num_layers=4,
head_layers=2,
n_heads=2,
dim_feedforward=4 * d_model, # 4 * d_model
dropout=0.05,
max_seq_len=seq_len,
flatten_encoder_output=False,
fe_blocks=0, # feature extractor is incremental, for instance input_shape, 128/2, 128 (d_model)
use_time2vec=False,
output_mode=output_mode, # 'mean' or 'distribution' or 'discrete_grid'
use_resblocks_in_head=False,
use_resblocks_in_fe=True,
average_encoder_output=False,
norm_encoder_prev=True
)
dataset_config = jaxer.config.DatasetConfig(
datapath='./data/datasets/data/',
output_mode=output_mode, # 'mean' or 'distribution' or 'discrete_grid
discrete_grid_levels=[-9e6, 0.0, 9e6],
initial_date='2018-01-01',
norm_mode="window_minmax",
resolution='30m',
tickers=['btc_usd', 'eth_usd', 'sol_usd'],
indicators=None,
seq_len=seq_len
)
synthetic_dataset_config = jaxer.config.SyntheticDatasetConfig(
window_size=seq_len,
output_mode=output_mode, # 'mean' or 'distribution' or 'discrete_grid
normalizer_mode='window_minmax', # 'window_meanstd' or 'window_minmax' or 'window_mean'
add_noise=False,
min_amplitude=0.1,
max_amplitude=1.0,
min_frequency=0.5,
max_frequency=30,
num_sinusoids=5,
)
pretrained_folder = "results/exp_synthetic_context"
pretrained_path_subfolder, pretrained_path_ckpt = jaxer.utils.get_best_model(pretrained_folder)
pretrained_model = (pretrained_folder, pretrained_path_subfolder, pretrained_path_ckpt)
config = jaxer.config.ExperimentConfig(
model_config=model_config,
pretrained_model=pretrained_model,
log_dir="results",
experiment_name="exp_both_pretrained_synthetic",
num_epochs=1000,
steps_per_epoch=500, # for synthetic dataset only
learning_rate=5e-4,
lr_mode='cosine', # 'cosine'
warmup_epochs=15,
dataset_mode='both', # 'real' or 'synthetic' or 'both'
dataset_config=dataset_config,
synthetic_dataset_config=synthetic_dataset_config,
batch_size=256,
test_split=0.1,
test_tickers=['btc_usd'],
seed=0,
save_weights=True,
early_stopper=100
)
You can find a more detailed explanation of each parameter in the API, Dataset, Training and Model sections.
Inference#
An agent class has been created so you can load a trained model and use it by providing the agent with a proper input. Agent
can infer by using __call__ method:
import jaxer
from torch.utils.data import DataLoader
if __name__ == '__main__':
# load the agent with best model weights
experiment = "exp_1"
agent = jaxer.run.Agent(experiment=experiment, model_name=jaxer.utils.get_best_model(experiment))
# creater dataloaders
if agent.config.dataset_mode == 'synthetic':
dataset = jaxer.utils.SyntheticDataset(config=jaxer.config.SyntheticDatasetConfig.from_dict(agent.config.synthetic_dataset_config))
test_dataloader = dataset.generator(batch_size=1, seed=100)
train_dataloader = dataset.generator(batch_size=1, seed=200)
else:
dataset = jaxer.utils.Dataset(dataset_config=jaxer.config.DatasetConfig.from_dict(agent.config.dataset_config))
train_ds, test_ds = dataset.get_train_test_split(test_size=agent.config.test_split,
test_tickers=agent.config.test_tickers)
train_dataloader = DataLoader(train_ds, batch_size=1, shuffle=True, collate_fn=jaxer.utils.jax_collate_fn)
test_dataloader = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=jaxer.utils.jax_collate_fn)
infer_test = False
if infer_test:
dataloader = test_dataloader
else:
dataloader = train_dataloader
if agent.config.dataset_mode == 'synthetic':
for i in range(30):
x, y_true, normalizer, window_info = next(dataloader)
y_pred = agent(x)
jaxer.utils.plot_predictions(x=x, y_true=y_true, y_pred=y_pred, normalizer=normalizer,
window_info=window_info, denormalize_values=True)
else:
for batch in dataloader:
x, y_true, normalizer, window_info = batch
y_pred = agent(x)
jaxer.utils.plot_predictions(x=x, y_true=y_true, y_pred=y_pred, normalizer=normalizer,
window_info=window_info[0], denormalize_values=True)
On this example, a jaxer agent is created with the best weights of the experiment exp_1.
The plot_entire_dataset flag is used to plot over the entire dataset (train and test), which is useful to see model performance (debug if overfitting or generalization).
Finally, the agent is used to predict on separate windows from the test set to see a more detailed prediction plot.