Source code for DeBERTa.deberta.deberta

# Copyright (c) Microsoft, Inc. 2020
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Author:
# Date: 01/15/2020

import copy
import torch
import os

import json
from .ops import *
from .bert import *
from .config import ModelConfig
from .cache_utils import load_model_state

__all__ = ['DeBERTa']

[docs]class DeBERTa(torch.nn.Module): """ DeBERTa encoder This module is composed of the input embedding layer with stacked transformer layers with disentangled attention. Parameters: config: A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`, \ for more details, please refer :class:`~DeBERTa.deberta.ModelConfig` pre_trained: The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configurations, \ i.e. [**base, large, base_mnli, large_mnli**] """ def __init__(self, config=None, pre_trained=None): super().__init__() if config: self.z_steps = getattr(config, 'z_steps', 0) else: self.z_steps = 0 state = None if pre_trained is not None: state, model_config = load_model_state(pre_trained) if config is not None and model_config is not None: for k in config.__dict__: if k not in ['hidden_size', 'intermediate_size', 'num_attention_heads', 'num_hidden_layers', 'vocab_size', 'max_position_embeddings']: model_config.__dict__[k] = config.__dict__[k] config = copy.copy(model_config) self.embeddings = BertEmbeddings(config) self.encoder = BertEncoder(config) self.config = config self.pre_trained = pre_trained self.apply_state(state)
[docs] def forward(self, input_ids, attention_mask=None, token_type_ids=None, output_all_encoded_layers=True, position_ids = None, return_att = False): """ Args: input_ids: a torch.LongTensor of shape [batch_size, sequence_length] \ with the word token indices in the vocabulary attention_mask: an optional parameter for input mask or attention mask. - If it's an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices \ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max \ input sequence length in the current batch. It's the mask that we typically use for attention when \ a batch has varying length sentences. - If it's an attention mask then it will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. \ In this case, it's a mask indicate which tokens in the sequence should be attended by other tokens in the sequence. token_type_ids: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token \ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to \ a `sentence B` token (see BERT paper for more details). output_all_encoded_layers: whether to output results of all encoder layers, default, True Returns: - The output of the stacked transformer layers if `output_all_encoded_layers=True`, else \ the last layer of stacked transformer layers - Attention matrix of self-attention layers if `return_att=True` Example:: # Batch of wordPiece token ids. # Each sample was padded with zero to the maxium length of the batch input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) # Mask of valid input ids attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) # DeBERTa model initialized with pretrained base model bert = DeBERTa(pre_trained='base') encoder_layers = bert(input_ids, attention_mask=attention_mask) """ if attention_mask is None: attention_mask = torch.ones_like(input_ids) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) embedding_output = self.embeddings(,, position_ids, attention_mask) encoded_layers = self.encoder(embedding_output, attention_mask, output_all_encoded_layers=output_all_encoded_layers, return_att = return_att) if return_att: encoded_layers, att_matrixs = encoded_layers if self.z_steps>1: hidden_states = encoded_layers[-2] layers = [self.encoder.layer[-1] for _ in range(z_steps)] query_states = encoded_layers[-1] rel_embeddings = self.encoder.get_rel_embedding() attention_mask = self.encoder.get_attention_mask(attention_mask) rel_pos = self.encoder.get_rel_pos(embedding_output) for layer in layers[1:]: query_states = layer(hidden_states, attention_mask, return_att=False, query_states = query_states, relative_pos=rel_pos, rel_embeddings=rel_embeddings) encoded_layers.append(query_states) if not output_all_encoded_layers: encoded_layers = encoded_layers[-1:] if return_att: return encoded_layers, att_matrixs return encoded_layers
[docs] def apply_state(self, state = None): """ Load state from previous loaded model state dictionary. Args: state (:obj:`dict`, optional): State dictionary as the state returned by torch.module.state_dict(), default: `None`. \ If it's `None`, then will use the pre-trained state loaded via the constructor to re-initialize \ the `DeBERTa` model """ if self.pre_trained is None and state is None: return if state is None: state, config = load_model_state(self.pre_trained) self.config = config def key_match(key, s): c = [k for k in s if key in k] assert len(c)==1, c return c[0] current = self.state_dict() for c in current.keys(): current[c] = state[key_match(c, state.keys())] self.load_state_dict(current)