MDNRNNTrainer(mdnrnn_network: ml.rl.models.world_model.MemoryNetwork, params: ml.rl.parameters.MDNRNNParameters, cum_loss_hist: int = 100)¶
Trainer for MDN-RNN
get_loss(training_batch: ml.rl.types.PreprocessedTrainingBatch, state_dim: Optional[int] = None, batch_first: bool = False)¶
- Compute losses:
GMMLoss(next_state, GMMPredicted) / (STATE_DIM + 2) + MSE(reward, predicted_reward) + BCE(not_terminal, logit_not_terminal)
- The STATE_DIM + 2 factor is here to counteract the fact that the GMMLoss scales
approximately linearly with STATE_DIM, the feature size of states. All losses are averaged both on the batch and the sequence dimensions (the two first dimensions).
training_batch – training_batch.learning_input has these fields: - state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor - action: (BATCH_SIZE, SEQ_LEN, ACTION_DIM) torch tensor - reward: (BATCH_SIZE, SEQ_LEN) torch tensor - not-terminal: (BATCH_SIZE, SEQ_LEN) torch tensor - next_state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor the first two dimensions may be swapped depending on batch_first
state_dim – the dimension of states. If provided, use it to normalize gmm loss
batch_first – whether data’s first dimension represents batch size. If FALSE, state, action, reward, not-terminal, and next_state’s first two dimensions are SEQ_LEN and BATCH_SIZE.
dictionary of losses, containing the gmm, the mse, the bce and the averaged loss.
train(training_batch: ml.rl.types.PreprocessedTrainingBatch, batch_first: bool = False)¶