reagent.model_utils package

Submodules

reagent.model_utils.seq2slate_utils module

class reagent.model_utils.seq2slate_utils.Seq2SlateMode(value)

Bases: enum.Enum

An enumeration.

DECODE_ONE_STEP_MODE = 'decode_one_step'
ENCODER_SCORE_MODE = 'encoder_score_mode'
PER_SEQ_LOG_PROB_MODE = 'per_sequence_log_prob'
PER_SYMBOL_LOG_PROB_DIST_MODE = 'per_symbol_log_prob_dist'
RANK_MODE = 'rank'
class reagent.model_utils.seq2slate_utils.Seq2SlateOutputArch(value)

Bases: enum.Enum

An enumeration.

AUTOREGRESSIVE = 'autoregressive'
ENCODER_SCORE = 'encoder_score'
FRECHET_SORT = 'frechet_sort'
reagent.model_utils.seq2slate_utils.attention(query, key, value, mask, d_k)

Scaled Dot Product Attention

reagent.model_utils.seq2slate_utils.clones(module, N)

Produce N identical layers.

Parameters
  • module – nn.Module class

  • N – number of copies

reagent.model_utils.seq2slate_utils.mask_logits_by_idx(logits, tgt_in_idx)
reagent.model_utils.seq2slate_utils.per_symbol_to_per_seq_log_probs(per_symbol_log_probs, tgt_out_idx)

Gather per-symbol log probabilities into per-seq log probabilities

reagent.model_utils.seq2slate_utils.per_symbol_to_per_seq_probs(per_symbol_probs, tgt_out_idx)

Gather per-symbol probabilities into per-seq probabilities

reagent.model_utils.seq2slate_utils.print_model_info(seq2slate)
reagent.model_utils.seq2slate_utils.pytorch_decoder_mask(memory: torch.Tensor, tgt_in_idx: torch.Tensor, num_heads: int)

Compute the masks used in the PyTorch Transformer-based decoder for self-attention and attention over encoder outputs

mask_ijk = 1 if the item should be ignored; 0 if the item should be paid attention

Input:

memory shape: batch_size, src_seq_len, dim_model tgt_in_idx (+2 offseted) shape: batch_size, tgt_seq_len

Returns

batch_size * num_heads, tgt_seq_len, tgt_seq_len tgt_src_mask shape: batch_size * num_heads, tgt_seq_len, src_seq_len

Return type

tgt_tgt_mask shape

reagent.model_utils.seq2slate_utils.subsequent_and_padding_mask(tgt_in_idx)

Create a mask to hide padding and future items

reagent.model_utils.seq2slate_utils.subsequent_mask(size: int, device: torch.device)

Mask out subsequent positions. Mainly used in the decoding process, in which an item should not attend subsequent items.

mask_ijk = 0 if the item should be ignored; 1 if the item should be paid attention

Module contents