reagent.gym.preprocessors package

Submodules

reagent.gym.preprocessors.default_preprocessors module

Get default preprocessors for training time.

class reagent.gym.preprocessors.default_preprocessors.RecsimObsPreprocessor(*, num_docs: int, discrete_keys: List[Tuple[str, int]], box_keys: List[Tuple[str, int]])

Bases: object

classmethod create_from_env(env: gym.core.Env, **kwargs)

reagent.gym.preprocessors.replay_buffer_inserters module

class reagent.gym.preprocessors.replay_buffer_inserters.BasicReplayBufferInserter

Bases: object

class reagent.gym.preprocessors.replay_buffer_inserters.RecSimReplayBufferInserter(*, num_docs: int, num_responses: int, discrete_keys: List[str], box_keys: List[str], response_discrete_keys: List[Tuple[str, int]], response_box_keys: List[Tuple[str, Tuple[int]]], augmentation_discrete_keys: List[str], augmentation_box_keys: List[str])

Bases: object

classmethod create_for_env(env: gym.core.Env)
reagent.gym.preprocessors.replay_buffer_inserters.make_replay_buffer_inserter(env: gym.core.Env) Callable[[reagent.replay_memory.circular_replay_buffer.ReplayBuffer, reagent.gym.types.Transition], None]

reagent.gym.preprocessors.trainer_preprocessor module

Get default preprocessors for training time.

class reagent.gym.preprocessors.trainer_preprocessor.DiscreteDqnInputMaker(num_actions: int, trainer_preprocessor=None)

Bases: object

classmethod create_for_env(env: gym.core.Env)
class reagent.gym.preprocessors.trainer_preprocessor.MemoryNetworkInputMaker(num_actions: Optional[int] = None)

Bases: object

classmethod create_for_env(env: gym.core.Env)
class reagent.gym.preprocessors.trainer_preprocessor.ParametricDqnInputMaker(num_actions: int)

Bases: object

classmethod create_for_env(env: gym.core.Env)
class reagent.gym.preprocessors.trainer_preprocessor.PolicyGradientInputMaker(num_actions: Optional[int] = None, recsim_obs: bool = False)

Bases: object

classmethod create_for_env(env: gym.core.Env)
class reagent.gym.preprocessors.trainer_preprocessor.PolicyNetworkInputMaker(action_low: numpy.ndarray, action_high: numpy.ndarray)

Bases: object

classmethod create_for_env(env: gym.core.Env)
class reagent.gym.preprocessors.trainer_preprocessor.SlateQInputMaker

Bases: object

classmethod create_for_env(env: gym.core.Env)
reagent.gym.preprocessors.trainer_preprocessor.get_possible_actions_for_gym(batch_size: int, num_actions: int) reagent.core.types.FeatureData

tiled_actions should be (batch_size * num_actions, num_actions) forall i in [batch_size], tiled_actions[i*num_actions:(i+1)*num_actions] should be I[num_actions] where I[n] is the n-dimensional identity matrix. NOTE: this is only the case for when we convert discrete action to parametric action via one-hot encoding.

reagent.gym.preprocessors.trainer_preprocessor.make_replay_buffer_trainer_preprocessor(trainer: reagent.training.reagent_lightning_module.ReAgentLightningModule, device: torch.device, env: gym.core.Env)
reagent.gym.preprocessors.trainer_preprocessor.make_trainer_preprocessor(trainer: reagent.training.reagent_lightning_module.ReAgentLightningModule, device: torch.device, env: gym.core.Env, maker_map: Dict)
reagent.gym.preprocessors.trainer_preprocessor.make_trainer_preprocessor_online(trainer: reagent.training.reagent_lightning_module.ReAgentLightningModule, device: torch.device, env: gym.core.Env)
reagent.gym.preprocessors.trainer_preprocessor.one_hot_actions(num_actions: int, action: torch.Tensor, next_action: torch.Tensor, terminal: torch.Tensor)

One-hot encode actions and non-terminal next actions. Input shape is (batch_size, 1). Output shape is (batch_size, num_actions)

Module contents

reagent.gym.preprocessors.make_replay_buffer_inserter(env: gym.core.Env) Callable[[reagent.replay_memory.circular_replay_buffer.ReplayBuffer, reagent.gym.types.Transition], None]
reagent.gym.preprocessors.make_replay_buffer_trainer_preprocessor(trainer: reagent.training.reagent_lightning_module.ReAgentLightningModule, device: torch.device, env: gym.core.Env)
reagent.gym.preprocessors.make_trainer_preprocessor_online(trainer: reagent.training.reagent_lightning_module.ReAgentLightningModule, device: torch.device, env: gym.core.Env)