Training#

This section describes the optimization process of the model. I want to delve a bit deeper into the training process in jax and flax as it has some differences from the usual torch workflow. I assume you are familiar with deep learning so this will not be a tutorial.

First of all, we need an optimization library for jax: optax. This library allows to define learning rate schedulers and also optimizers such as adamw. In the configuration file you can chose from a linear learning rate scheduler or a warmup cosine learning rate scheduler. Last one is widely used when training transformer models. optax does not provide this scheduler by default, but can be easily implemented by joining the cosine and linear schedulers:

def create_warmup_cosine_schedule(learning_rate: float,
                                  warmup_epochs: int,
                                  num_epochs: int,
                                  steps_per_epoch: int) -> optax.Schedule:
    """Creates learning rate cosine schedule

    :param learning_rate: initial learning rate
    :type learning_rate: float

    :param warmup_epochs: number of warmup epochs
    :type warmup_epochs: int

    :param num_epochs: total number of epochs
    :type num_epochs: int

    :param steps_per_epoch: number of steps per epoch
    :type steps_per_epoch: int

    :return: learning rate schedule
    :rtype: optax.Schedule
    """
    warmup_fn = optax.linear_schedule(
        init_value=0., end_value=learning_rate,
        transition_steps=warmup_epochs * steps_per_epoch)
    cosine_epochs = max(num_epochs - warmup_epochs, 1)
    cosine_fn = optax.cosine_decay_schedule(
        init_value=learning_rate,
        decay_steps=cosine_epochs * steps_per_epoch)
    schedule_fn = optax.join_schedules(
        schedules=[warmup_fn, cosine_fn],
        boundaries=[warmup_epochs * steps_per_epoch])
    return schedule_fn

Even optax also had several loss functions implemented, I coded my own functions (code is indeed very similar). flax does not store the model parameters in the model itself. This the main difference from torch. Instead, the model parameters are stored in a separate params object. This object is passed to the apply method of the model. Therefore, the process of calling (inference) the model will be:

  1. params = model.init(rng, input_shape) to initialize the model parameters

  2. output = model.apply(params, input) to make a forward pass

However, flax has a TrainState object that stores the model parameters, the optimizer state and the learning rate scheduler to keep everything in one place. I found this really helpful. You can find more information on the official documentation. TrainState is easily created by:

optimizer = optax.adamw(learning_rate=learning_rate_scheduler)

train_state = TrainState.create(
    apply_fn=model.apply,
    params=initial_params,
    tx=optimizer
)

Another key difference is jax randomness. To obtain random numbers during training, we start with a random.PRNGKey(seed) and then we split this key to obtain a new key and a new subkey for each operation. This must be done if using random layers in the model such as dropout. In torch we only need to set the seed once and then everything is run from there. jax is a bit more complex, but comes to solve the following issue:

Note

The problem with torch magic PRNG state is that it’s hard to reason about how it’s being used and updated across different threads, processes, and devices, and it’s very easy to screw up when the details of entropy production and consumption are hidden from the end user.

Additionally, I found interesting the use of orbax to manage checkpoint saving and loading and it is recommended by flax. For instance, we can define a checkpoint manager that saves up to 5 best models:

self._orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
options = orbax.checkpoint.CheckpointManagerOptions(create=True, max_to_keep=5)
self._checkpoint_manager = orbax.checkpoint.CheckpointManager(
    str(self._ckpts_dir), self._orbax_checkpointer, options)

# to save the model
if test_metric < best_test_metric:
    ckpt = {'model': trained_state}
    save_args = orbax_utils.save_args_from_target(ckpt)
    self._checkpoint_manager.save(epoch, ckpt, save_kwargs={'save_args': save_args})

# to load the model
restored_state = self._orbax_checkpointer.restore('model_path')['model']

To later visualize the training process, I used tensorboard to record logs of train/test metrics. This is a very useful tool to visualize the Metrics of train and test set. There is also an early stopper class to stop the training process if the test metric does not improve after a certain number of epochs.

@dataclass
class EarlyStopper:
    """Early stopper class

    :param max_epochs: max number of epochs without improvement
    :type max_epochs: int

    :param n_epochs: number of epochs without improvement
    :type n_epochs: int

    :param optim_value: best optimization value
    :type optim_value: float
    """
    max_epochs: int
    n_epochs: int = 0
    optim_value: float = 1e9

    def __call__(self, optim_value: float):
        """ Returns True if the training should stop """
        if optim_value < self.optim_value:
            self.optim_value = optim_value
            self.n_epochs = 0
            return False

        self.n_epochs += 1

        if self.n_epochs >= self.max_epochs:
            return True

        return False

Configuration#

Training configuration must be filled on its dataclass:

model_config: ModelConfig  # model configuration (transformer)
log_dir: str  # directory to save logs
experiment_name: str  # experiment name (logs will be saved on log_dir/experiment_name)
num_epochs: int  # number of epochs
learning_rate: float  # initial learning rate
lr_mode: str  # learning rate scheduler mode (linear or cosine)
warmup_epochs: int  # number of warmup epochs
dataset_config: DatasetConfig  # dataset configuration
batch_size: int  # batch size
test_split: float  # test split (between 0 and 1)
test_tickers: List[str]  # tickers to test
seed: int  # initial seed for reproducibility
save_weights: bool  # save weights during training
early_stopper: int  # early stopper patience (number of epochs without improvement)

Metrics#

To proper evaluate how good is the model, we need to declare some metrics. As we have two main approaches: classification and regression, the following table shows the metrics used for each case:

Task

Metric

Classification

Class Accuracy (acc_class), Direction Accuracy (acc_dir)

Regression

Mean Squared Error (mae), Mean Average Percentage Error (mape), R2 Score (r2), Mean Absolute Error (mae), Direction Accuracy (acc_dir)

Note

Metrics were initially computed with normalized data, but it did not allow to compare over different normalization methods (the only normalization independent metric was mape). For comparison reasons, I decided to denormalize predictions and compute metrics with the original data. This way, we can compare the metrics over different models and normalization methods. I found absolute magnitudes such as mse not to be very explanatory as it is not the same to have a mse of 2$ when price is around 1 than when price is at 20000$.

I have designed a relative custom metric acc_dir to measure the accuracy of the model to predict the right direction of the price. This is motivated by the need to have relative (independent of price magnitude) comparative metrics and give intuition of the model’s ability to surf the price waves. As this metric is independent of price, it can be computed on each of three time series prediction approaches developed in this work. acc_dir defined as:

@jax.jit
def acc_dir(y_pred: jnp.ndarray, y_true: jnp.ndarray, last_price: jnp.ndarray) -> jnp.ndarray:
    """Direction accuracy metric

    :param y_pred: predicted values
    :type y_pred: jnp.ndarray

    :param y_true: true values
    :type y_true: jnp.ndarray

    :param last_price: last close price
    :type last_price: jnp.ndarray

    :return: direction accuracy
    :rtype: jnp.ndarray
    """

    y_true = jnp.sign(y_true[:, 0] - last_price)  # y_true is (batch, 1) and last_price is (batch, )
    y_pred = jnp.sign(y_pred[:, 0] - last_price)

    return jnp.mean(y_true == y_pred)