DeBERTa

DeBERTa Model

class DeBERTa.deberta.DeBERTa(config=None, pre_trained=None)[source]

DeBERTa encoder This module is composed of the input embedding layer with stacked transformer layers with disentangled attention.

Parameters:
  • config – A model config class instance with the configuration to build a new model. The schema is similar to BertConfig, for more details, please refer ModelConfig
  • pre_trained – The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configurations, i.e. [base, large, base_mnli, large_mnli]
apply_state(state=None)[source]

Load state from previous loaded model state dictionary.

Parameters:state (dict, optional) – State dictionary as the state returned by torch.module.state_dict(), default: None. If it’s None, then will use the pre-trained state loaded via the constructor to re-initialize the DeBERTa model
forward(input_ids, attention_mask=None, token_type_ids=None, output_all_encoded_layers=True, position_ids=None, return_att=False)[source]
Parameters:
  • input_ids – a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary
  • attention_mask

    an optional parameter for input mask or attention mask.

    • If it’s an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It’s a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It’s the mask that we typically use for attention when a batch has varying length sentences.
    • If it’s an attention mask then it will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. In this case, it’s a mask indicate which tokens in the sequence should be attended by other tokens in the sequence.
  • token_type_ids – an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a sentence A and type 1 corresponds to a sentence B token (see BERT paper for more details).
  • output_all_encoded_layers – whether to output results of all encoder layers, default, True
Returns:

  • The output of the stacked transformer layers if output_all_encoded_layers=True, else the last layer of stacked transformer layers
  • Attention matrix of self-attention layers if return_att=True

Example:

# Batch of wordPiece token ids.
# Each sample was padded with zero to the maxium length of the batch
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
# Mask of valid input ids
attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])

# DeBERTa model initialized with pretrained base model
bert = DeBERTa(pre_trained='base')

encoder_layers = bert(input_ids, attention_mask=attention_mask)

NNModule

class DeBERTa.deberta.NNModule(config, *inputs, **kwargs)[source]

An abstract class to handle weights initialization and a simple interface for dowloading and loading pretrained models.

Parameters:config (ModelConfig) – The model config to the module
init_weights(module)[source]

Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module.

Parameters:module (torch.nn.Module) – The module to apply the initialization.

Example:

class MyModule(NNModule):
  def __init__(self, config):
    # Add construction instructions
    self.bert = DeBERTa(config)

    # Add other modules
    ...

    # Apply initialization
    self.apply(self.init_weights)
classmethod load_model(model_path, model_config=None, tag=None, no_cache=False, cache_dir=None, *inputs, **kwargs)[source]

Instantiate a sub-class of NNModule from a pre-trained model file.

Parameters:
  • model_path (str) –

    Path or name of the pre-trained model which can be either,

    • The path of pre-trained model
    • The pre-trained DeBERTa model name in DeBERTa GitHub releases, i.e. [base, base_mnli, large, large_mnli].

    If model_path is None or -, then the method will create a new sub-class without initialing from pre-trained models.

  • model_config (str) –

    The path of model config file. If it’s None, then the method will try to find the the config in order:

    1. [‘config’] in the model state dictionary.
    2. model_config.json aside the model_path.

    If it failed to find a config the method will fail.

  • tag (str, optional) – The release tag of DeBERTa, default: None.
  • no_cache (bool, optional) – Disable local cache of downloaded models, default: False.
  • cache_dir (str, optional) – The cache directory used to save the downloaded models, default: None. If it’s None, then the models will be saved at $HOME/.~DeBERTa
Returns:

The sub-class object.

Return type:

NNModule

DisentangledSelfAttention

class DeBERTa.deberta.DisentangledSelfAttention(config)[source]

Disentangled self-attention module

Parameters:config (str) – A model config class instance with the configuration to build a new model. The schema is similar to BertConfig, for more details, please refer ModelConfig
forward(hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None)[source]

Call the module

Parameters:
  • hidden_states (torch.FloatTensor) – Input states to the module usally the output from previous layer, it will be the Q,K and V in Attention(Q,K,V)
  • attention_mask (torch.ByteTensor) – An attention mask matrix of shape [B, N, N] where B is the batch size, N is the maxium sequence length in which element [i,j] = 1 means the i th token in the input can attend to the j th token.
  • return_att (bool, optional) – Whether return the attention maxitrix.
  • query_states (torch.FloatTensor, optional) – The Q state in Attention(Q,K,V).
  • relative_pos (torch.LongTensor) – The relative position encoding between the tokens in the sequence. It’s of shape [B, N, N] with values ranging in [-max_relative_positions, max_relative_positions].
  • rel_embeddings (torch.FloatTensor) – The embedding of relative distances. It’s a tensor of shape [\(2 \times \text{max_relative_positions}\), hidden_size].
DeBERTa.deberta.build_relative_position(query_size, key_size, device)[source]

Build relative position according to the query and key

We assume the absolute position of query \(P_q\) is range from (0, query_size) and the absolute position of key \(P_k\) is range from (0, key_size), The relative positions from query to key is

\(R_{q \rightarrow k} = P_q - P_k\)

Parameters:
  • query_size (int) – the length of query
  • key_size (int) – the length of key
Returns:

A tensor with shape [1, query_size, key_size]

Return type:

torch.LongTensor

ContextPooler

class DeBERTa.deberta.ContextPooler(config)[source]

BertEncoder

class DeBERTa.deberta.BertEncoder(config)[source]

Modified BertEncoder with relative position bias support

BertLayerNorm

class DeBERTa.deberta.BertLayerNorm(size, eps=1e-12)[source]

LayerNorm module in the TF style (epsilon inside the square root).

XSoftmax

class DeBERTa.deberta.XSoftmax

Masked Softmax which is optimized for saving memory

Parameters:
  • input (torch.tensor) – The input tensor that will apply softmax.
  • mask (torch.IntTensor) – The mask matrix where 0 indicate that element will be ignored in the softmax caculation.
  • dim (int) – The dimenssion that will apply softmax.

Example:

import torch
from DeBERTa.deberta import XSoftmax
# Make a tensor
x = torch.randn([4,20,100])
# Create a mask
mask = (x>0).int()
y = XSoftmax.apply(x, mask, dim=-1)

StableDropout

class DeBERTa.deberta.StableDropout(drop_prob)[source]

Optimized dropout module for stabilizing the training

Parameters:drop_prob (float) – the dropout probabilities
forward(x)[source]

Call the module

Parameters:x (torch.tensor) – The input tensor to apply dropout

MaskedLayerNorm

DeBERTa.deberta.MaskedLayerNorm(layerNorm, input, mask=None)[source]

Masked LayerNorm which will apply mask over the output of LayerNorm to avoid inaccurate updatings to the LayerNorm module.

Parameters:
  • layernorm (BertLayerNorm) – LayerNorm module or function
  • input (torch.tensor) – The input tensor
  • mask (torch.IntTensor) – The mask to applied on the output of LayerNorm where 0 indicate the output of that element will be ignored, i.e. set to 0

Example:

# Create a tensor b x n x d
x = torch.randn([1,10,100])
m = torch.tensor([[1,1,1,0,0,0,0,0,0,0]], dtype=torch.int)
LayerNorm = DeBERTa.deberta.BertLayerNorm(100)
y = MaskedLayerNorm(LayerNorm, x, m)

GPT2Tokenizer

class DeBERTa.deberta.GPT2Tokenizer(vocab_file=None, do_lower_case=True, special_tokens=None)[source]

A wrapper of GPT2 tokenizer with similar interface as BERT tokenizer

Parameters:
  • vocab_file (str, optional) –

    The local path of vocabulary package or the release name of vocabulary in DeBERTa GitHub releases, e.g. “bpe_encoder”, default: None.

    If it’s None, then it will download the vocabulary in the latest release from GitHub. The vocabulary file is a state dictionary with three items, “dict_map”, “vocab”, “encoder” which correspond to three files used in RoBERTa, i.e. dict.txt, vocab.txt and encoder.json. The difference between our wrapped GPT2 tokenizer and RoBERTa wrapped tokenizer are,
    • Special tokens, unlike RoBERTa which use <s>, </s> as the start token and end token of a sentence. We use [CLS] and [SEP] as the start and end token of input sentence which is the same as BERT.
    • We remapped the token ids in our dictionary with regarding to the new special tokens, [PAD] => 0, [CLS] => 1, [SEP] => 2, [UNK] => 3, [MASK] => 50264
  • do_lower_case (bool, optional) – Whether to convert inputs to lower case. Not used in GPT2 tokenizer.
  • special_tokens (list, optional) – List of special tokens to be added to the end of the vocabulary.
add_special_token(token)[source]

Adds a special token to the dictionary.

Parameters:token (str) – Tthe new token/word to be added to the vocabulary.
Returns:The id of new token in the vocabulary.
add_symbol(word, n=1)[source]

Adds a word to the dictionary.

Parameters:
  • word (str) – Tthe new token/word to be added to the vocabulary.
  • n (int, optional) – The frequency of the word.
Returns:

The id of the new word.

convert_ids_to_tokens(ids)[source]

Convert list of ids to tokens.

Parameters:ids (list) – list of ids
Returns:List of tokens
convert_tokens_to_ids(tokens)[source]

Convert list of tokens to ids.

Parameters:tokens (list) – list of tokens
Returns:List of ids
decode(tokens)[source]

Decode list of tokens to text strings.

Parameters:tokens (list) – list of tokens.
Returns:Text string corresponds to the input tokens.

Example:

>>> tokenizer = GPT2Tokenizer()
>>> text = "Hello world!"
>>> tokens = tokenizer.tokenize(text)
>>> print(tokens)
['15496', '995', '0']

>>> tokenizer.decode(tokens)
'Hello world!'
tokenize(text)[source]

Convert an input text to tokens.

Parameters:text (str) – input text to be tokenized.
Returns:A list of byte tokens where each token represent the byte id in GPT2 byte dictionary

Example:

>>> tokenizer = GPT2Tokenizer()
>>> text = "Hello world!"
>>> tokens = tokenizer.tokenize(text)
>>> print(tokens)
['15496', '995', '0']

ModelConfig

class DeBERTa.deberta.ModelConfig[source]

Configuration class to store the configuration of a DeBERTa model.

hidden_size

Size of the encoder layers and the pooler layer, default: 768.

Type:int
num_hidden_layers

Number of hidden layers in the Transformer encoder, default: 12.

Type:int
num_attention_heads

Number of attention heads for each attention layer in the Transformer encoder, default: 12.

Type:int
intermediate_size

The size of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder, default: 3072.

Type:int
hidden_act

The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu” and “swish” are supported, default: gelu.

Type:str
hidden_dropout_prob

The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler, default: 0.1.

Type:float
attention_probs_dropout_prob

The dropout ratio for the attention probabilities, default: 0.1.

Type:float
max_position_embeddings

The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048), default: 512.

Type:int
type_vocab_size

The vocabulary size of the token_type_ids passed into DeBERTa model, default: -1.

Type:int
initializer_range

The sttdev of the _normal_initializer for initializing all weight matrices, default: 0.02.

Type:int
relative_attention

Whether use relative position encoding, default: False.

Type:bool
max_relative_positions

The range of relative positions [-max_position_embeddings, max_position_embeddings], default: -1, use the same value as max_position_embeddings.

Type:int
padding_idx

The value used to pad input_ids, default: 0.

Type:int
position_biased_input

Whether add absolute position embedding to content embedding, default: True.

Type:bool
pos_att_type

The type of relative position attention, it can be a combination of [p2c, c2p, p2p], e.g. “p2c”, “p2c|c2p”, “p2c|c2p|p2p”., default: “None”.

Type:str

PoolConfig

class DeBERTa.deberta.PoolConfig(config=None)[source]

Configuration class to store the configuration of pool layer.

Parameters:config (ModelConfig) – The model config. The field of pool config will be initalized with the pooling field in model config.
hidden_size

Size of the encoder layers and the pooler layer, default: 768.

Type:int
dropout

The dropout rate applied on the output of [CLS] token,

Type:float
hidden_act

The activation function of the projection layer, it can be one of [‘gelu’, ‘tanh’].

Type:str

Example:

# Here is the content of an exmple model config file in json format

    {
      "hidden_size": 768,
      "num_hidden_layers" 12,
      "num_attention_heads": 12,
      "intermediate_size": 3072,
      ...
      "pooling": {
        "hidden_size":  768,
        "hidden_act": "gelu",
        "dropout": 0.1
      }
    }