API#
This documentation provides an overview of the API for jaxer.
Modules Overview#
Run#
- class jaxer.run.AgentBase(experiment: str, model_name: Tuple[Optional[str], str])#
Bases:
objectAgent base class to load a model and perform inference (jax, torch, …)
- Parameters:
- Raises:
FileNotFoundError – if the experiment or model does not exist
- class jaxer.run.FlaxAgent(experiment: str, model_name: Tuple[Optional[str], str])#
Bases:
AgentBaseAgent class to load a model and perform inference
- class jaxer.run.FlaxTrainer(config: ExperimentConfig)#
Bases:
TrainerBaseTrainer class for training jaxer using flax, optax and jax
- Parameters:
config (Config) – training config for running an experiment
- class jaxer.run.TrainerBase(config: ExperimentConfig)#
Bases:
objectTrainer base class. All trainers should inherit from this (jax, torch, …)
- Parameters:
config (Config) – the configuration for the experiment
Utils#
- class jaxer.utils.Dataset(dataset_config: DatasetConfig)#
Bases:
DatasetFinance dataset class for training jaxer (pytorch dataset)
- Parameters:
dataset_config (DatasetConfig) – DatasetConfig object
- Raises:
ValueError – if the norm_mode is not one of [‘window_minmax’, ‘window_meanstd’, ‘global_minmax’, ‘global_meanstd’, ‘none’]
ValueError – if the indicators are not in the dataset [‘rsi’, ‘bb_upper’, ‘bb_lower’, ‘bb_middle’]
ValueError – if the discrete_grid_levels are not provided and the output_mode is discrete_grid
ValueError – if the ticker is not in the dataset
- INDICATORS = ['rsi', 'bb_upper', 'bb_lower', 'bb_middle']#
- INDICATORS_TO_NORMALIZE = ['bb_upper', 'bb_lower', 'bb_middle']#
- NORM_MODES = ['window_minmax', 'window_meanstd', 'window_mean', 'global_minmax', 'global_meanstd', 'none']#
- OHLC = ['open', 'high', 'low', 'close']#
- static encode_tokens(tokens: ndarray) ndarray#
Encodes the tokens into integer (tokens are expected to be on [0, 1])
- Parameters:
tokens (np.ndarray) – tokens to encode
- Returns:
encoded tokens
- Return type:
np.ndarray
- get_random_input()#
Returns a random input from the dataset
- Returns:
sequence_tokens, extra_tokens
- Return type:
Tuple[jnp.ndarray, jnp.ndarray]
- class jaxer.utils.EarlyStopper(max_epochs: int)#
Bases:
objectEarly stopper class. Stops the training if the optimization metric does not improve for a number of epochs
- Parameters:
max_epochs (int) – max number of epochs without improvement
- class jaxer.utils.SyntheticDataset(config: SyntheticDatasetConfig)#
Bases:
objectSynthetic dataset generator. It generates a window size historic data with sinusoidal signals with the same shape as the real dataset. Only historical prices are generated because volume and trades were not easily simulated.
- Parameters:
config (SyntheticDatasetConfig) – configuration for the synthetic dataset
- get_random_input()#
Get a random input for the synthetic dataset
- Returns:
random input for the synthetic dataset
- static softmax(x: array)#
Compute softmax values for each sets of scores in x.
- jaxer.utils.compute_metrics(x: Array, y_pred: Array, y_true: Array, normalizer: Array, denormalize_values: bool = True)#
Compute metrics for a batch of data
- jaxer.utils.get_best_model(experiment_name: str) Tuple[Optional[str], str]#
Returns the best model from the experiment
- Parameters:
experiment_name (str) – name of the experiment
- Returns:
subfolder and checkpoint of the best model
- Return type:
- Raises:
FileNotFoundError – if the file is not found
- jaxer.utils.get_logger()#
Provides the LOGGER_NAME logger. If is not the default logger, it means that the user has already managed to configure the logger with a env variable: ‘LOGGER_NAME’. If it is the default logger, we configure it with a StreamHandler and a default formatter. If the logger is already configured, we just return it (as a singleton).
- Returns:
logging.Logger
- jaxer.utils.jax_collate_fn(batch: List[ndarray]) Tuple#
Collate function for the jax dataset
- Parameters:
batch (List[jnp.ndarray]) – batch of data
- Returns:
batched data (sequence_tokens, extra_tokens), labels, norms, window_info
- Return type:
Tuple
- jaxer.utils.plot_predictions(x: Tuple[Array, Array], y_true: Array, y_pred: Array, normalizer: Array, window_info: Dict, folder_name: Optional[str] = None, image_name: str = 'pred', denormalize_values: bool = True) Dict#
Function to plot a window prediction. Batch size must be 1
- Parameters:
x (Tuple[jnp.ndarray, jnp.ndarray]) – input data (sequence, extra tokens)
y_true (jnp.ndarray) – true values
y_pred (jnp.ndarray) – predicted values
normalizer (jnp.ndarray) – normalizer values
window_info (Dict) – window information
folder_name (Optional[str]) – folder name to save the image
image_name (str) – image name
denormalize_values (bool) – denormalize the values on the plot
Config#
- class jaxer.config.DatasetConfig(datapath: str, seq_len: int, norm_mode: str, output_mode: str, resolution: str, tickers: List[str], initial_date: Optional[str] = None, indicators: Optional[List[str]] = None, discrete_grid_levels: Optional[List[float]] = None, ohlc_only: bool = False, close_only: bool = False, return_mode: bool = False)#
Bases:
objectConfiguration class for the dataset
- Parameters:
datapath (str) – path to the dataset
seq_len (int) – sequence length
norm_mode (str) – normalization mode
initial_date (Optional[str]) – initial date to start the dataset
output_mode (str) – output mode of the model (mean, distribution or discrete_grid)
discrete_grid_levels (Optional[List[float]]) – levels of the discrete grid (in percentage: e.g. [-9.e6, -2., 0.0, 2., 9.e6])
resolution (str) – resolution of the dataset (30m, 1h, 4h, all)
tickers (List[str]) – list of tickers (e.g. [‘btc_usd’, ‘eth_usd’])
indicators (Optional[List[str]]) – list of indicators (e.g. [‘rsi’, ‘bb_upper’, ‘bb_lower’, ‘bb_middle’, ‘ema_2h’, ‘ema_4h’])
ohlc_only (bool) – whether to use only ohlc data (pad everything else with -1)
close_only (bool) – whether to use only the close price
return_mode (bool) – whether to return the dataset
- class jaxer.config.ExperimentConfig(model_config: ModelConfig, log_dir: str, experiment_name: str, num_epochs: int, learning_rate: float, lr_mode: str, weight_decay: float, warmup_epochs: int, dataset_mode: str, dataset_config: Optional[DatasetConfig], synthetic_dataset_config: Optional[SyntheticDatasetConfig], batch_size: int, test_split: float, test_tickers: List[str], seed: int, save_weights: bool, early_stopper: int, pretrained_model: Optional[Tuple[str, str, str]] = None, steps_per_epoch: int = 100, real_proportion: float = 0.3)#
Bases:
objectConfiguration class for a training experiment
- Parameters:
model_config (ModelConfig) – model configuration (transformer architecture)
pretrained_model (Optional[Tuple[str, str, str]]) – experiment path and best model
log_dir (str) – directory to save the logs
experiment_name (str) – name of the experiment
num_epochs (int) – number of epochs to train
steps_per_epoch (int) – number of steps per epoch (for synthetic datasets)
learning_rate (float) – learning rate
lr_mode (str) – learning rate mode (cosine, linear or none)
weight_decay (float) – weight decay
warmup_epochs (int) – number of warmup epochs (for learning rate)
dataset_mode (str) – dataset mode (real or synthetic or both)
real_proportion (float) – proportion of real data to use (only for both mode)
dataset_config (DatasetConfig) – dataset configuration
synthetic_dataset_config (SyntheticDatasetConfig) – synthetic dataset configuration
batch_size (int) – batch size for training
test_split (float) – test split (between 0 and 1)
test_tickers (List[str]) – list of tickers to test
seed (int) – seed for reproducibility
- class jaxer.config.ModelConfig(precision: str, d_model: int, num_layers: int, head_layers: int, n_heads: int, dim_feedforward: int, dropout: float, max_seq_len: int, flatten_encoder_output: bool, fe_blocks: int, use_time2vec: bool, output_mode: str, use_resblocks_in_head: bool, use_resblocks_in_fe: bool, use_extra_tokens: bool, average_encoder_output: bool, norm_encoder_prev: bool)#
Bases:
objectConfiguration class for the model
- Parameters:
precision (str) – precision of the model (fp32 or fp16)
d_model (int) – dimension of the model
num_layers (int) – number of encoder layers in the transformer
head_layers (int) – number of layers in the prediction head
n_heads (int) – number of heads in the multihead attention
dim_feedforward (int) – dimension of the feedforward network
dropout (float) – dropout rate
max_seq_len (int) – maximum sequence length (context window)
flatten_encoder_output (bool) – whether to flatten the encoder output or get the last token
fe_blocks (int) – number of blocks in the feature extractor
use_time2vec (bool) – whether to use time2vec in the feature extractor
output_mode (str) – output mode of the model (mean, distribution or discrete_grid)
use_resblocks_in_head (bool) – whether to use residual blocks in the head
use_resblocks_in_fe (bool) – whether to use residual blocks in the feature extractor
use_extra_tokens (bool) – whether to use extra tokens in the model
average_encoder_output (bool) – whether to average the encoder output (if not flattened)
norm_encoder_prev (bool) – whether to normalize the encoder prev to the attention
- class jaxer.config.SyntheticDatasetConfig(window_size: int, return_mode: bool = False, output_mode: str = 'mean', normalizer_mode: str = 'window_minmax', add_noise: bool = False, min_amplitude: float = 0.1, max_amplitude: float = 1.0, min_frequency: float = 0.5, max_frequency: float = 30, num_sinusoids: int = 3, max_linear_trend: float = 0.5, max_exp_trend: float = 0.01, precision: str = 'fp32', close_only: bool = False)#
Bases:
objectConfiguration class for a synthetic dataset
- Parameters:
window_size (int) – size of the window
output_mode (str) – output mode (mean, distribution, discrete_grid)
normalizer_mode (str) – normalizer mode (window_meanstd, window_minmax)
add_noise (bool) – add noise to the signal
return_mode (bool) – return mode. If true, then normalizer cant be window_mean
min_amplitude (float) – minimum amplitude for the sinusoids
max_amplitude (float) – maximum amplitude for the sinusoids
min_frequency (float) – minimum frequency for the sinusoids
max_frequency (float) – maximum frequency for the sinusoids
num_sinusoids (int) – number of sinusoids to generate
max_linear_trend (float) – maximum linear trend (for short term trends)
max_exp_trend (float) – maximum exponential trend. Must be low as it grows exponentially (for long term trends)
precision (str) – precision of the model (fp32 or fp16)
close_only (bool) – whether to use only the close price
Models#
- class jaxer.model.AddPositionalEncoding(config: ~jaxer.model.flax_transformer.TransformerConfig, parent: ~typing.Optional[~typing.Union[~typing.Type[~flax.linen.module.Module], ~flax.core.scope.Scope, ~typing.Type[~flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: ~typing.Optional[str] = None)#
Bases:
ModuleAdd Positional Encoding Module (absolute positional encoding)
- Parameters:
config (TransformerConfig) – transformer configuration
- class jaxer.model.Encoder(config: ~jaxer.model.flax_transformer.TransformerConfig, parent: ~typing.Optional[~typing.Union[~typing.Type[~flax.linen.module.Module], ~flax.core.scope.Scope, ~typing.Type[~flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: ~typing.Optional[str] = None)#
Bases:
ModuleEncoder Module (L * encoder blocks). Uses time2vec or positional encoding to encode the input sequence and calls L encoder blocks
- Parameters:
config (TransformerConfig) – transformer configuration
- class jaxer.model.EncoderBlock(config: ~jaxer.model.flax_transformer.TransformerConfig, parent: ~typing.Optional[~typing.Union[~typing.Type[~flax.linen.module.Module], ~flax.core.scope.Scope, ~typing.Type[~flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: ~typing.Optional[str] = None)#
Bases:
Module- Encoder Block Module (self attention followed by feed forward). Normalization can be applied to the input or
at the end (layer norm)
- Parameters:
config (TransformerConfig) – transformer configuration
- class jaxer.model.FeatureExtractor(config: ~jaxer.model.flax_transformer.TransformerConfig, parent: ~typing.Optional[~typing.Union[~typing.Type[~flax.linen.module.Module], ~flax.core.scope.Scope, ~typing.Type[~flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: ~typing.Optional[str] = None)#
Bases:
ModuleFeature Extractor Module based on residual MLP networks (of increasing shape) to get to d_model
- Parameters:
config (TransformerConfig) – transformer configuration
- class jaxer.model.FeedForwardBlock(config: ~jaxer.model.flax_transformer.TransformerConfig, out_dim: int, parent: ~typing.Optional[~typing.Union[~typing.Type[~flax.linen.module.Module], ~flax.core.scope.Scope, ~typing.Type[~flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: ~typing.Optional[str] = None)#
Bases:
ModuleFeed Forward Block Module (dense, gelu, dropout, dense, gelu, dropout)
- Parameters:
config (TransformerConfig) – transformer configuration
out_dim (int) – output dimension
- class jaxer.model.PredictionHead(config: ~jaxer.model.flax_transformer.TransformerConfig, parent: ~typing.Optional[~typing.Union[~typing.Type[~flax.linen.module.Module], ~flax.core.scope.Scope, ~typing.Type[~flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: ~typing.Optional[str] = None)#
Bases:
ModulePrediction Head Module. It can output a mean, a mean and a variance, or categorical probabilities. Residual blocks can be used in the head.
- Parameters:
config (TransformerConfig) – transformer configuration
- class jaxer.model.ResidualBlock(dtype: ~numpy.dtype, feature_dim: int, kernel_init: ~typing.Callable, bias_init: ~typing.Callable, config: ~jaxer.model.flax_transformer.TransformerConfig, parent: ~typing.Optional[~typing.Union[~typing.Type[~flax.linen.module.Module], ~flax.core.scope.Scope, ~typing.Type[~flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: ~typing.Optional[str] = None)#
Bases:
ModuleResidual Block Module with optional normalization of the input or at the end (layer norm)
- Parameters:
dtype (jnp.dtype) – data type
feature_dim (int) – feature dimension
kernel_init (Callable) – kernel initializer
bias_init (Callable) – bias initializer
config (TransformerConfig) – transformer configuration
- class jaxer.model.Time2Vec(dtype: ~numpy.dtype, kernel_init: ~typing.Callable, bias_init: ~typing.Callable, max_seq_len: int, d_model: int, parent: ~typing.Optional[~typing.Union[~typing.Type[~flax.linen.module.Module], ~flax.core.scope.Scope, ~typing.Type[~flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: ~typing.Optional[str] = None)#
Bases:
ModuleTime2Vec Module (from the paper Time2Vec: Learning a Vector Representation of Time)
- class jaxer.model.Transformer(config: ~jaxer.model.flax_transformer.TransformerConfig, parent: ~typing.Optional[~typing.Union[~typing.Type[~flax.linen.module.Module], ~flax.core.scope.Scope, ~typing.Type[~flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: ~typing.Optional[str] = None)#
Bases:
ModuleTransformer model
Encoder
Flatten/Average/Last Element of the output of the encoder
Prediction Head -> mean, variance, or categorical probabilities
- Parameters:
config (TransformerConfig) – transformer configuration
- class jaxer.model.TransformerConfig(d_model: int = 512, n_heads: int = 8, num_layers: int = 6, head_layers: int = 2, dim_feedforward: int = 2048, max_seq_len: int = 256, dropout: float = 0.0, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, kernel_init: ~typing.Callable = <function variance_scaling.<locals>.init>, bias_init: ~typing.Callable = <function normal.<locals>.init>, deterministic: bool = False, flatten_encoder_output: bool = False, fe_blocks: int = 2, use_time2vec: bool = True, output_mode: str = 'distribution', discrete_grid_levels: int = 2, use_resblocks_in_head: bool = True, use_resblocks_in_fe: bool = True, use_extra_tokens: bool = False, average_encoder_output: bool = False, norm_encoder_prev: bool = False, univariate: bool = False)#
Bases:
objectTransformer model configuration
- Parameters:
d_model (int) – model embedding size
n_heads (int) – number of attention heads
num_layers (int) – number of layers
head_layers (int) – number of layers in the head
dim_feedforward (int) – feedforward dimension
max_seq_len (int) – maximum sequence length
dropout (float) – dropout rate
dtype (jnp.dtype) – data type
kernel_init (Callable) – kernel initializer
bias_init (Callable) – bias initializer
deterministic (bool) – whether the model is deterministic
flatten_encoder_output (bool) – whether to flatten the encoder output
fe_blocks (int) – number of feature extractor blocks
use_time2vec (bool) – whether to use time2vec
output_mode (str) – output mode
discrete_grid_levels (int) – number of discrete grid levels
use_resblocks_in_head (bool) – whether to use residual blocks in the head
use_resblocks_in_fe (bool) – whether to use residual blocks in the feature extractor
use_extra_tokens (bool) – whether to use extra tokens
average_encoder_output (bool) – whether to average the encoder output
norm_encoder_prev (bool) – whether to normalize the encoder output
univariate (bool) – whether the input is univariate
- replace(**updates)#
“Returns a new object replacing the specified fields with new values.