Классификация патчей с большого изображения с использованием pytorch - PullRequest
0 голосов
/ 29 марта 2019

Я использую следующую функцию для классификации изображения с использованием обученной модели pytorch. Он отлично работает для небольшого входного изображения.

def test(net, STRIDE-16, BATCH_SIZE=20, WINDOW_SIZE= (256,256)):

# Use the network on the test image
img = (1 / 255 * np.asarray(io.imread("C:/bd/R1C1.tif"), dtype='float32'))


all_preds = []

# Switch the network to inference mode
net.eval()

pred = np.zeros(img.shape[:2] + (N_CLASSES,))


for i, coords in enumerate(tqdm(grouper(batch_size, sliding_window(img, 
STRIDE, WINDOW_SIZE)))):

# Build the tensor
image_patches = [np.copy(img[x:x + w, y:y + h]).transpose((2, 0, 1)) for 
x, y, w, h in coords]

image_patches = np.asarray(image_patches)
image_patches = Variable(torch.from_numpy(image_patches).cpu(), 
volatile=True)

 # Do the inference
 outs = net(image_patches)
 outs = outs.data.cpu().numpy()

 # Fill in the results array
 for out, (x, y, w, h) in zip(outs, coords):
     out = out.transpose((1, 2, 0))
     pred[x:x + w, y:y + h] += out
  del (outs)

pred = np.argmax(pred, axis=-1)

all_preds.append(pred)


return all_preds

Когда я загружаю большое изображение, скажем (40k x 40k), я получаю MemoryError. Как я могу избежать этой ошибки памяти. Может быть, с помощью серии небольших изображений? Как эффективно это реализовать?

...