Получена ошибка UnboundLocalError при измерении времени для torch.where. - PullRequest
0 голосов
/ 25 мая 2020

При попытке %timeit в Jupyter Notebook возникла ошибка; Без него работает нормально.

UnboundLocalError: local variable 'a' referenced before assignment

import torch
a = torch.rand(10)
b = torch.rand(10)
%timeit a = torch.where(b > 0.5, torch.tensor(0.), a)

Что здесь происходит?

1 Ответ

1 голос
/ 25 мая 2020

Сначала я подумал, что %timeit оценивает только время, выполняемое функциями. Но спасибо @Shiva, который сказал мне, что он может вычислять время выполнения других вещей. И я проверил документацию здесь и выяснил, что это правда.

Итак, согласно этому ответ , %timeit имеет проблему с повторным назначением поскольку повторное присвоение a заставляет функцию иметь локальную переменную a, скрывая глобальную. Другими словами, вы можете использовать любую другую переменную, кроме a, чтобы присвоить ее torch.where:

#this works
%timeit c = torch.where(b > 0.5, torch.tensor(0.), a) #c instead of a

# this works
%timeit torch.where(b > 0.5, torch.tensor(0.), a)

# this doesn't work
%timeit a = torch.where(b > 0.5, torch.tensor(0.), a)
...