pytorch xla: типы элементов операндов для Pad не совпадают - PullRequest
1 голос
/ 02 марта 2020

(отредактировано для предоставления и объяснения минимального воспроизводимого примера)

Я вижу нижеприведенную ошибку при использовании обратного хука с pytorch xla.

  • ошибка не видна когда pytorch-xla заменяется обычным pytorch (он же pytorch cuda).
  • и ошибка не обнаруживается в pytorch-xla, когда строка, копирующая градиент, закомментирована в обратном хуке.
Traceback (most recent call last):
  File "test1.py", line 30, in <module>
    l.backward()
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/tensor.py", line 198, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Error while lowering: f32[1,2,16,16]{3,2,1,0} aten::constant_pad_nd, pad=[0, -1, 0, -1, 0, 0, 0, 0], value=0
XLA builder error: Invalid argument: The element types of the operands to Pad do not match.:
Python Frames:

Минимальный код, который создает эту ошибку:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

def loss(output, target):
    l = torch.sum(output - target)
    return l

model = torch.nn.Sequential(
          # minimal model to reproduce the error
          torch.nn.ConstantPad2d((0, 1, 0, 1), 0),
          torch.nn.Conv2d(1, 2, kernel_size=(3, 3), stride=(2, 2)),
          torch.nn.ConstantPad2d((0, 1, 0, 1), 0),
          torch.nn.Conv2d(2, 2, kernel_size=(3, 3), stride=(2, 2))
        )
model = model.to(xm.xla_device())

def dummyHook(module, gradIn, gradOut):
    # error is not seen if i comment out the below line
    g = gradOut[0].cpu()
    print(str(module))

x = torch.randn((1, 1, 32, 32), device=xm.xla_device(), dtype=torch.float)
target = torch.ones((1, 2, 8, 8), device=xm.xla_device(), dtype=torch.float)

for module in model.modules():
    module.register_backward_hook(dummyHook)

y = model(x)
l = loss(y, target)
l.backward()

что может быть не так?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...