Я написал следующие коды для обеспечения парного внимания на c уровне. `
def bi_linear_attn(self, Q, K):
"""
:param Q: 1*b*d
:param K: n*b*d
:return: (for relative part and irrelative part)
"""
v = K.permute(1, 2, 0)
# b*1*n && b*s*n
# here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
interaction = torch.bmm(Q.permute(1, 0, 2), torch.bmm(
self.bi_linear_para.unsqueeze(0).expand(v.shape[0], *self.bi_linear_para.shape), v))
print(interaction.shape)
direct_score = F.softmax(interaction, dim=2)
print(direct_score.shape)
inverse_score = 1 - direct_score
h_dir = torch.bmm(direct_score, K.permute(1, 0, 2))
h_inv = torch.bmm(inverse_score, K.permute(1, 0, 2))
# b * 1 * d
return h_dir, h_inv
def forward(self, bid2emb):
blk_num = len(bid2emb)
tuple_lst = sorted([(bid, emb_tuple[0], emb_tuple[1]) for bid, emb_tuple in bid2emb.items()])
id_zip, h_zip, c_zip = zip(*tuple_lst)
h_stack = torch.cat([h.repeat(1, blk_num, 1) for h in list(h_zip)], dim=1)
# -- pad context
c_lst_padded = []
stc_len = [c.shape[0] for c in list(c_zip)]
max_stc_len = max(stc_len)
for c in list(c_zip):
if c.shape[0] < max_stc_len:
c_pad = F.pad(c, (0, 0, 0, 0, 0, max_stc_len - c.shape[0]), 'constant', 0.)
else:
c_pad = c
c_lst_padded.append(c_pad)
c_stack = torch.cat(c_lst_padded * blk_num, dim=1)
h_dir, h_inv = self.bi_linear_attn(h_stack, c_stack)
# -- get ind
AB_lst, BA_lst = [], []
for bid_A, emb_tuple_A in bid2emb.items():
for bid_B, emb_tuple_B in bid2emb.items():
if bid_A == bid_B:
continue
AB_lst.append(bid_A * blk_num + bid_B)
BA_lst.append(bid_B * blk_num + bid_A)
AB_idx = torch.tensor(AB_lst).to(device)
BA_idx = torch.tensor(BA_lst).to(device)
# n * 1 * d
h_dir = torch.mean(
torch.cat((torch.index_select(h_dir, 0, AB_idx), torch.index_select(h_dir, 0, BA_idx)), dim=1), dim=1,
keepdim=True)
h_inv = torch.mean(
torch.cat((torch.index_select(h_inv, 0, AB_idx), torch.index_select(h_inv, 0, BA_idx)), dim=1), dim=1,
keepdim=True)
print('hdir', h_dir.shape)
return torch.cat((h_dir, h_inv), dim=2)`
В строке 10 я попытался вычислить оценку взаимодействия с x1 A x2
, где x1 и x2 имеют форму 36000 * 1 * 512 и 36000 * max_len * 512. Из-за большого размера пакета matmul не удалось из-за ошибки памяти, после изменения matmul на bmm с расширенным параметром matri c, прямое вычисление работает отлично. Однако ошибка памяти появляется при выполнении обратного распространения. Что меня смущает.
Я был бы очень признателен за любую помощь по этому вопросу, поскольку я полностью озадачен