В параллельном режиме данных Pytorch, как использовать глобальный тензор? - PullRequest
0 голосов
/ 22 апреля 2019

В этом примере хотелось бы, чтобы z_proto мог быть глобальным для разных графических процессоров.Однако в режиме параллельной передачи данных он также разделяется на разные графические процессоры.Как решить такую ​​проблему?Спасибо.

class SequencePrototypeTokenClassification(nn.Module):
    def __init__(self,seq_model, label_num):
        super(SequencePrototypeTokenClassification, self).__init__()
        self.seq_model = seq_model
        self.label_num = label_num

    def forward(self, input_ids, token_type_ids, attention_mask, labels, z_proto, n_query, target_inds):
        z, _ = self.seq_model(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
        z_dim = z.size(-1)
        zq = z.squeeze().view(-1, z_dim)
        dists = euclidean_dist(zq, z_proto)
        log_p_y = F.log_softmax(-dists, dim=1).view(-1, self.label_num)
        loss_val = -log_p_y.gather(1, self.target_inds).squeeze().view(-1).mean()
        _, y_hat = log_p_y.max(1)

        return loss_val, y_hat

Ответы [ 2 ]

1 голос
/ 22 апреля 2019

Исходя из приведенного выше кода, z_proto представляется одним из аргументов функции forward, а не частью модели. Поэтому простое сохранение его в tensor на основном графическом процессоре позволило бы ему иметь одинаковое значение для всех графических процессоров.

Редактировать

Исходя из документации , кажется, что DataParallel разделяет все входы для функции прямого прохода между графическими процессорами. Метод, с помощью которого вы можете обойти это, храня его как переменную класса внутри самого объекта модели. Вы можете обновить значение перед вызовом функции forward, если это не статическая переменная.

class SequencePrototypeTokenClassification(nn.Module):
    def __init__(self,seq_model, label_num):
        ...
        self.z_proto = None
        ...
        ...


#Training loop
    ...
    model.z_proto = value
    model.forward()
    ...


0 голосов
/ 22 апреля 2019

Оказывается, DataParallel будет копировать только nn.Parameter из nn.Module. Поэтому я случайно инициализировал nn.Parameter с именем z_proto в модуле и скопировал значение параметра tenors z_proto в параметр. Затем параметр реплицируется в 4 графических процессора.

...