попробуйте grid_sample:
torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros')
С учетом входных данных и сетки поля потока вычисляет выходные данные с использованием входных значений и местоположений в пикселях из сетки.
Для каждого выходного местоположения вывод [n,:, h, w] , вектор размера 2 сетка [n, h, w] определяет местоположения входных пикселей x и y, которые используются для интерполяции выходного выходного значения [n,:, h, w]. mode аргумент указывает метод ближайшей или билинейной интерполяции для выборки входных пикселей.
координата должна находиться в диапазоне [- 1, 1] . Это связано с тем, что положения пикселей нормализуются входными пространственными измерениями.
пример сэмплера git
документация по pytorch