Как получить строки меток в тонко настроенной сети, используя файлы контрольных точек или tf-запись? - PullRequest
0 голосов
/ 28 июня 2018

Например, я настроил сеть VGG, используя свой собственный набор данных, только с 2 метками foo и bar. Я преобразовал изображения в tf.record, используя пример через эту ссылку :

labels_to_class_names = dict(zip(range(len(class_names)), class_names))
dataset_utils.write_label_file(labels_to_class_names, dataset_dir)

Я собираюсь создать API для прогнозирования изображений на основе этой новой модели, мой вопрос: есть ли какой-нибудь формальный способ получить строку метки из файлов контрольных точек или из набора данных (например, predict_image("abc.png") возвращает foo строку)? Поскольку я понятия не имею, какой узел в слое logits представляет метку foo, а какой представляет bar

Я пробовал искать, но без помощи, и я все еще новичок в тензорном потоке.

1 Ответ

0 голосов
/ 28 июня 2018

Модель (и, кстати, файлы контрольных точек) не имеют названия каждого класса. Все, что у него есть, - это определенное количество выходных нейронов, первый из которых соответствует первому классу, второй - второму классу и т. Д.

Если вы хотите узнать, какой из них какой, посмотрите на файл меток, созданный этой строкой (скорее всего, с названием label.txt):

dataset_utils.write_label_file(labels_to_class_names, dataset_dir)

В качестве альтернативы вы можете проверить содержимое labels_to_class_names dict:

In [1]: class_names=['aaa', 'bbb', 'ccc']

In [2]: labels_to_class_names = dict(zip(range(len(class_names)), class_names))

In [3]: labels_to_class_names
Out[3]: {0: 'aaa', 1: 'bbb', 2: 'ccc'}

-> значение по индексу 0 на выходе модели = класс 'aaa' и т. Д.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...