ml.rl.prediction package

Submodules

ml.rl.prediction.dqn_torch_predictor module

class ml.rl.prediction.dqn_torch_predictor.ActorTorchPredictor(model, action_feature_ids: List[int])

Bases: object

actor_prediction(float_state_features: List[Dict[int, float]]) → List[Dict[str, float]]
predict(state_features: List[Dict[int, float]]) → List[Dict[str, float]]
class ml.rl.prediction.dqn_torch_predictor.DiscreteDqnTorchPredictor(model)

Bases: object

discrete_action() → bool
policy(state: torch.Tensor, state_feature_presence: Optional[torch.Tensor] = None, possible_actions_presence: Optional[torch.Tensor] = None) → ml.rl.types.DqnPolicyActionSet
static policy_given_q_values(q_scores: torch.Tensor, action_names: List[str], softmax_temperature: float, possible_actions_presence: Optional[torch.Tensor] = None) → ml.rl.types.DqnPolicyActionSet
policy_net() → bool
predict(state_features: List[Dict[int, float]]) → List[Dict[str, float]]
class ml.rl.prediction.dqn_torch_predictor.ParametricDqnTorchPredictor(model)

Bases: object

discrete_action() → bool
policy(tiled_states: torch.Tensor, possible_actions_with_presence: Tuple[torch.Tensor, torch.Tensor])
static policy_given_q_values(q_scores: torch.Tensor, softmax_temperature: float, possible_actions_presence: torch.Tensor) → ml.rl.types.DqnPolicyActionSet
policy_net() → bool
predict(state_features: List[Dict[int, float]], action_features: List[Dict[int, float]]) → List[Dict[str, float]]

ml.rl.prediction.predictor_wrapper module

class ml.rl.prediction.predictor_wrapper.ActorPredictorWrapper(actor_with_preprocessor: ml.rl.prediction.predictor_wrapper.ActorWithPreprocessor)

Bases: torch.jit.ScriptModule

forward(state_with_presence: Tuple[torch.Tensor, torch.Tensor]) → torch.Tensor
state_sorted_features() → List[int]

This interface is used by ActorTorchPredictor

class ml.rl.prediction.predictor_wrapper.ActorWithPreprocessor(model: ml.rl.models.base.ModelBase, state_preprocessor: ml.rl.preprocessing.preprocessor.Preprocessor, action_postprocessor: Optional[ml.rl.preprocessing.postprocessor.Postprocessor] = None)

Bases: ml.rl.models.base.ModelBase

This is separate from ActorPredictorWrapper so that we can pass typed inputs into the model. This is possible because JIT only traces tensor operation. In contrast, JIT scripting needs to compile the code, therefore, it won’t recognize any custom Python type.

forward(state_with_presence: Tuple[torch.Tensor, torch.Tensor])
input_prototype()

This function provides the input for ONNX graph tracing.

The return value should be what expected by forward().

property sorted_features
class ml.rl.prediction.predictor_wrapper.DiscreteDqnPredictorWrapper(dqn_with_preprocessor: ml.rl.prediction.predictor_wrapper.DiscreteDqnWithPreprocessor, action_names: List[str])

Bases: torch.jit.ScriptModule

forward(state_with_presence: Tuple[torch.Tensor, torch.Tensor]) → Tuple[List[str], torch.Tensor]
state_sorted_features() → List[int]

This interface is used by DiscreteDqnTorchPredictor

class ml.rl.prediction.predictor_wrapper.DiscreteDqnWithPreprocessor(model: ml.rl.models.base.ModelBase, state_preprocessor: ml.rl.preprocessing.preprocessor.Preprocessor)

Bases: ml.rl.models.base.ModelBase

This is separated from DiscreteDqnPredictorWrapper so that we can pass typed inputs into the model. This is possible because JIT only traces tensor operation. In contrast, JIT scripting needs to compile the code, therefore, it won’t recognize any custom Python type.

forward(state_with_presence: Tuple[torch.Tensor, torch.Tensor])
input_prototype()

This function provides the input for ONNX graph tracing.

The return value should be what expected by forward().

property sorted_features
class ml.rl.prediction.predictor_wrapper.ParametricDqnPredictorWrapper(dqn_with_preprocessor: ml.rl.prediction.predictor_wrapper.ParametricDqnWithPreprocessor)

Bases: torch.jit.ScriptModule

action_sorted_features() → List[int]

This interface is used by ParametricDqnTorchPredictor

forward(state_with_presence: Tuple[torch.Tensor, torch.Tensor], action_with_presence: Tuple[torch.Tensor, torch.Tensor]) → Tuple[List[str], torch.Tensor]
state_sorted_features() → List[int]

This interface is used by ParametricDqnTorchPredictor

class ml.rl.prediction.predictor_wrapper.ParametricDqnWithPreprocessor(model: ml.rl.models.base.ModelBase, state_preprocessor: ml.rl.preprocessing.preprocessor.Preprocessor, action_preprocessor: ml.rl.preprocessing.preprocessor.Preprocessor)

Bases: ml.rl.models.base.ModelBase

property action_sorted_features
forward(state_with_presence: Tuple[torch.Tensor, torch.Tensor], action_with_presence: Tuple[torch.Tensor, torch.Tensor])
input_prototype()

This function provides the input for ONNX graph tracing.

The return value should be what expected by forward().

property state_sorted_features
class ml.rl.prediction.predictor_wrapper.Seq2SlatePredictorWrapper(seq2slate_with_preprocessor: ml.rl.prediction.predictor_wrapper.Seq2SlateWithPreprocessor)

Bases: torch.jit.ScriptModule

candidate_sorted_features() → List[int]

This interface is used by Seq2SlateTorchPredictor

forward(state_with_presence: Tuple[torch.Tensor, torch.Tensor], candidate_with_presence: Tuple[torch.Tensor, torch.Tensor]) → Tuple[torch.Tensor, torch.Tensor]
state_sorted_features() → List[int]

This interface is used by Seq2SlateTorchPredictor

class ml.rl.prediction.predictor_wrapper.Seq2SlateWithPreprocessor(model: ml.rl.models.seq2slate.Seq2SlateTransformerNet, state_preprocessor: ml.rl.preprocessing.preprocessor.Preprocessor, candidate_preprocessor: ml.rl.preprocessing.preprocessor.Preprocessor, greedy: bool)

Bases: ml.rl.models.base.ModelBase

property candidate_sorted_features
forward(state_with_presence: Tuple[torch.Tensor, torch.Tensor], candidate_with_presence: Tuple[torch.Tensor, torch.Tensor])
input_prototype()

This function provides the input for ONNX graph tracing.

The return value should be what expected by forward().

property state_sorted_features

Module contents