1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
| class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super(MultiHeadAttention, self).__init__() assert d_model % num_heads == 0, "d_model must be divisible by num_heads" self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None): batch_size = q.size(0)
q = self.W_q(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) k = self.W_k(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) v = self.W_v(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attn = F.softmax(scores, dim=-1) output = torch.matmul(attn, v)
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
output = self.W_o(output) return output
|