Обновите субплоты Pyplot изображениями в цикле - PullRequest
0 голосов
/ 09 июня 2019

Я думаю, что делаю что-то действительно глупое, но я не могу понять это. Я хочу создать класс для отображения набора изображений в качестве вспомогательных участков; дисплей должен обновляться вручную из цикла. Вот класс, который я создал, чтобы попытаться сделать это:

import matplotlib.pyplot as plt
import numpy as np

class tensor_plot:
    def __init__(self, tensor_shape, nrows=1):
        self.img_height, self.img_width, self.num_imgs = tensor_shape
        self.nrows = nrows
        self.ncols = self.num_imgs // nrows
        assert(self.ncols*self.nrows == self.num_imgs)
        self.fig, self.a = plt.subplots(self.nrows, self.ncols, sharex='col', sharey='row')
        for (row, col) in zip(range(self.nrows), range(self.ncols)):
            self.a[row, col] = plt.imshow(np.zeros([self.img_height, self.img_width]))

    def update(self, tensor):
        n=0
        for row in range(self.nrows):
            for col in range(self.ncols):
                self.a[row,col].set_data(tensor[:,:,n].squeeze())
                n += 1
        plt.show()

Когда я пытаюсь передать тензор для обновления, он говорит, что атрибут set_data отсутствует. Но при использовании dir есть такой атрибут.

In [322]: tp = tensor_plot(l10.shape, 4)

In [323]: tp.update(l10)
AttributeError: 'AxesSubplot' object has no attribute 'set_data'


In [324]: dir(tp.a[0,0])
Out[324]: 
['_A',
...
 'set_data',
...
 'update_from',
 'write_png',
 'zorder']

Если я добавлю строку print(dir(self.a[row,col])) в цикл, то правда, что set_data там нет! Тот же комментарий относится к imshow.

Есть идеи?

1 Ответ

0 голосов
/ 10 июня 2019

Большое спасибо @ImportanceOfBeingEarnest, вот последний код, который работает для меня (в случае, если он полезен для других).

class tensor_plot:
    def __init__(self, tensor_shape, nrows=1):
        self.img_height, self.img_width, self.num_imgs = tensor_shape
        self.nrows = nrows
        self.ncols = self.num_imgs // nrows
        assert(self.ncols*self.nrows == self.num_imgs)
        self.fig, self.a = plt.subplots(self.nrows, self.ncols, sharex='col', sharey='row')

        self.imgs = np.array( [   [ self.a[row, col].imshow(np.zeros([self.img_height, self.img_width])) for col in range(self.ncols)    ] for row in range(self.nrows)])
        plt.pause(0.1)


    def update(self, tensor):
        n=0
        for row in range(self.nrows):
            for col in range(self.ncols):
                self.imgs[row,col].set_data(tensor[:,:,n].squeeze())
                self.imgs[row,col].set_clim(vmin=0, vmax=255)
                n += 1
        self.fig.canvas.draw_idle()
        plt.pause(0.01)
        plt.draw_all()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...