Тензор индексного питора с массивом индексов другой размерности - PullRequest
0 голосов
/ 05 января 2019

У меня есть следующая функция, которая делает то, что я хочу, используя numpy.array, но прерывается при подаче torch.Tensor из-за ошибок индексации.

import torch
import numpy as np


def combination_matrix(arr):
    idxs = np.arange(len(arr))
    idx = np.ix_(idxs, idxs)
    mesh = np.stack(np.meshgrid(idxs, idxs))

    def np_combination_matrix():
        output = np.zeros((len(arr), len(arr), 2, *arr.shape[1:]), dtype=arr.dtype)
        num_dims = len(output.shape)
        output[idx] = arr[mesh].transpose((2, 1, 0, *np.arange(3, num_dims)))
        return output

    def torch_combination_matrix():
        output = torch.zeros(len(arr), len(arr), 2, *arr.shape[1:], dtype=arr.dtype)
        num_dims = len(output.shape)
        print(arr[mesh].shape)  # <-- This is wrong/different to numpy!
        output[idx] = arr[mesh].permute(2, 1, 0, *np.arange(3, num_dims))
        return output

    if isinstance(arr, np.ndarray):
        return np_combination_matrix()
    elif isinstance(arr, torch.Tensor):
        return torch_combination_matrix()

Проблема в том, что arr[mesh] приводит к различным размерам, в зависимости от куска и резака. По-видимому, pytorch не поддерживает индексирование с помощью массивов индекса, отличных от индексируемого массива. В идеале должно работать следующее:

features = np.arange(9).reshape(3, 3)
np_combs = combination_matrix(features)
features = torch.from_numpy(features)
torch_combs = combination_matrix(features)
assert np.array_equal(np_combs, torch_combs.numpy())

Но размеры бывают разные:

(2, 3, 3, 3)
torch.Size([3, 3])

Что приводит к ошибке (логически):

Traceback (most recent call last):
  File "/home/XXX/util.py", line 226, in <module>
    torch_combs = combination_matrix(features)
  File "/home/XXX/util.py", line 218, in combination_matrix
    return torch_combination_matrix()
  File "/home/XXX/util.py", line 212, in torch_combination_matrix
    output[idx] = arr[mesh].permute(2, 1, 0, *np.arange(3, num_dims))
RuntimeError: number of dims don't match in permute

Как мне сопоставить поведение факела с numpy? Я читал различные вопросы на форумах по факелам (например, этот только с одним измерением ), но мог найти здесь, как применить это. Точно так же index_select работает только для одного измерения, но мне нужно, чтобы оно работало как минимум для 2 измерений.

1 Ответ

0 голосов
/ 06 января 2019

Это на самом деле неловко легко. Вам просто нужно сгладить индексы, затем изменить форму и изменить размеры. Это полная рабочая версия:

import torch
import numpy as np


def combination_matrix(arr):
    idxs = np.arange(len(arr))
    idx = np.ix_(idxs, idxs)
    mesh = np.stack(np.meshgrid(idxs, idxs))

    def np_combination_matrix():
        output = np.zeros((len(arr), len(arr), 2, *arr.shape[1:]), dtype=arr.dtype)
        num_dims = len(output.shape)
        output[idx] = arr[mesh].transpose((2, 1, 0, *np.arange(3, num_dims)))
        return output

    def torch_combination_matrix():
        output_shape = (2, len(arr), len(arr), *arr.shape[1:])  # Note that this is different to numpy!
        return arr[mesh.flatten()].reshape(output_shape).permute(2, 1, 0, *range(3, len(output_shape)))

    if isinstance(arr, np.ndarray):
        return np_combination_matrix()
    elif isinstance(arr, torch.Tensor):
        return torch_combination_matrix()

Я использовал pytest для запуска этого на случайных массивах разных измерений, и, похоже, он работает во всех случаях:

import pytest

@pytest.mark.parametrize('random_dims', range(1, 5))
def test_combination_matrix(random_dims):
    dim_size = np.random.randint(1, 40, size=random_dims)
    elements = np.random.random(size=dim_size)
    np_combs = combination_matrix(elements)
    features = torch.from_numpy(elements)
    torch_combs = combination_matrix(features)

    assert np.array_equal(np_combs, torch_combs.numpy())

if __name__ == '__main__':
    pytest.main(['-x', __file__])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...