Как и в приведенном ниже коде, я пытаюсь получить среднее значение для 47 повторных выходов модели.Но это всегда из памяти.Если я уберу z_proto_class_list.append(z_proto_class)
, то это нормально.Я думаю, это потому, что память освобождается, если я не добавляю тензор.Я всегда пытаюсь сгенерировать выход 47 одновременно, но он явно потребляет больше памяти, чем мой текущий выбор.Есть ли способ решить мою текущую проблему?Спасибо.
z_proto_class_list = []
for support_input_ids, support_input_mask, support_segment_ids in dataloader:
s_z, s_pooled_output = model(support_input_ids, support_input_mask, support_segment_ids, output_all_encoded_layers=False)
sz_dim = s_z.size(-1)
index = torch.LongTensor(support_idx_list).unsqueeze(1).unsqueeze(2).expand(len(support_idx_list),1,sz_dim).cuda()
z_proto_raw = torch.gather(s_z,1,index)
z_proto_class = z_proto_raw.view(1,n_support, sz_dim).mean(1)
z_proto_class_list.append(z_proto_class)
torch.cuda.empty_cache()
z_proto = torch.cat(z_proto_class_list, 0)