API#

This documentation provides an overview of the API for jaxer.

Modules Overview#

  • Run: contains endpoints related to trainers and agent.

  • Utils: provides utility functions (e.g. dataset, plotting, etc.).

  • Config: contains configuration settings for jaxer.

  • Models: contains the flax transformer model and its internal blocks.

Run#

class jaxer.run.AgentBase(experiment: str, model_name: Tuple[Optional[str], str])#

Bases: object

Agent base class to load a model and perform inference (jax, torch, …)

Parameters:
  • experiment (str) – the name of the experiment

  • model_name (Tuple[Optional[str], str]) – the name of the model to load. If the model is in a subfolder, provide a tuple with the subfolder name and the model name

Raises:

FileNotFoundError – if the experiment or model does not exist

forward(x: Any) Any#

Inference function (you can use __call__ instead)

Parameters:

x (Any) – input data

Returns:

model output

Return type:

Any

class jaxer.run.FlaxAgent(experiment: str, model_name: Tuple[Optional[str], str])#

Bases: AgentBase

Agent class to load a model and perform inference

Parameters:
  • experiment (str) – the name of the experiment

  • model_name (Tuple[Optional[str], str]) – the name of the model to load. If the model is in a subfolder, provide a tuple with the subfolder name and the model name

class jaxer.run.FlaxTrainer(config: ExperimentConfig)#

Bases: TrainerBase

Trainer class for training jaxer using flax, optax and jax

Parameters:

config (Config) – training config for running an experiment

train_and_evaluate() None#

Runs the training loop with evaluation

class jaxer.run.TrainerBase(config: ExperimentConfig)#

Bases: object

Trainer base class. All trainers should inherit from this (jax, torch, …)

Parameters:

config (Config) – the configuration for the experiment

abstract train_and_evaluate() None#

Runs a training loop

Utils#

class jaxer.utils.Dataset(dataset_config: DatasetConfig)#

Bases: Dataset

Finance dataset class for training jaxer (pytorch dataset)

Parameters:

dataset_config (DatasetConfig) – DatasetConfig object

Raises:
  • ValueError – if the norm_mode is not one of [‘window_minmax’, ‘window_meanstd’, ‘global_minmax’, ‘global_meanstd’, ‘none’]

  • ValueError – if the indicators are not in the dataset [‘rsi’, ‘bb_upper’, ‘bb_lower’, ‘bb_middle’]

  • ValueError – if the discrete_grid_levels are not provided and the output_mode is discrete_grid

  • ValueError – if the ticker is not in the dataset

INDICATORS = ['rsi', 'bb_upper', 'bb_lower', 'bb_middle']#
INDICATORS_TO_NORMALIZE = ['bb_upper', 'bb_lower', 'bb_middle']#
NORM_MODES = ['window_minmax', 'window_meanstd', 'window_mean', 'global_minmax', 'global_meanstd', 'none']#
OHLC = ['open', 'high', 'low', 'close']#
static encode_tokens(tokens: ndarray) ndarray#

Encodes the tokens into integer (tokens are expected to be on [0, 1])

Parameters:

tokens (np.ndarray) – tokens to encode

Returns:

encoded tokens

Return type:

np.ndarray

get_random_input()#

Returns a random input from the dataset

Returns:

sequence_tokens, extra_tokens

Return type:

Tuple[jnp.ndarray, jnp.ndarray]

get_train_test_split(test_size: float = 0.1, test_tickers: Optional[List[str]] = None) Tuple[Dataset, Dataset]#

Returns a train and test set from the dataset

Parameters:
  • test_size (float) – test size

  • test_tickers (Optional[List[str]]) – tickers to include in the test set. If None, all tickers are included

Returns:

train and test dataset

Return type:

Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]

class jaxer.utils.EarlyStopper(max_epochs: int)#

Bases: object

Early stopper class. Stops the training if the optimization metric does not improve for a number of epochs

Parameters:

max_epochs (int) – max number of epochs without improvement

class jaxer.utils.SyntheticDataset(config: SyntheticDatasetConfig)#

Bases: object

Synthetic dataset generator. It generates a window size historic data with sinusoidal signals with the same shape as the real dataset. Only historical prices are generated because volume and trades were not easily simulated.

Parameters:

config (SyntheticDatasetConfig) – configuration for the synthetic dataset

generator(batch_size: int, seed: Optional[int] = None)#

Generator for the synthetic dataset

Parameters:
  • batch_size (int) – batch size

  • seed (int) – seed for reproducibility

Returns:

generator for the synthetic dataset

get_random_input()#

Get a random input for the synthetic dataset

Returns:

random input for the synthetic dataset

static softmax(x: array)#

Compute softmax values for each sets of scores in x.

jaxer.utils.compute_metrics(x: Array, y_pred: Array, y_true: Array, normalizer: Array, denormalize_values: bool = True)#

Compute metrics for a batch of data

jaxer.utils.get_best_model(experiment_name: str) Tuple[Optional[str], str]#

Returns the best model from the experiment

Parameters:

experiment_name (str) – name of the experiment

Returns:

subfolder and checkpoint of the best model

Return type:

Tuple[Optional[str], str]

Raises:

FileNotFoundError – if the file is not found

jaxer.utils.get_logger()#

Provides the LOGGER_NAME logger. If is not the default logger, it means that the user has already managed to configure the logger with a env variable: ‘LOGGER_NAME’. If it is the default logger, we configure it with a StreamHandler and a default formatter. If the logger is already configured, we just return it (as a singleton).

Returns:

logging.Logger

jaxer.utils.jax_collate_fn(batch: List[ndarray]) Tuple#

Collate function for the jax dataset

Parameters:

batch (List[jnp.ndarray]) – batch of data

Returns:

batched data (sequence_tokens, extra_tokens), labels, norms, window_info

Return type:

Tuple

jaxer.utils.plot_metrics(mape: List[float], acc_dir: List[float], window_size: int = 10) None#
jaxer.utils.plot_predictions(x: Tuple[Array, Array], y_true: Array, y_pred: Array, normalizer: Array, window_info: Dict, folder_name: Optional[str] = None, image_name: str = 'pred', denormalize_values: bool = True) Dict#

Function to plot a window prediction. Batch size must be 1

Parameters:
  • x (Tuple[jnp.ndarray, jnp.ndarray]) – input data (sequence, extra tokens)

  • y_true (jnp.ndarray) – true values

  • y_pred (jnp.ndarray) – predicted values

  • normalizer (jnp.ndarray) – normalizer values

  • window_info (Dict) – window information

  • folder_name (Optional[str]) – folder name to save the image

  • image_name (str) – image name

  • denormalize_values (bool) – denormalize the values on the plot

jaxer.utils.plot_tensorboard_experiment(exp_path: str, save_path: Optional[str] = None, window_size: int = 5)#

Plot tensorboard experiment for documentation

Parameters:
  • exp_path (str) – path to the experiment

  • save_path (Optional[str]) – path to save the plot

  • window_size (int) – window size for the moving average

Config#

class jaxer.config.DatasetConfig(datapath: str, seq_len: int, norm_mode: str, output_mode: str, resolution: str, tickers: List[str], initial_date: Optional[str] = None, indicators: Optional[List[str]] = None, discrete_grid_levels: Optional[List[float]] = None, ohlc_only: bool = False, close_only: bool = False, return_mode: bool = False)#

Bases: object

Configuration class for the dataset

Parameters:
  • datapath (str) – path to the dataset

  • seq_len (int) – sequence length

  • norm_mode (str) – normalization mode

  • initial_date (Optional[str]) – initial date to start the dataset

  • output_mode (str) – output mode of the model (mean, distribution or discrete_grid)

  • discrete_grid_levels (Optional[List[float]]) – levels of the discrete grid (in percentage: e.g. [-9.e6, -2., 0.0, 2., 9.e6])

  • resolution (str) – resolution of the dataset (30m, 1h, 4h, all)

  • tickers (List[str]) – list of tickers (e.g. [‘btc_usd’, ‘eth_usd’])

  • indicators (Optional[List[str]]) – list of indicators (e.g. [‘rsi’, ‘bb_upper’, ‘bb_lower’, ‘bb_middle’, ‘ema_2h’, ‘ema_4h’])

  • ohlc_only (bool) – whether to use only ohlc data (pad everything else with -1)

  • close_only (bool) – whether to use only the close price

  • return_mode (bool) – whether to return the dataset

class jaxer.config.ExperimentConfig(model_config: ModelConfig, log_dir: str, experiment_name: str, num_epochs: int, learning_rate: float, lr_mode: str, weight_decay: float, warmup_epochs: int, dataset_mode: str, dataset_config: Optional[DatasetConfig], synthetic_dataset_config: Optional[SyntheticDatasetConfig], batch_size: int, test_split: float, test_tickers: List[str], seed: int, save_weights: bool, early_stopper: int, pretrained_model: Optional[Tuple[str, str, str]] = None, steps_per_epoch: int = 100, real_proportion: float = 0.3)#

Bases: object

Configuration class for a training experiment

Parameters:
  • model_config (ModelConfig) – model configuration (transformer architecture)

  • pretrained_model (Optional[Tuple[str, str, str]]) – experiment path and best model

  • log_dir (str) – directory to save the logs

  • experiment_name (str) – name of the experiment

  • num_epochs (int) – number of epochs to train

  • steps_per_epoch (int) – number of steps per epoch (for synthetic datasets)

  • learning_rate (float) – learning rate

  • lr_mode (str) – learning rate mode (cosine, linear or none)

  • weight_decay (float) – weight decay

  • warmup_epochs (int) – number of warmup epochs (for learning rate)

  • dataset_mode (str) – dataset mode (real or synthetic or both)

  • real_proportion (float) – proportion of real data to use (only for both mode)

  • dataset_config (DatasetConfig) – dataset configuration

  • synthetic_dataset_config (SyntheticDatasetConfig) – synthetic dataset configuration

  • batch_size (int) – batch size for training

  • test_split (float) – test split (between 0 and 1)

  • test_tickers (List[str]) – list of tickers to test

  • seed (int) – seed for reproducibility

class jaxer.config.ModelConfig(precision: str, d_model: int, num_layers: int, head_layers: int, n_heads: int, dim_feedforward: int, dropout: float, max_seq_len: int, flatten_encoder_output: bool, fe_blocks: int, use_time2vec: bool, output_mode: str, use_resblocks_in_head: bool, use_resblocks_in_fe: bool, use_extra_tokens: bool, average_encoder_output: bool, norm_encoder_prev: bool)#

Bases: object

Configuration class for the model

Parameters:
  • precision (str) – precision of the model (fp32 or fp16)

  • d_model (int) – dimension of the model

  • num_layers (int) – number of encoder layers in the transformer

  • head_layers (int) – number of layers in the prediction head

  • n_heads (int) – number of heads in the multihead attention

  • dim_feedforward (int) – dimension of the feedforward network

  • dropout (float) – dropout rate

  • max_seq_len (int) – maximum sequence length (context window)

  • flatten_encoder_output (bool) – whether to flatten the encoder output or get the last token

  • fe_blocks (int) – number of blocks in the feature extractor

  • use_time2vec (bool) – whether to use time2vec in the feature extractor

  • output_mode (str) – output mode of the model (mean, distribution or discrete_grid)

  • use_resblocks_in_head (bool) – whether to use residual blocks in the head

  • use_resblocks_in_fe (bool) – whether to use residual blocks in the feature extractor

  • use_extra_tokens (bool) – whether to use extra tokens in the model

  • average_encoder_output (bool) – whether to average the encoder output (if not flattened)

  • norm_encoder_prev (bool) – whether to normalize the encoder prev to the attention

class jaxer.config.SyntheticDatasetConfig(window_size: int, return_mode: bool = False, output_mode: str = 'mean', normalizer_mode: str = 'window_minmax', add_noise: bool = False, min_amplitude: float = 0.1, max_amplitude: float = 1.0, min_frequency: float = 0.5, max_frequency: float = 30, num_sinusoids: int = 3, max_linear_trend: float = 0.5, max_exp_trend: float = 0.01, precision: str = 'fp32', close_only: bool = False)#

Bases: object

Configuration class for a synthetic dataset

Parameters:
  • window_size (int) – size of the window

  • output_mode (str) – output mode (mean, distribution, discrete_grid)

  • normalizer_mode (str) – normalizer mode (window_meanstd, window_minmax)

  • add_noise (bool) – add noise to the signal

  • return_mode (bool) – return mode. If true, then normalizer cant be window_mean

  • min_amplitude (float) – minimum amplitude for the sinusoids

  • max_amplitude (float) – maximum amplitude for the sinusoids

  • min_frequency (float) – minimum frequency for the sinusoids

  • max_frequency (float) – maximum frequency for the sinusoids

  • num_sinusoids (int) – number of sinusoids to generate

  • max_linear_trend (float) – maximum linear trend (for short term trends)

  • max_exp_trend (float) – maximum exponential trend. Must be low as it grows exponentially (for long term trends)

  • precision (str) – precision of the model (fp32 or fp16)

  • close_only (bool) – whether to use only the close price

Models#

class jaxer.model.AddPositionalEncoding(config: ~jaxer.model.flax_transformer.TransformerConfig, parent: ~typing.Optional[~typing.Union[~typing.Type[~flax.linen.module.Module], ~flax.core.scope.Scope, ~typing.Type[~flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: ~typing.Optional[str] = None)#

Bases: Module

Add Positional Encoding Module (absolute positional encoding)

Parameters:

config (TransformerConfig) – transformer configuration

class jaxer.model.Encoder(config: ~jaxer.model.flax_transformer.TransformerConfig, parent: ~typing.Optional[~typing.Union[~typing.Type[~flax.linen.module.Module], ~flax.core.scope.Scope, ~typing.Type[~flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: ~typing.Optional[str] = None)#

Bases: Module

Encoder Module (L * encoder blocks). Uses time2vec or positional encoding to encode the input sequence and calls L encoder blocks

Parameters:

config (TransformerConfig) – transformer configuration

class jaxer.model.EncoderBlock(config: ~jaxer.model.flax_transformer.TransformerConfig, parent: ~typing.Optional[~typing.Union[~typing.Type[~flax.linen.module.Module], ~flax.core.scope.Scope, ~typing.Type[~flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: ~typing.Optional[str] = None)#

Bases: Module

Encoder Block Module (self attention followed by feed forward). Normalization can be applied to the input or

at the end (layer norm)

Parameters:

config (TransformerConfig) – transformer configuration

class jaxer.model.FeatureExtractor(config: ~jaxer.model.flax_transformer.TransformerConfig, parent: ~typing.Optional[~typing.Union[~typing.Type[~flax.linen.module.Module], ~flax.core.scope.Scope, ~typing.Type[~flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: ~typing.Optional[str] = None)#

Bases: Module

Feature Extractor Module based on residual MLP networks (of increasing shape) to get to d_model

Parameters:

config (TransformerConfig) – transformer configuration

class jaxer.model.FeedForwardBlock(config: ~jaxer.model.flax_transformer.TransformerConfig, out_dim: int, parent: ~typing.Optional[~typing.Union[~typing.Type[~flax.linen.module.Module], ~flax.core.scope.Scope, ~typing.Type[~flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: ~typing.Optional[str] = None)#

Bases: Module

Feed Forward Block Module (dense, gelu, dropout, dense, gelu, dropout)

Parameters:
class jaxer.model.PredictionHead(config: ~jaxer.model.flax_transformer.TransformerConfig, parent: ~typing.Optional[~typing.Union[~typing.Type[~flax.linen.module.Module], ~flax.core.scope.Scope, ~typing.Type[~flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: ~typing.Optional[str] = None)#

Bases: Module

Prediction Head Module. It can output a mean, a mean and a variance, or categorical probabilities. Residual blocks can be used in the head.

Parameters:

config (TransformerConfig) – transformer configuration

class jaxer.model.ResidualBlock(dtype: ~numpy.dtype, feature_dim: int, kernel_init: ~typing.Callable, bias_init: ~typing.Callable, config: ~jaxer.model.flax_transformer.TransformerConfig, parent: ~typing.Optional[~typing.Union[~typing.Type[~flax.linen.module.Module], ~flax.core.scope.Scope, ~typing.Type[~flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: ~typing.Optional[str] = None)#

Bases: Module

Residual Block Module with optional normalization of the input or at the end (layer norm)

Parameters:
  • dtype (jnp.dtype) – data type

  • feature_dim (int) – feature dimension

  • kernel_init (Callable) – kernel initializer

  • bias_init (Callable) – bias initializer

  • config (TransformerConfig) – transformer configuration

class jaxer.model.Time2Vec(dtype: ~numpy.dtype, kernel_init: ~typing.Callable, bias_init: ~typing.Callable, max_seq_len: int, d_model: int, parent: ~typing.Optional[~typing.Union[~typing.Type[~flax.linen.module.Module], ~flax.core.scope.Scope, ~typing.Type[~flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: ~typing.Optional[str] = None)#

Bases: Module

Time2Vec Module (from the paper Time2Vec: Learning a Vector Representation of Time)

Parameters:
  • dtype (jnp.dtype) – data type

  • kernel_init (Callable) – kernel initializer

  • bias_init (Callable) – bias initializer

  • max_seq_len (int) – maximum sequence length

  • d_model (int) – model embedding size

class jaxer.model.Transformer(config: ~jaxer.model.flax_transformer.TransformerConfig, parent: ~typing.Optional[~typing.Union[~typing.Type[~flax.linen.module.Module], ~flax.core.scope.Scope, ~typing.Type[~flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: ~typing.Optional[str] = None)#

Bases: Module

Transformer model

  1. Encoder

  2. Flatten/Average/Last Element of the output of the encoder

  3. Prediction Head -> mean, variance, or categorical probabilities

Parameters:

config (TransformerConfig) – transformer configuration

class jaxer.model.TransformerConfig(d_model: int = 512, n_heads: int = 8, num_layers: int = 6, head_layers: int = 2, dim_feedforward: int = 2048, max_seq_len: int = 256, dropout: float = 0.0, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, kernel_init: ~typing.Callable = <function variance_scaling.<locals>.init>, bias_init: ~typing.Callable = <function normal.<locals>.init>, deterministic: bool = False, flatten_encoder_output: bool = False, fe_blocks: int = 2, use_time2vec: bool = True, output_mode: str = 'distribution', discrete_grid_levels: int = 2, use_resblocks_in_head: bool = True, use_resblocks_in_fe: bool = True, use_extra_tokens: bool = False, average_encoder_output: bool = False, norm_encoder_prev: bool = False, univariate: bool = False)#

Bases: object

Transformer model configuration

Parameters:
  • d_model (int) – model embedding size

  • n_heads (int) – number of attention heads

  • num_layers (int) – number of layers

  • head_layers (int) – number of layers in the head

  • dim_feedforward (int) – feedforward dimension

  • max_seq_len (int) – maximum sequence length

  • dropout (float) – dropout rate

  • dtype (jnp.dtype) – data type

  • kernel_init (Callable) – kernel initializer

  • bias_init (Callable) – bias initializer

  • deterministic (bool) – whether the model is deterministic

  • flatten_encoder_output (bool) – whether to flatten the encoder output

  • fe_blocks (int) – number of feature extractor blocks

  • use_time2vec (bool) – whether to use time2vec

  • output_mode (str) – output mode

  • discrete_grid_levels (int) – number of discrete grid levels

  • use_resblocks_in_head (bool) – whether to use residual blocks in the head

  • use_resblocks_in_fe (bool) – whether to use residual blocks in the feature extractor

  • use_extra_tokens (bool) – whether to use extra tokens

  • average_encoder_output (bool) – whether to average the encoder output

  • norm_encoder_prev (bool) – whether to normalize the encoder output

  • univariate (bool) – whether the input is univariate

replace(**updates)#

“Returns a new object replacing the specified fields with new values.