reagent.gym.policies.samplers package
Submodules
reagent.gym.policies.samplers.continuous_sampler module
- class reagent.gym.policies.samplers.continuous_sampler.GaussianSampler(actor_network)
Bases:
reagent.gym.types.Sampler
- log_prob(scores: reagent.gym.types.GaussianSamplerScore, squashed_action: torch.Tensor) torch.Tensor
- sample_action(scores: reagent.gym.types.GaussianSamplerScore) reagent.core.types.ActorOutput
reagent.gym.policies.samplers.discrete_sampler module
- class reagent.gym.policies.samplers.discrete_sampler.EpsilonGreedyActionSampler(epsilon: float, epsilon_decay: float = 1.0, minimum_epsilon: float = 0.0)
Bases:
reagent.gym.types.Sampler
Epsilon-Greedy Policy
With probability epsilon, a random action is sampled. Otherwise, the highest scoring (greedy) action is chosen.
Call update() to decay the amount of exploration by lowering epsilon by a factor of epsilon_decay (<=1) until we reach minimum_epsilon
- log_prob(scores: torch.Tensor, action: torch.Tensor) torch.Tensor
- sample_action(scores: torch.Tensor) reagent.core.types.ActorOutput
- update() None
Call to update internal parameters (e.g. decay epsilon)
- class reagent.gym.policies.samplers.discrete_sampler.GreedyActionSampler
Bases:
reagent.gym.types.Sampler
Return the highest scoring action.
- log_prob(scores: torch.Tensor, action: torch.Tensor) torch.Tensor
- sample_action(scores: torch.Tensor) reagent.core.types.ActorOutput
- class reagent.gym.policies.samplers.discrete_sampler.SoftmaxActionSampler(temperature: float = 1.0, temperature_decay: float = 1.0, minimum_temperature: float = 0.1)
Bases:
reagent.gym.types.Sampler
Softmax sampler. Equation: http://incompleteideas.net/book/first/ebook/node17.html The action scores are logits. Supports decaying the temperature over time.
- Parameters
temperature – A measure of how uniformly random the distribution looks. The higher the temperature, the more uniform the sampling.
temperature_decay – A multiplier by which temperature is reduced at each .update() call
minimum_temperature – Minimum temperature, below which the temperature is not decayed further
- entropy(scores: torch.Tensor) torch.Tensor
Returns average policy entropy. Simple unweighted average across the batch.
- log_prob(scores: torch.Tensor, action: torch.Tensor) torch.Tensor
- sample_action(scores: torch.Tensor) reagent.core.types.ActorOutput
- update() None
Call to update internal parameters (e.g. decay epsilon)
reagent.gym.policies.samplers.top_k_sampler module
- class reagent.gym.policies.samplers.top_k_sampler.TopKSampler(k: int)
Bases:
reagent.gym.types.Sampler
- log_prob(scores: torch.Tensor, action: torch.Tensor) torch.Tensor
- sample_action(scores: torch.Tensor) reagent.core.types.ActorOutput