Я новичок в tenorflow, я пытаюсь создать классификатор изображений с resnet50 для классификации набора данных о породах собак, но я не могу иметь дело с tensorflow.dataset. вот код
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, Flatten, GlobalAveragePooling2D, BatchNormalization
from tensorflow.keras.applications import ResNet50
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import metrics
import tensorflow_datasets as tfds
train= tfds.load('stanford_dogs', split= 'train')
test= tfds.load('stanford_dogs', split= 'test')
model = keras.Sequential()
model.add(ResNet50(include_top=False, weights='imagenet', pooling='avg', ))
model.add(BatchNormalization())
model.add(Dense(1024, activation = 'relu'))
model.add(BatchNormalization())
model.add(Dense(120, activation='softmax'))
model.layers[0].trainable = False
model.compile(optimizer = 'adam', loss = keras.losses.sparse_categorical_crossentropy, metrics = ['accuracy'])
model.summary()
model.fit(
train,
steps_per_epoch = 100,
epochs = 30,
verbose =2,
validation_data = test
)
, и он дал мне эту ошибку,
KeyError Traceback (most recent call last)
<ipython-input-16-da6b57b304b5> in <module>()
4 epochs = 30,
5 verbose =2,
----> 6 validation_data = test
7
8 )
10 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
966 except Exception as e: # pylint:disable=broad-except
967 if hasattr(e, "ag_error_metadata"):
--> 968 raise e.ag_error_metadata.to_exception(e)
969 else:
970 raise
KeyError: 'resnet50_input'
переменные test и train имеют следующий тип
type(train)
tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter
type(test)
tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter
набор данных хранится как tfrecord
!ls /root/tensorflow_datasets/stanford_dogs/0.2.0
dataset_info.json
image.image.json
label.labels.txt
stanford_dogs-test.tfrecord-00000-of-00004
stanford_dogs-test.tfrecord-00001-of-00004
stanford_dogs-test.tfrecord-00002-of-00004
stanford_dogs-test.tfrecord-00003-of-00004
stanford_dogs-train.tfrecord-00000-of-00004
stanford_dogs-train.tfrecord-00001-of-00004
stanford_dogs-train.tfrecord-00002-of-00004
stanford_dogs-train.tfrecord-00003-of-00004
Я искал решение в Google, все, что я нашел, это статьи о том, как преобразовать набор данных в tfrecord, а затем прочитать tfrecord и построить с ним входной конвейер, но документация tenorflow говорит, что tenorflow_datasets (tfds) определяет набор наборов данных, готовых к использованию с TensorFlow