# Copyright (c) Microsoft, Inc. 2020
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Author: penhe@microsoft.com
# Date: 01/15/2020
#
"""
Disentangled SelfAttention module
"""
import torch
import math
from .ops import *
__all__=['build_relative_position', 'DisentangledSelfAttention']
[docs]def build_relative_position(query_size, key_size, device):
""" Build relative position according to the query and key
We assume the absolute position of query :math:`P_q` is range from (0, query_size) and the absolute position of key :math:`P_k` is range from (0, key_size),
The relative positions from query to key is
:math:`R_{q \\rightarrow k} = P_q - P_k`
Args:
query_size (int): the length of query
key_size (int): the length of key
Return:
:obj:`torch.LongTensor`: A tensor with shape [1, query_size, key_size]
"""
q_ids = torch.arange(query_size, dtype=torch.long, device=device)
k_ids = torch.arange(key_size, dtype=torch.long, device=device)
rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1)
rel_pos_ids = rel_pos_ids[:query_size, :]
rel_pos_ids = rel_pos_ids.unsqueeze(0)
return rel_pos_ids
@torch.jit.script
def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
@torch.jit.script
def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
@torch.jit.script
def pos_dynamic_expand(pos_index, p2c_att, key_layer):
return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
[docs]class DisentangledSelfAttention(torch.nn.Module):
""" Disentangled self-attention module
Parameters:
config (:obj:`str`):
A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`, \
for more details, please refer :class:`~DeBERTa.deberta.ModelConfig`
"""
def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.in_proj = torch.nn.Linear(config.hidden_size, self.all_head_size*3, bias=False)
self.q_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
self.v_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
self.pos_att_type = [x.strip() for x in getattr(config, 'pos_att_type', 'none').lower().split('|')] # c2p|p2c
self.relative_attention = getattr(config, 'relative_attention', False)
self.talking_head = getattr(config, 'talking_head', False)
if self.talking_head:
self.head_logits_proj = torch.nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False)
self.head_weights_proj = torch.nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False)
if self.relative_attention:
self.max_relative_positions = getattr(config, 'max_relative_positions', -1)
if self.max_relative_positions <1:
self.max_relative_positions = config.max_position_embeddings
self.pos_dropout = StableDropout(config.hidden_dropout_prob)
if 'c2p' in self.pos_att_type or 'p2p' in self.pos_att_type:
self.pos_proj = torch.nn.Linear(config.hidden_size, self.all_head_size, bias=False)
if 'p2c' in self.pos_att_type or 'p2p' in self.pos_att_type:
self.pos_q_proj = torch.nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = StableDropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
[docs] def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None):
""" Call the module
Args:
hidden_states (:obj:`torch.FloatTensor`):
Input states to the module usally the output from previous layer, it will be the Q,K and V in `Attention(Q,K,V)`
attention_mask (:obj:`torch.ByteTensor`):
An attention mask matrix of shape [`B`, `N`, `N`] where `B` is the batch size, `N` is the maxium sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j` th token.
return_att (:obj:`bool`, optional):
Whether return the attention maxitrix.
query_states (:obj:`torch.FloatTensor`, optional):
The `Q` state in `Attention(Q,K,V)`.
relative_pos (:obj:`torch.LongTensor`):
The relative position encoding between the tokens in the sequence. It's of shape [`B`, `N`, `N`] with values ranging in [`-max_relative_positions`, `max_relative_positions`].
rel_embeddings (:obj:`torch.FloatTensor`):
The embedding of relative distances. It's a tensor of shape [:math:`2 \\times \\text{max_relative_positions}`, `hidden_size`].
"""
if query_states is None:
qp = self.in_proj(hidden_states) #.split(self.all_head_size, dim=-1)
query_layer,key_layer,value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1)
else:
def linear(w,b,x):
if b is not None:
return torch.matmul(x, w.t()) + b.t()
else:
return torch.matmul(x, w.t()) # + b.t()
ws = self.in_proj.weight.chunk(self.num_attention_heads*3, dim=0)
qkvw = [torch.cat([ws[i*3+k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)]
qkvb = [None]*3
q = linear(qkvw[0], qkvb[0], query_states)
k,v = [linear(qkvw[i], qkvb[i], hidden_states) for i in range(1,3)]
query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q,k,v]]
query_layer += self.transpose_for_scores(self.q_bias.unsqueeze(0).unsqueeze(0))
value_layer += self.transpose_for_scores(self.v_bias.unsqueeze(0).unsqueeze(0))
rel_att = None
# Take the dot product between "query" and "key" to get the raw attention scores.
scale_factor = 1
if 'c2p' in self.pos_att_type:
scale_factor += 1
if 'p2c' in self.pos_att_type:
scale_factor += 1
if 'p2p' in self.pos_att_type:
scale_factor += 1
scale = math.sqrt(query_layer.size(-1)*scale_factor)
query_layer = query_layer/scale
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings)
rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
if rel_att is not None:
attention_scores = (attention_scores + rel_att)
# bxhxlxd
if self.talking_head:
attention_scores = self.head_logits_proj(attention_scores.permute(0,2,3,1)).permute(0,3,1,2)
attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
attention_probs = self.dropout(attention_probs)
if self.talking_head:
attention_probs = self.head_weights_proj(attention_probs.permute(0,2,3,1)).permute(0,3,1,2)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
context_layer = context_layer.view(*new_context_layer_shape)
if return_att:
return (context_layer, attention_probs)
else:
return context_layer
def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
if relative_pos is None:
q = query_layer.size(-2)
relative_pos = build_relative_position(q, key_layer.size(-2), query_layer.device)
if relative_pos.dim()==2:
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
elif relative_pos.dim()==3:
relative_pos = relative_pos.unsqueeze(1)
# bxhxqxk
elif relative_pos.dim()!=4:
raise ValueError(f'Relative postion ids must be of dim 2 or 3 or 4. {relative_pos.dim()}')
att_span = min(max(query_layer.size(-2), key_layer.size(-2)), self.max_relative_positions)
relative_pos = relative_pos.long().to(query_layer.device)
rel_embeddings = rel_embeddings[self.max_relative_positions - att_span:self.max_relative_positions + att_span, :].unsqueeze(0)
if 'c2p' in self.pos_att_type or 'p2p' in self.pos_att_type:
pos_key_layer = self.pos_proj(rel_embeddings)
pos_key_layer = self.transpose_for_scores(pos_key_layer)
if 'p2c' in self.pos_att_type or 'p2p' in self.pos_att_type:
pos_query_layer = self.pos_q_proj(rel_embeddings)
pos_query_layer = self.transpose_for_scores(pos_query_layer)
score = 0
# content->position
if 'c2p' in self.pos_att_type:
c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2))
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span*2-1)
c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_dynamic_expand(c2p_pos, query_layer, relative_pos))
score += c2p_att
# position->content
if 'p2c' in self.pos_att_type or 'p2p' in self.pos_att_type:
pos_query_layer /= math.sqrt(pos_query_layer.size(-1)*scale_factor)
if query_layer.size(-2) != key_layer.size(-2):
r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device)
else:
r_pos = relative_pos
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span*2-1)
if query_layer.size(-2) != key_layer.size(-2):
pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
if 'p2c' in self.pos_att_type:
p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2))
p2c_att = torch.gather(p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer)).transpose(-1,-2)
if query_layer.size(-2) != key_layer.size(-2):
p2c_att = torch.gather(p2c_att, dim=-2, index=pos_dynamic_expand(pos_index, p2c_att, key_layer))
score += p2c_att
return score