CUDA не хватает памяти во время loss.backward (), torch.bmm отлично работает при выполнении прямого вычисления - PullRequest
1 голос
/ 18 апреля 2020

Я написал следующие коды для обеспечения парного внимания на c уровне. `

def bi_linear_attn(self, Q, K):
    """
    :param Q: 1*b*d
    :param K: n*b*d
    :return:  (for relative part and irrelative part)
    """
    v = K.permute(1, 2, 0)
    # b*1*n && b*s*n
    # here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    interaction = torch.bmm(Q.permute(1, 0, 2), torch.bmm(
        self.bi_linear_para.unsqueeze(0).expand(v.shape[0], *self.bi_linear_para.shape), v))

    print(interaction.shape)
    direct_score = F.softmax(interaction, dim=2)
    print(direct_score.shape)
    inverse_score = 1 - direct_score
    h_dir = torch.bmm(direct_score, K.permute(1, 0, 2))
    h_inv = torch.bmm(inverse_score, K.permute(1, 0, 2))
    # b * 1 * d
    return h_dir, h_inv

def forward(self, bid2emb):
    blk_num = len(bid2emb)
    tuple_lst = sorted([(bid, emb_tuple[0], emb_tuple[1]) for bid, emb_tuple in bid2emb.items()])
    id_zip, h_zip, c_zip = zip(*tuple_lst)
    h_stack = torch.cat([h.repeat(1, blk_num, 1) for h in list(h_zip)], dim=1)
    # -- pad context
    c_lst_padded = []
    stc_len = [c.shape[0] for c in list(c_zip)]
    max_stc_len = max(stc_len)
    for c in list(c_zip):
        if c.shape[0] < max_stc_len:
            c_pad = F.pad(c, (0, 0, 0, 0, 0, max_stc_len - c.shape[0]), 'constant', 0.)
        else:
            c_pad = c
        c_lst_padded.append(c_pad)
    c_stack = torch.cat(c_lst_padded * blk_num, dim=1)
    h_dir, h_inv = self.bi_linear_attn(h_stack, c_stack)
    # -- get ind
    AB_lst, BA_lst = [], []
    for bid_A, emb_tuple_A in bid2emb.items():
        for bid_B, emb_tuple_B in bid2emb.items():
            if bid_A == bid_B:
                continue
            AB_lst.append(bid_A * blk_num + bid_B)
            BA_lst.append(bid_B * blk_num + bid_A)
    AB_idx = torch.tensor(AB_lst).to(device)
    BA_idx = torch.tensor(BA_lst).to(device)
    # n * 1 * d
    h_dir = torch.mean(
        torch.cat((torch.index_select(h_dir, 0, AB_idx), torch.index_select(h_dir, 0, BA_idx)), dim=1), dim=1,
        keepdim=True)
    h_inv = torch.mean(
        torch.cat((torch.index_select(h_inv, 0, AB_idx), torch.index_select(h_inv, 0, BA_idx)), dim=1), dim=1,
        keepdim=True)
    print('hdir', h_dir.shape)
    return torch.cat((h_dir, h_inv), dim=2)` 

В строке 10 я попытался вычислить оценку взаимодействия с x1 A x2, где x1 и x2 имеют форму 36000 * 1 * 512 и 36000 * max_len * 512. Из-за большого размера пакета matmul не удалось из-за ошибки памяти, после изменения matmul на bmm с расширенным параметром matri c, прямое вычисление работает отлично. Однако ошибка памяти появляется при выполнении обратного распространения. Что меня смущает.

Я был бы очень признателен за любую помощь по этому вопросу, поскольку я полностью озадачен

...