ml.rl.training.world_model package

Submodules

ml.rl.training.world_model.mdnrnn_trainer module

class ml.rl.training.world_model.mdnrnn_trainer.MDNRNNTrainer(mdnrnn_network: ml.rl.models.world_model.MemoryNetwork, params: ml.rl.parameters.MDNRNNParameters, cum_loss_hist: int = 100)

Bases: object

Trainer for MDN-RNN

get_loss(training_batch: ml.rl.types.PreprocessedTrainingBatch, state_dim: Optional[int] = None, batch_first: bool = False)
Compute losses:

GMMLoss(next_state, GMMPredicted) / (STATE_DIM + 2) + MSE(reward, predicted_reward) + BCE(not_terminal, logit_not_terminal)

The STATE_DIM + 2 factor is here to counteract the fact that the GMMLoss scales

approximately linearly with STATE_DIM, the feature size of states. All losses are averaged both on the batch and the sequence dimensions (the two first dimensions).

Parameters
  • training_batch – training_batch.learning_input has these fields: - state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor - action: (BATCH_SIZE, SEQ_LEN, ACTION_DIM) torch tensor - reward: (BATCH_SIZE, SEQ_LEN) torch tensor - not-terminal: (BATCH_SIZE, SEQ_LEN) torch tensor - next_state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor the first two dimensions may be swapped depending on batch_first

  • state_dim – the dimension of states. If provided, use it to normalize gmm loss

  • batch_first – whether data’s first dimension represents batch size. If FALSE, state, action, reward, not-terminal, and next_state’s first two dimensions are SEQ_LEN and BATCH_SIZE.

Returns

dictionary of losses, containing the gmm, the mse, the bce and the averaged loss.

train(training_batch: ml.rl.types.PreprocessedTrainingBatch, batch_first: bool = False)

Module contents