Прочтите изображение, чтобы получить функции, а затем используйте модель кластеризации, чтобы установить кластеры для похожих изображений. - PullRequest
0 голосов
/ 30 января 2020

Я пытаюсь использовать модель ResNet50 для извлечения характеристик изображения. Затем используйте эти возможности для создания кластеров с использованием кластеризации иерархии scipy. (Я не знаю, возможно ли это). Что ж, я делаю, создавая модель.

  1. Прочитайте изображение
  2. Извлеките элементы
  3. Затем измените его, чтобы он стал 1D массивом
  4. Добавьте это к набору данных.

Однако я могу использовать кластер, но когда я пытаюсь построить кластеры, он выдает ошибку.

TypeError: scatter() got multiple values for argument 'c'

Это означает, что когда я использую * np.transpose (data), он распаковывает слишком много материала, поскольку в методе построения должно быть только 4 элемента. Так что я немного растерялся, что делать здесь. Что-то не так с кластеризацией или мне нужно изменить метод построения?

Вот мой код.

import matplotlib.pyplot as plt
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import numpy as np
import scipy.cluster.hierarchy as hcluster
import os
data_dir = '../Dataset/Train/'
extractor = ResNet50(include_top=False, weights='imagenet')
data = []
number_images = 0
i = 0
MAX_IMAGES = 2
for images in os.listdir(data_dir):
    if i == MAX_IMAGES:
        break
    img_path = os.path.join(data_dir, images)
    # Read in image
    img = load_img(img_path)
    img = img_to_array(img)
    img = np.expand_dims(img, axis=0)
    img = preprocess_input(img)
    feat = extractor.predict(img)
    feat = np.squeeze(feat) # Remove the 1
    feat = feat.reshape(-1)
    print(feat.shape)
    data.append(feat)
    i +=1
    number_images+=1

print("Total images=", number_images)
print(len(data))

THRESHOLD = 1.5
clusters = hcluster.fclusterdata(data, THRESHOLD, criterion='distance')

# plotting
plt.scatter(*np.transpose(data), c=clusters)
plt.axis("equal")
title = "threshold: %f, number of clusters: %d" % (THRESHOLD, len(set(clusters)))
plt.title(title)
plt.show()
...