reagent.prediction package

Subpackages

Submodules

reagent.prediction.predictor_wrapper module

reagent.prediction.predictor_wrapper.ActorPredictorUnwrapper

alias of reagent.prediction.predictor_wrapper.OSSSparsePredictorUnwrapper

class reagent.prediction.predictor_wrapper.ActorPredictorWrapper(actor_with_preprocessor: reagent.prediction.predictor_wrapper.ActorWithPreprocessor, state_feature_config: reagent.core.types.ModelFeatureConfig, action_feature_ids: List[int] = [])

Bases: torch.jit._script.ScriptModule

forward: Callable[[...], Any]
training: bool
class reagent.prediction.predictor_wrapper.ActorWithPreprocessor(model: reagent.models.base.ModelBase, state_preprocessor: reagent.preprocessing.preprocessor.Preprocessor, state_feature_config: reagent.core.types.ModelFeatureConfig, action_postprocessor: Optional[reagent.preprocessing.postprocessor.Postprocessor] = None, serve_mean_policy: bool = False)

Bases: reagent.models.base.ModelBase

This is separate from ActorPredictorWrapper so that we can pass typed inputs into the model. This is possible because JIT only traces tensor operation. In contrast, JIT scripting needs to compile the code, therefore, it won’t recognize any custom Python type.

forward(state: reagent.core.types.ServingFeatureData)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

input_prototype()

This function provides the input for ONNX graph tracing.

The return value should be what expected by forward().

training: bool
class reagent.prediction.predictor_wrapper.BinaryDifferenceScorerPredictorWrapper(binary_difference_scorer_with_preprocessor: reagent.prediction.predictor_wrapper.BinaryDifferenceScorerWithPreprocessor, state_feature_config: reagent.core.types.ModelFeatureConfig)

Bases: torch.jit._script.ScriptModule

forward: Callable[[...], Any]
training: bool
class reagent.prediction.predictor_wrapper.BinaryDifferenceScorerWithPreprocessor(model: reagent.models.base.ModelBase, state_preprocessor: reagent.preprocessing.preprocessor.Preprocessor, state_feature_config: reagent.core.types.ModelFeatureConfig)

Bases: reagent.models.base.ModelBase

This is separated from DiscreteDqnPredictorWrapper so that we can pass typed inputs into the model. This is possible because JIT only traces tensor operation. In contrast, JIT scripting needs to compile the code, therefore, it won’t recognize any custom Python type.

forward(state: reagent.core.types.ServingFeatureData)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

input_prototype()

This function provides the input for ONNX graph tracing.

The return value should be what expected by forward().

training: bool
class reagent.prediction.predictor_wrapper.CompressModelWithPreprocessor(model: reagent.models.base.ModelBase, state_preprocessor: reagent.preprocessing.preprocessor.Preprocessor, state_feature_config: reagent.core.types.ModelFeatureConfig)

Bases: reagent.prediction.predictor_wrapper.DiscreteDqnWithPreprocessor

forward(state: reagent.core.types.ServingFeatureData)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
reagent.prediction.predictor_wrapper.DiscreteDqnPredictorUnwrapper

alias of reagent.prediction.predictor_wrapper.OSSSparsePredictorUnwrapper

class reagent.prediction.predictor_wrapper.DiscreteDqnPredictorWrapper(dqn_with_preprocessor: reagent.prediction.predictor_wrapper.DiscreteDqnWithPreprocessor, action_names: List[str], state_feature_config: reagent.core.types.ModelFeatureConfig)

Bases: torch.jit._script.ScriptModule

forward: Callable[[...], Any]
training: bool
class reagent.prediction.predictor_wrapper.DiscreteDqnWithPreprocessor(model: reagent.models.base.ModelBase, state_preprocessor: reagent.preprocessing.preprocessor.Preprocessor, state_feature_config: reagent.core.types.ModelFeatureConfig)

Bases: reagent.models.base.ModelBase

This is separated from DiscreteDqnPredictorWrapper so that we can pass typed inputs into the model. This is possible because JIT only traces tensor operation. In contrast, JIT scripting needs to compile the code, therefore, it won’t recognize any custom Python type.

forward(state: reagent.core.types.ServingFeatureData)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

input_prototype()

This function provides the input for ONNX graph tracing.

The return value should be what expected by forward().

training: bool
class reagent.prediction.predictor_wrapper.LearnVMSlateWithPreprocessor(mlp: torch.nn.modules.module.Module, state_preprocessor: reagent.preprocessing.preprocessor.Preprocessor, candidate_preprocessor: reagent.preprocessing.preprocessor.Preprocessor)

Bases: reagent.models.base.ModelBase

forward(state_vp, candidate_vp)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

input_prototype()

This function provides the input for ONNX graph tracing.

The return value should be what expected by forward().

training: bool
class reagent.prediction.predictor_wrapper.MDNRNNWithPreprocessor(model: reagent.models.base.ModelBase, state_preprocessor: reagent.preprocessing.preprocessor.Preprocessor, seq_len: int, num_action: int, state_feature_config: Optional[reagent.core.types.ModelFeatureConfig] = None)

Bases: reagent.models.base.ModelBase

forward(state_with_presence: Tuple[torch.Tensor, torch.Tensor], action: torch.Tensor)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

input_prototype()

This function provides the input for ONNX graph tracing.

The return value should be what expected by forward().

training: bool
class reagent.prediction.predictor_wrapper.OSSPredictorUnwrapper(model: torch.nn.modules.module.Module)

Bases: torch.nn.modules.module.Module

forward(*args, **kwargs) Tuple[List[str], torch.Tensor]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class reagent.prediction.predictor_wrapper.OSSSparsePredictorUnwrapper(model: torch.nn.modules.module.Module)

Bases: torch.nn.modules.module.Module

forward(state_with_presence: Tuple[torch.Tensor, torch.Tensor], state_id_list_features: Dict[int, Tuple[torch.Tensor, torch.Tensor]], state_id_score_list_features: Dict[int, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]) Tuple[List[str], torch.Tensor]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
reagent.prediction.predictor_wrapper.ParametricDqnPredictorUnwrapper

alias of reagent.prediction.predictor_wrapper.OSSPredictorUnwrapper

class reagent.prediction.predictor_wrapper.ParametricDqnPredictorWrapper(dqn_with_preprocessor: reagent.prediction.predictor_wrapper.ParametricDqnWithPreprocessor)

Bases: torch.jit._script.ScriptModule

forward: Callable[[...], Any]
training: bool
class reagent.prediction.predictor_wrapper.ParametricDqnWithPreprocessor(model: reagent.models.base.ModelBase, state_preprocessor: reagent.preprocessing.preprocessor.Preprocessor, action_preprocessor: reagent.preprocessing.preprocessor.Preprocessor)

Bases: reagent.models.base.ModelBase

forward(state_with_presence: Tuple[torch.Tensor, torch.Tensor], action_with_presence: Tuple[torch.Tensor, torch.Tensor])

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

input_prototype()

This function provides the input for ONNX graph tracing.

The return value should be what expected by forward().

training: bool
class reagent.prediction.predictor_wrapper.RankingActorPredictorWrapper(actor_with_preprocessor: reagent.prediction.predictor_wrapper.RankingActorWithPreprocessor, action_feature_ids: List[int])

Bases: torch.jit._script.ScriptModule

forward: Callable[[...], Any]
training: bool
class reagent.prediction.predictor_wrapper.RankingActorWithPreprocessor(model: reagent.models.base.ModelBase, state_preprocessor: reagent.preprocessing.preprocessor.Preprocessor, candidate_preprocessor: reagent.preprocessing.preprocessor.Preprocessor, num_candidates: int, action_postprocessor: Optional[reagent.preprocessing.postprocessor.Postprocessor] = None)

Bases: reagent.models.base.ModelBase

forward(state_with_presence: Tuple[torch.Tensor, torch.Tensor], candidate_with_presence_list: List[Tuple[torch.Tensor, torch.Tensor]])

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

input_prototype()

This function provides the input for ONNX graph tracing.

The return value should be what expected by forward().

training: bool
class reagent.prediction.predictor_wrapper.Seq2RewardPlanShortSeqWithPreprocessor(model: reagent.models.base.ModelBase, step_model: reagent.models.base.ModelBase, state_preprocessor: reagent.preprocessing.preprocessor.Preprocessor, seq_len: int, num_action: int)

Bases: reagent.prediction.predictor_wrapper.DiscreteDqnWithPreprocessor

forward(state: reagent.core.types.ServingFeatureData)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class reagent.prediction.predictor_wrapper.Seq2RewardWithPreprocessor(model: reagent.models.base.ModelBase, state_preprocessor: reagent.preprocessing.preprocessor.Preprocessor, seq_len: int, num_action: int)

Bases: reagent.prediction.predictor_wrapper.DiscreteDqnWithPreprocessor

forward(state: reagent.core.types.ServingFeatureData)

This serving module only takes in current state. We need to simulate all multi-step length action seq’s then predict accumulated reward on all those seq’s. After that, we categorize all action seq’s by their first actions. Then take the maximum reward as the predicted categorical reward for that category. Return: categorical reward for the first action

training: bool
class reagent.prediction.predictor_wrapper.Seq2SlatePredictorWrapper(seq2slate_with_preprocessor: reagent.prediction.predictor_wrapper.Seq2SlateWithPreprocessor)

Bases: torch.jit._script.ScriptModule

forward: Callable[[...], Any]
training: bool
class reagent.prediction.predictor_wrapper.Seq2SlateRewardWithPreprocessor(model: reagent.models.seq2slate_reward.Seq2SlateRewardNetBase, state_preprocessor: reagent.preprocessing.preprocessor.Preprocessor, candidate_preprocessor: reagent.preprocessing.preprocessor.Preprocessor)

Bases: reagent.models.base.ModelBase

property candidate_sorted_features: List[int]
forward(state_with_presence: Tuple[torch.Tensor, torch.Tensor], candidate_with_presence: Tuple[torch.Tensor, torch.Tensor])

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

input_prototype()

This function provides the input for ONNX graph tracing.

The return value should be what expected by forward().

property state_sorted_features: List[int]
training: bool
class reagent.prediction.predictor_wrapper.Seq2SlateWithPreprocessor(model: reagent.models.seq2slate.Seq2SlateTransformerNet, state_preprocessor: reagent.preprocessing.preprocessor.Preprocessor, candidate_preprocessor: reagent.preprocessing.preprocessor.Preprocessor, greedy: bool)

Bases: torch.nn.modules.module.Module

can_be_traced()

Whether this module can be serialized by jit.trace. In production, we find jit.trace may have faster performance than jit.script. The models that can be traced are those don’t have for-loop in inference, since we want to deal with inputs of variable lengths. The models that can’t be traced are those with iterative decoder, i.e., autoregressive or non-greedy frechet-sort.

forward(state_with_presence: Tuple[torch.Tensor, torch.Tensor], candidate_with_presence: Tuple[torch.Tensor, torch.Tensor])

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

input_prototype()
training: bool
class reagent.prediction.predictor_wrapper.SlateRankingPreprocessor(state_preprocessor: reagent.preprocessing.preprocessor.Preprocessor, candidate_preprocessor: reagent.preprocessing.preprocessor.Preprocessor, candidate_size: int)

Bases: reagent.models.base.ModelBase

forward(state_with_presence: Tuple[torch.Tensor, torch.Tensor], candidate_with_presence: Tuple[torch.Tensor, torch.Tensor])

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

input_prototype()

This function provides the input for ONNX graph tracing.

The return value should be what expected by forward().

training: bool
reagent.prediction.predictor_wrapper.serving_to_feature_data(serving: reagent.core.types.ServingFeatureData, dense_preprocessor: reagent.preprocessing.preprocessor.Preprocessor, sparse_preprocessor: reagent.preprocessing.sparse_preprocessor.SparsePreprocessor) reagent.core.types.FeatureData
reagent.prediction.predictor_wrapper.sparse_input_prototype(model: reagent.models.base.ModelBase, state_preprocessor: reagent.preprocessing.preprocessor.Preprocessor, state_feature_config: reagent.core.types.ModelFeatureConfig)

Module contents