Я использую следующую функцию для классификации изображения с использованием обученной модели pytorch. Он отлично работает для небольшого входного изображения.
def test(net, STRIDE-16, BATCH_SIZE=20, WINDOW_SIZE= (256,256)):
# Use the network on the test image
img = (1 / 255 * np.asarray(io.imread("C:/bd/R1C1.tif"), dtype='float32'))
all_preds = []
# Switch the network to inference mode
net.eval()
pred = np.zeros(img.shape[:2] + (N_CLASSES,))
for i, coords in enumerate(tqdm(grouper(batch_size, sliding_window(img,
STRIDE, WINDOW_SIZE)))):
# Build the tensor
image_patches = [np.copy(img[x:x + w, y:y + h]).transpose((2, 0, 1)) for
x, y, w, h in coords]
image_patches = np.asarray(image_patches)
image_patches = Variable(torch.from_numpy(image_patches).cpu(),
volatile=True)
# Do the inference
outs = net(image_patches)
outs = outs.data.cpu().numpy()
# Fill in the results array
for out, (x, y, w, h) in zip(outs, coords):
out = out.transpose((1, 2, 0))
pred[x:x + w, y:y + h] += out
del (outs)
pred = np.argmax(pred, axis=-1)
all_preds.append(pred)
return all_preds
Когда я загружаю большое изображение, скажем (40k x 40k), я получаю MemoryError. Как я могу избежать этой ошибки памяти. Может быть, с помощью серии небольших изображений? Как эффективно это реализовать?