обучение модели keras с набором данных tensorflow - PullRequest
0 голосов
/ 12 июля 2020

Я новичок в tenorflow, я пытаюсь создать классификатор изображений с resnet50 для классификации набора данных о породах собак, но я не могу иметь дело с tensorflow.dataset. вот код

import numpy as np
import pandas as pd 
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, Flatten, GlobalAveragePooling2D, BatchNormalization
from tensorflow.keras.applications import ResNet50

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn import metrics
import tensorflow_datasets as tfds




train= tfds.load('stanford_dogs', split= 'train')
test= tfds.load('stanford_dogs', split= 'test')

model = keras.Sequential()
model.add(ResNet50(include_top=False, weights='imagenet', pooling='avg', ))
model.add(BatchNormalization())

model.add(Dense(1024, activation = 'relu'))
model.add(BatchNormalization())

model.add(Dense(120, activation='softmax'))

model.layers[0].trainable = False

model.compile(optimizer = 'adam', loss = keras.losses.sparse_categorical_crossentropy, metrics = ['accuracy'])

model.summary()



model.fit(
    train,
    steps_per_epoch = 100,
    epochs = 30,
    verbose =2,
    validation_data = test
       
)

, и он дал мне эту ошибку,

KeyError                                  Traceback (most recent call last)
<ipython-input-16-da6b57b304b5> in <module>()
      4     epochs = 30,
      5     verbose =2,
----> 6     validation_data = test
      7 
      8 )

10 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise


    KeyError: 'resnet50_input'

переменные test и train имеют следующий тип

type(train)
tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter

type(test)
tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter

набор данных хранится как tfrecord

!ls /root/tensorflow_datasets/stanford_dogs/0.2.0



dataset_info.json
image.image.json
label.labels.txt
stanford_dogs-test.tfrecord-00000-of-00004
stanford_dogs-test.tfrecord-00001-of-00004
stanford_dogs-test.tfrecord-00002-of-00004
stanford_dogs-test.tfrecord-00003-of-00004
stanford_dogs-train.tfrecord-00000-of-00004
stanford_dogs-train.tfrecord-00001-of-00004
stanford_dogs-train.tfrecord-00002-of-00004
stanford_dogs-train.tfrecord-00003-of-00004

Я искал решение в Google, все, что я нашел, это статьи о том, как преобразовать набор данных в tfrecord, а затем прочитать tfrecord и построить с ним входной конвейер, но документация tenorflow говорит, что tenorflow_datasets (tfds) определяет набор наборов данных, готовых к использованию с TensorFlow

...