reagent.model_managers package

Subpackages

Submodules

reagent.model_managers.actor_critic_base module

class reagent.model_managers.actor_critic_base.ActorCriticBase(state_preprocessing_options: Optional[reagent.workflow.types.PreprocessingOptions] = None, action_preprocessing_options: Optional[reagent.workflow.types.PreprocessingOptions] = None, action_feature_override: Optional[str] = None, state_feature_config_provider: reagent.workflow.types.ModelFeatureConfigProvider__Union = <factory>, action_float_features: List[Tuple[int, str]] = <factory>, reader_options: Optional[reagent.workflow.types.ReaderOptions] = None, eval_parameters: reagent.core.parameters.EvaluationParameters = <factory>, save_critic_bool: bool = True)

Bases: reagent.model_managers.model_manager.ModelManager

property action_feature_config: reagent.core.types.ModelFeatureConfig
action_feature_override: Optional[str] = None
action_float_features: List[Tuple[int, str]]
action_preprocessing_options: Optional[reagent.workflow.types.PreprocessingOptions] = None
create_policy(trainer_module: reagent.training.reagent_lightning_module.ReAgentLightningModule, serving: bool = False, normalization_data_map: Optional[Dict[str, reagent.core.parameters.NormalizationData]] = None) reagent.gym.policies.policy.Policy

Create online actor critic policy.

eval_parameters: reagent.core.parameters.EvaluationParameters
get_action_preprocessing_options() reagent.workflow.types.PreprocessingOptions
get_data_module(*, input_table_spec: Optional[reagent.workflow.types.TableSpec] = None, reward_options: Optional[reagent.workflow.types.RewardOptions] = None, reader_options: Optional[reagent.workflow.types.ReaderOptions] = None, setup_data: Optional[Dict[str, bytes]] = None, saved_setup_data: Optional[Dict[str, bytes]] = None, resource_options: Optional[reagent.workflow.types.ResourceOptions] = None) Optional[reagent.data.reagent_data_module.ReAgentDataModule]

Return the data module. If this is not None, then run_feature_identification & query_data will not be run.

get_reporter()
get_state_preprocessing_options() reagent.workflow.types.PreprocessingOptions
reader_options: Optional[reagent.workflow.types.ReaderOptions] = None
save_critic_bool: bool = True
property state_feature_config: reagent.core.types.ModelFeatureConfig
state_feature_config_provider: reagent.workflow.types.ModelFeatureConfigProvider__Union
state_preprocessing_options: Optional[reagent.workflow.types.PreprocessingOptions] = None
class reagent.model_managers.actor_critic_base.ActorCriticDataModule(*args: Any, **kwargs: Any)

Bases: reagent.data.manual_data_module.ManualDataModule

build_batch_preprocessor() reagent.preprocessing.batch_preprocessor.BatchPreprocessor
query_data(input_table_spec: reagent.workflow.types.TableSpec, sample_range: Optional[Tuple[float, float]], reward_options: reagent.workflow.types.RewardOptions, data_fetcher: reagent.data.data_fetcher.DataFetcher) reagent.workflow.types.Dataset

Massage input table into the format expected by the trainer

run_feature_identification(input_table_spec: reagent.workflow.types.TableSpec) Dict[str, reagent.core.parameters.NormalizationData]

Derive preprocessing parameters from data.

property should_generate_eval_dataset: bool
class reagent.model_managers.actor_critic_base.ActorPolicyWrapper(actor_network)

Bases: reagent.gym.policies.policy.Policy

Actor’s forward function is our act

act(obs: reagent.core.types.FeatureData, possible_actions_mask: Optional[torch.Tensor] = None) reagent.core.types.ActorOutput

Performs the composition described above. These are the actions being put into the replay buffer, not necessary the actions taken by the environment!

reagent.model_managers.discrete_dqn_base module

class reagent.model_managers.discrete_dqn_base.DiscreteDQNBase(target_action_distribution: Optional[List[float]] = None, state_feature_config_provider: reagent.workflow.types.ModelFeatureConfigProvider__Union = <factory>, preprocessing_options: Optional[reagent.workflow.types.PreprocessingOptions] = None, reader_options: Optional[reagent.workflow.types.ReaderOptions] = None, eval_parameters: reagent.core.parameters.EvaluationParameters = <factory>)

Bases: reagent.model_managers.model_manager.ModelManager

abstract property action_names: List[str]
create_policy(trainer_module: reagent.training.reagent_lightning_module.ReAgentLightningModule, serving: bool = False, normalization_data_map: Optional[Dict[str, reagent.core.parameters.NormalizationData]] = None) reagent.gym.policies.policy.Policy

Create an online DiscreteDQN Policy from env.

eval_parameters: reagent.core.parameters.EvaluationParameters
get_data_module(*, input_table_spec: Optional[reagent.workflow.types.TableSpec] = None, reward_options: Optional[reagent.workflow.types.RewardOptions] = None, reader_options: Optional[reagent.workflow.types.ReaderOptions] = None, setup_data: Optional[Dict[str, bytes]] = None, saved_setup_data: Optional[Dict[str, bytes]] = None, resource_options: Optional[reagent.workflow.types.ResourceOptions] = None) Optional[reagent.data.reagent_data_module.ReAgentDataModule]

Return the data module. If this is not None, then run_feature_identification & query_data will not be run.

get_reporter()
get_state_preprocessing_options() reagent.workflow.types.PreprocessingOptions
property multi_steps: Optional[int]
preprocessing_options: Optional[reagent.workflow.types.PreprocessingOptions] = None
reader_options: Optional[reagent.workflow.types.ReaderOptions] = None
abstract property rl_parameters: reagent.core.parameters.RLParameters
property state_feature_config: reagent.core.types.ModelFeatureConfig
state_feature_config_provider: reagent.workflow.types.ModelFeatureConfigProvider__Union
target_action_distribution: Optional[List[float]] = None
class reagent.model_managers.discrete_dqn_base.DiscreteDqnDataModule(*args: Any, **kwargs: Any)

Bases: reagent.data.manual_data_module.ManualDataModule

build_batch_preprocessor() reagent.preprocessing.batch_preprocessor.BatchPreprocessor
query_data(input_table_spec: reagent.workflow.types.TableSpec, sample_range: Optional[Tuple[float, float]], reward_options: reagent.workflow.types.RewardOptions, data_fetcher: reagent.data.data_fetcher.DataFetcher) reagent.workflow.types.Dataset

Massage input table into the format expected by the trainer

run_feature_identification(input_table_spec: reagent.workflow.types.TableSpec) Dict[str, reagent.core.parameters.NormalizationData]

Derive preprocessing parameters from data.

property should_generate_eval_dataset: bool

reagent.model_managers.model_manager module

class reagent.model_managers.model_manager.ModelManager

Bases: object

ModelManager manages how to train models.

Each type of models can have their own config type, implemented as config_type() class method. __init__() of the concrete class must take this type.

To integrate training algorithms into the standard training workflow, you need: 1. build_trainer(): Builds the ReAgentLightningModule 2. get_data_module(): Defines how to create data module for this algorithm 3. build_serving_modules(): Creates the TorchScript modules for serving 4. get_reporter(): Returns the reporter to collect training/evaluation metrics 5. create_policy(): (Optional) Creates Policy object for to interact with Gym

build_serving_module(trainer_module: reagent.training.reagent_lightning_module.ReAgentLightningModule, normalization_data_map: Dict[str, reagent.core.parameters.NormalizationData]) torch.nn.modules.module.Module

Optionaly, implement this method if you only have one model for serving

build_serving_modules(trainer_module: reagent.training.reagent_lightning_module.ReAgentLightningModule, normalization_data_map: Dict[str, reagent.core.parameters.NormalizationData]) Dict[str, torch.nn.modules.module.Module]

Returns TorchScript for serving in production

abstract build_trainer(normalization_data_map: Dict[str, reagent.core.parameters.NormalizationData], use_gpu: bool, reward_options: Optional[reagent.workflow.types.RewardOptions] = None) reagent.training.reagent_lightning_module.ReAgentLightningModule

Implement this to build the trainer, given the config

TODO: This function should return ReAgentLightningModule & the dictionary of modules created

create_policy(trainer_module: reagent.training.reagent_lightning_module.ReAgentLightningModule, serving: bool = False, normalization_data_map: Optional[Dict[str, reagent.core.parameters.NormalizationData]] = None)
get_data_module(*, input_table_spec: Optional[reagent.workflow.types.TableSpec] = None, reward_options: Optional[reagent.workflow.types.RewardOptions] = None, setup_data: Optional[Dict[str, bytes]] = None, saved_setup_data: Optional[Dict[str, bytes]] = None, reader_options: Optional[reagent.workflow.types.ReaderOptions] = None, resource_options: Optional[reagent.workflow.types.ResourceOptions] = None) Optional[reagent.data.reagent_data_module.ReAgentDataModule]

Return the data module. If this is not None, then run_feature_identification & query_data will not be run.

abstract get_reporter() reagent.reporting.reporter_base.ReporterBase
serving_module_names() List[str]

Returns the keys that would be returned in build_serving_modules(). This method is required because we need to reserve entity IDs for these serving modules before we start the training.

train(trainer_module: reagent.training.reagent_lightning_module.ReAgentLightningModule, train_dataset: Optional[reagent.workflow.types.Dataset], eval_dataset: Optional[reagent.workflow.types.Dataset], test_dataset: Optional[reagent.workflow.types.Dataset], data_module: Optional[reagent.data.reagent_data_module.ReAgentDataModule], num_epochs: int, reader_options: reagent.workflow.types.ReaderOptions, resource_options: reagent.workflow.types.ResourceOptions, checkpoint_path: Optional[str] = None) Tuple[reagent.workflow.types.RLTrainingOutput, pytorch_lightning.trainer.trainer.Trainer]

Train the model

Returns partially filled RLTrainingOutput. The field that should not be filled are: - output_path

Parameters
  • train/eval/test_dataset – what you’d expect

  • data_module – [pytorch lightning only] a lightning data module that replaces the use of train/eval datasets

  • num_epochs – number of training epochs

  • reader_options – options for the data reader

  • resource_options – options for training resources (currently only used for setting num_nodes in pytorch lightning trainer)

reagent.model_managers.parametric_dqn_base module

class reagent.model_managers.parametric_dqn_base.ParametricDQNBase(state_preprocessing_options: Optional[reagent.workflow.types.PreprocessingOptions] = None, action_preprocessing_options: Optional[reagent.workflow.types.PreprocessingOptions] = None, state_float_features: Optional[List[Tuple[int, str]]] = None, action_float_features: Optional[List[Tuple[int, str]]] = None, reader_options: Optional[reagent.workflow.types.ReaderOptions] = None, eval_parameters: reagent.core.parameters.EvaluationParameters = <factory>)

Bases: reagent.model_managers.model_manager.ModelManager

property action_feature_config: reagent.core.types.ModelFeatureConfig
action_float_features: Optional[List[Tuple[int, str]]] = None
action_preprocessing_options: Optional[reagent.workflow.types.PreprocessingOptions] = None
create_policy(trainer_module: reagent.training.reagent_lightning_module.ReAgentLightningModule, serving: bool = False, normalization_data_map: Optional[Dict[str, reagent.core.parameters.NormalizationData]] = None)

Create an online DiscreteDQN Policy from env.

eval_parameters: reagent.core.parameters.EvaluationParameters
reader_options: Optional[reagent.workflow.types.ReaderOptions] = None
property state_feature_config: reagent.core.types.ModelFeatureConfig
state_float_features: Optional[List[Tuple[int, str]]] = None
state_preprocessing_options: Optional[reagent.workflow.types.PreprocessingOptions] = None
class reagent.model_managers.parametric_dqn_base.ParametricDqnDataModule(*args: Any, **kwargs: Any)

Bases: reagent.data.manual_data_module.ManualDataModule

build_batch_preprocessor() reagent.preprocessing.batch_preprocessor.BatchPreprocessor
query_data(input_table_spec: reagent.workflow.types.TableSpec, sample_range: Optional[Tuple[float, float]], reward_options: reagent.workflow.types.RewardOptions, data_fetcher: reagent.data.data_fetcher.DataFetcher) reagent.workflow.types.Dataset

Massage input table into the format expected by the trainer

run_feature_identification(input_table_spec: reagent.workflow.types.TableSpec) Dict[str, reagent.core.parameters.NormalizationData]

Derive preprocessing parameters from data.

property should_generate_eval_dataset: bool

reagent.model_managers.slate_q_base module

class reagent.model_managers.slate_q_base.SlateQBase(slate_feature_id: int = 0, slate_score_id: Tuple[int, int] = (0, 0), item_preprocessing_options: Optional[reagent.workflow.types.PreprocessingOptions] = None, state_preprocessing_options: Optional[reagent.workflow.types.PreprocessingOptions] = None, state_float_features: Optional[List[Tuple[int, str]]] = None, item_float_features: Optional[List[Tuple[int, str]]] = None)

Bases: reagent.model_managers.model_manager.ModelManager

create_policy(trainer_module: reagent.training.reagent_lightning_module.ReAgentLightningModule, serving: bool = False, normalization_data_map: Optional[Dict[str, reagent.core.parameters.NormalizationData]] = None)
get_reporter()
property item_feature_config: reagent.core.types.ModelFeatureConfig
item_float_features: Optional[List[Tuple[int, str]]] = None
item_preprocessing_options: Optional[reagent.workflow.types.PreprocessingOptions] = None
slate_feature_id: int = 0
slate_score_id: Tuple[int, int] = (0, 0)
property state_feature_config: reagent.core.types.ModelFeatureConfig
state_float_features: Optional[List[Tuple[int, str]]] = None
state_preprocessing_options: Optional[reagent.workflow.types.PreprocessingOptions] = None

reagent.model_managers.union module

Register all ModelManagers. Must import them before filling union.

class reagent.model_managers.union.ModelManager__Union(SAC: Optional[reagent.model_managers.actor_critic.sac.SAC] = None, TD3: Optional[reagent.model_managers.actor_critic.td3.TD3] = None, DiscreteC51DQN: Optional[reagent.model_managers.discrete.discrete_c51dqn.DiscreteC51DQN] = None, DiscreteCRR: Optional[reagent.model_managers.discrete.discrete_crr.DiscreteCRR] = None, DiscreteDQN: Optional[reagent.model_managers.discrete.discrete_dqn.DiscreteDQN] = None, DiscreteQRDQN: Optional[reagent.model_managers.discrete.discrete_qrdqn.DiscreteQRDQN] = None, CrossEntropyMethod: Optional[reagent.model_managers.model_based.cross_entropy_method.CrossEntropyMethod] = None, Seq2RewardModel: Optional[reagent.model_managers.model_based.seq2reward_model.Seq2RewardModel] = None, WorldModel: Optional[reagent.model_managers.model_based.world_model.WorldModel] = None, SyntheticReward: Optional[reagent.model_managers.model_based.synthetic_reward.SyntheticReward] = None, ParametricDQN: Optional[reagent.model_managers.parametric.parametric_dqn.ParametricDQN] = None, PPO: Optional[reagent.model_managers.policy_gradient.ppo.PPO] = None, Reinforce: Optional[reagent.model_managers.policy_gradient.reinforce.Reinforce] = None, SlateQ: Optional[reagent.model_managers.ranking.slate_q.SlateQ] = None)

Bases: reagent.core.tagged_union.TaggedUnion

CrossEntropyMethod: Optional[reagent.model_managers.model_based.cross_entropy_method.CrossEntropyMethod] = None
DiscreteC51DQN: Optional[reagent.model_managers.discrete.discrete_c51dqn.DiscreteC51DQN] = None
DiscreteCRR: Optional[reagent.model_managers.discrete.discrete_crr.DiscreteCRR] = None
DiscreteDQN: Optional[reagent.model_managers.discrete.discrete_dqn.DiscreteDQN] = None
DiscreteQRDQN: Optional[reagent.model_managers.discrete.discrete_qrdqn.DiscreteQRDQN] = None
PPO: Optional[reagent.model_managers.policy_gradient.ppo.PPO] = None
ParametricDQN: Optional[reagent.model_managers.parametric.parametric_dqn.ParametricDQN] = None
Reinforce: Optional[reagent.model_managers.policy_gradient.reinforce.Reinforce] = None
SAC: Optional[reagent.model_managers.actor_critic.sac.SAC] = None
Seq2RewardModel: Optional[reagent.model_managers.model_based.seq2reward_model.Seq2RewardModel] = None
SlateQ: Optional[reagent.model_managers.ranking.slate_q.SlateQ] = None
SyntheticReward: Optional[reagent.model_managers.model_based.synthetic_reward.SyntheticReward] = None
TD3: Optional[reagent.model_managers.actor_critic.td3.TD3] = None
WorldModel: Optional[reagent.model_managers.model_based.world_model.WorldModel] = None

reagent.model_managers.world_model_base module

class reagent.model_managers.world_model_base.WorldModelBase(reward_boost: Optional[Dict[str, float]] = None)

Bases: reagent.model_managers.model_manager.ModelManager

reward_boost: Optional[Dict[str, float]] = None
class reagent.model_managers.world_model_base.WorldModelDataModule(*args: Any, **kwargs: Any)

Bases: reagent.data.manual_data_module.ManualDataModule

build_batch_preprocessor() reagent.preprocessing.batch_preprocessor.BatchPreprocessor
query_data(input_table_spec: reagent.workflow.types.TableSpec, sample_range: Optional[Tuple[float, float]], reward_options: reagent.workflow.types.RewardOptions, data_fetcher: reagent.data.data_fetcher.DataFetcher) reagent.workflow.types.Dataset

Massage input table into the format expected by the trainer

run_feature_identification(input_table_spec: reagent.workflow.types.TableSpec) Dict[str, reagent.core.parameters.NormalizationData]

Derive preprocessing parameters from data.

property should_generate_eval_dataset: bool

Module contents