У меня есть следующая функция, которая делает то, что я хочу, используя numpy.array
, но прерывается при подаче torch.Tensor
из-за ошибок индексации.
import torch
import numpy as np
def combination_matrix(arr):
idxs = np.arange(len(arr))
idx = np.ix_(idxs, idxs)
mesh = np.stack(np.meshgrid(idxs, idxs))
def np_combination_matrix():
output = np.zeros((len(arr), len(arr), 2, *arr.shape[1:]), dtype=arr.dtype)
num_dims = len(output.shape)
output[idx] = arr[mesh].transpose((2, 1, 0, *np.arange(3, num_dims)))
return output
def torch_combination_matrix():
output = torch.zeros(len(arr), len(arr), 2, *arr.shape[1:], dtype=arr.dtype)
num_dims = len(output.shape)
print(arr[mesh].shape) # <-- This is wrong/different to numpy!
output[idx] = arr[mesh].permute(2, 1, 0, *np.arange(3, num_dims))
return output
if isinstance(arr, np.ndarray):
return np_combination_matrix()
elif isinstance(arr, torch.Tensor):
return torch_combination_matrix()
Проблема в том, что arr[mesh]
приводит к различным размерам, в зависимости от куска и резака. По-видимому, pytorch не поддерживает индексирование с помощью массивов индекса, отличных от индексируемого массива. В идеале должно работать следующее:
features = np.arange(9).reshape(3, 3)
np_combs = combination_matrix(features)
features = torch.from_numpy(features)
torch_combs = combination_matrix(features)
assert np.array_equal(np_combs, torch_combs.numpy())
Но размеры бывают разные:
(2, 3, 3, 3)
torch.Size([3, 3])
Что приводит к ошибке (логически):
Traceback (most recent call last):
File "/home/XXX/util.py", line 226, in <module>
torch_combs = combination_matrix(features)
File "/home/XXX/util.py", line 218, in combination_matrix
return torch_combination_matrix()
File "/home/XXX/util.py", line 212, in torch_combination_matrix
output[idx] = arr[mesh].permute(2, 1, 0, *np.arange(3, num_dims))
RuntimeError: number of dims don't match in permute
Как мне сопоставить поведение факела с numpy?
Я читал различные вопросы на форумах по факелам (например, этот только с одним измерением ), но мог найти здесь, как применить это. Точно так же index_select работает только для одного измерения, но мне нужно, чтобы оно работало как минимум для 2 измерений.