дифференциация aten :: mse_loss не поддерживается или отсутствует необходимая информация о типе - PullRequest
0 голосов
/ 14 мая 2019

Я использую torch_xla в Google Colab

Моя сеть представляет собой простую сеть, подобную этой:

class EmbeddingNet(nn.Module):


 def __init__(self, n_users, n_movies,
             n_factors=50, embedding_dropout=0.02, 
            dropouts=0.2):

    super().__init__()


    self.u = nn.Embedding(n_users, n_factors)
    self.m = nn.Embedding(n_movies, n_factors)
    self.drop = nn.Dropout(embedding_dropout)

    self.fc = nn.Linear(n_factors, 1)


 def forward(self, x):
    x = torch.cat([self.u(x[:,0])*self.m(x[:,1])],1)
    x = self.drop(x)
    x = torch.relu(self.fc(x))
    return x

Инициализируйте сеть:

net = EmbeddingNet(
n_users=n, n_movies=m, 
n_factors=50)

print(net)

EmbeddingNet(
  (u): Embedding(138493, 50)
  (m): Embedding(26744, 50)
  (drop): Dropout(p=0.02)
  (fc): Linear(in_features=50, out_features=1, bias=True)
)

Затем я попыталсячтобы преобразовать его в модель xla:

devices = [':{}'.format(n) for n in range(0, 8)]
inputs = torch.zeros(2000,2).long()
target = torch.zeros(2000, dtype=torch.float)
import torch_xla_py.xla_model as xm
xla_model = xm.XlaModel(
      net, [inputs],
      loss_fn=F.mse_loss,
      target=target,
      num_cores=8,
      devices=devices)

Тогда я получаю ошибку:

    /usr/local/lib/python3.6/dist-packages/torch_xla_py/xla_model.py in __init__(self, model, inputs, target, loss_fn, num_cores, devices, loader_prefetch, full_conv_precision)
    496           devices=devices,
    497           input_gradients=loss_output_grads,
--> 498           full_conv_precision=full_conv_precision)
    499     else:
    500       self._xla_model, self._traced_model = create_xla_model(

/usr/local/lib/python3.6/dist-packages/torch_xla_py/xla_model.py in create_xla_model(model, inputs, num_cores, devices, input_gradients, full_conv_precision)
    235   if input_gradients is not None:
    236     xla_model.set_input_gradients(input_gradients)
--> 237   xla_model(*inputs_xla)
    238   return xla_model, traced_model
    239 

RuntimeError: differentiation of aten::mse_loss is not supported, or it is missing necessary type information

Я попытался l1_loss, который дает мне тот же результат.

Я такжепопробовал альтернативный способ его инициализации:

import torch_xla
traced_model = torch.jit.trace(net, (inputs, target))
xla_model = torch_xla._XLAC.XlaModule(traced_model)
devices = [':{}'.format(n) for n in range(0, 8)]
inputs = torch.zeros(2000,2).long()
target = torch.zeros(2000, dtype=torch.float)
import torch_xla
traced_model = torch.jit.trace(net, (inputs, target))
xla_model = torch_xla._XLAC.XlaModule(traced_model)
output_xla = xla_model((torch_xla._XLAC.XLATensor(inputs), torch_xla._XLAC.XLATensor(target)))

Возвращает ту же ошибку: дифференциация aten :: mse_loss не поддерживается или отсутствует необходимая информация о типе.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...