reagent.models package

Submodules module

reagent.models.base module

class reagent.models.base.ModelBase(*args, **kwargs)

Bases: torch.nn.Module

A base class to support exporting through ONNX


Override this in DistributedDataParallel models

feature_config() → Optional[reagent.types.ModelFeatureConfig]

If the model needs additional preprocessing, e.g., using sequence features, returns the config here.


Return DistributedDataParallel version of this model

This needs to be implemented explicitly because: 1) Model with EmbeddingBag module is not compatible with vanilla DistributedDataParallel 2) Exporting logic needs structured data. DistributedDataParallel doesn’t work with structured data.


Return a copy of this network to be used as target network

Subclass should override this if the target network should share parameters with the network to be trained.

input_prototype() → Any

This function provides the input for ONNX graph tracing.

The return value should be what expected by forward().

reagent.models.bcq module

reagent.models.categorical_dqn module

reagent.models.cem_planner module

reagent.models.containers module

reagent.models.convolutional_network module

reagent.models.critic module

reagent.models.dqn module

reagent.models.dueling_q_network module

reagent.models.embedding_bag_concat module

reagent.models.fully_connected_network module

class reagent.models.fully_connected_network.FullyConnectedNetwork(layers, activations, *, use_batch_norm=False, min_std=0.0, dropout_ratio=0.0, use_layer_norm=False, normalize_output=False)

Bases: reagent.models.base.ModelBase

forward(input: torch.Tensor) → torch.Tensor

Forward pass for generic feed-forward DNNs. Assumes activation names are valid pytorch activation names. :param input tensor


This function provides the input for ONNX graph tracing.

The return value should be what expected by forward().

reagent.models.fully_connected_network.gaussian_fill_w_gain(tensor, activation, dim_in, min_std=0.0) → None

Gaussian initialization with gain.

reagent.models.mdn_rnn module

reagent.models.model_feature_config_provider module

reagent.models.no_soft_update_embedding module

reagent.models.seq2reward_model module

reagent.models.seq2slate module

reagent.models.seq2slate_reward module

reagent.models.world_model module

Module contents