Неверный индекс для скалярной переменной с использованием imageio - PullRequest
0 голосов
/ 19 июня 2019

Я импортирую изображения кошек и лошадей в нейронную сеть vgg-19 с помощью imageio, а затем хочу импортировать только одно изображение, чтобы предсказать класс, но в скалярной переменной есть неверный индекс ошибки.

ЭтоВот как я импортирую изображения перед запуском модели vgg19.

def generator(batch_size, datapath):
    from random import shuffle
    target = glob.glob(datapath + '*.png')  
    n_samples = len(target)  
    n_batches = n_samples // batch_size 
    b = n_batches

    while True:

        if b == n_batches:  
            shuffle(target)
            b = 0
            print("epoch finished - " + datapath)

        # initialize current batch
        batch_features = np.zeros((batch_size, LEFT, RIGHT, 3))  
        batch_labels = np.zeros((batch_size, 2))
        target_b = target[b * batch_size:(b + 1) * batch_size]  

        # populate current batch
        for i, t in enumerate(target_b):
            batch_features[i, :, :, :] = imageio.imread(t)[:, :, :3]  

            batch_labels[i, :] = np.array([1, 0]) if "cat" in t else np.array([0, 1])

        b += 1
        yield batch_features, batch_labels

Чем работает модель vgg-19.

И вот как я импортирую одно изображение после отслеживания модели:

im = imageio.imread('test_image.png')[:, :, :3]
im = np.expand_dims(im, axis=0)
print(im.shape)
from keras.preprocessing.image import img_to_array

yhat = model.predict(im)
print(yhat)
class_labels = ['cat', 'horse'] 
pred = np.argmax(class_labels)
print(class_labels[pred[0]])

В последней строке кода указан неверный индекс скалярной переменной.Он печатает «yhat», но не дает ярлыков классов для предсказаний.Как я должен восстановить код, чтобы иметь тот же формат?

...