Здесь я обучаю свою сеть, используя набор данных, который хранится в файле .tfrecord. Этот набор данных включает в себя изображения и позы объектов. Но всякий раз, когда я запускаю этот код, я получаю следующую ошибку, которая упоминается внизу.
def _parse_image_function(example_proto):
image_feature_description = {
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64),
'label': tf.FixedLenFeature([], tf.string),
'image_raw': tf.FixedLenFeature([], tf.string),
}
# Parse the input tf.Example proto using the dictionary above.
example = tf.parse_single_example(example_proto, image_feature_description)
height_feature = example['height'] # get byte string
width_feature = example['width'] # get byte string
depth_feature = example['depth'] # get byte string
image_raw_feature = example['image_raw'] # get byte string
label_feature = example['label'] # get byte string
images = tf.parse_tensor(example['image_raw'], out_type=tf.int32) # restore 2D array from byte string
images = tf.cast(images, dtype=tf.float64)/255.0
images = tf.reshape(images, [120, 120, 3])
image_label = tf.parse_tensor(label_feature, out_type=tf.float64) # restore 2D array from byte string
return images, image_label
def get_batched_dataset(filenames):
#option_no_order = tf.data.Options()
#option_no_order.experimental_deterministic = False
path = '/media/local/username/dataset/image_label_datatset/'
filenames1 = get_dataset_from_directory(path)
#dataset = tf.data.Dataset.list_files(filenames)
dataset = tf.data.TFRecordDataset(filenames1)
#dataset = dataset.with_options(option_no_order)
#dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=8, num_parallel_calls=AUTO)
dataset = dataset.map(_parse_image_function)
dataset = dataset.cache() # This dataset fits in RAM
dataset = dataset.repeat()
dataset = dataset.shuffle(20000)
dataset = dataset.batch(BATCH_SIZE) # drop_remainder will be needed on TPU
#dataset = dataset.prefetch(AUTO) #
return dataset
def get_training_dataset():
return get_batched_dataset(training_filenames)
def get_validation_dataset():
return get_batched_dataset(validation_filenames)
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(kernel_size=(3, 3), filters=64, dtype='float32', input_shape=(120, 120, 3)),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Conv2D(kernel_size=(3, 3), filters=64, use_bias=True),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=2),
tf.keras.layers.Conv2D(kernel_size=(3, 3), filters=32, use_bias=True),
#tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=2),
tf.keras.layers.Conv2D(kernel_size=(3, 3), filters=32, use_bias=True),
#tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=2),
tf.keras.layers.Conv2D(kernel_size=(3, 3), filters=16, use_bias=True),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=2),
tf.keras.layers.Conv2D(kernel_size=(3, 3), filters=8, use_bias=True),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, use_bias = True),
tf.keras.layers.Activation('relu'),
#tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(12, activation='linear', name='fc')])
#model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])
#model.compile(optimizer=tf.keras.optimizers.Adam(0.01), loss='mean_squared_logarithmic_error', metrics=['accuracy'])
#sgd = SGD(lr=0.001, decay=1e-6, momentum=0.1, nesterov=True)
#model.compile(loss='mean_squared_logarithmic_error', optimizer= sgd, metrics=["mse"])
model.compile(optimizer='Adam', loss='mse', metrics=['mae', 'mse']) # mean absolute error
#model.summary()
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)
history = model.fit_generator(get_training_dataset(), steps_per_epoch=steps_per_epoch, epochs=100, validation_data=get_validation_dataset(), validation_steps=validation_steps, callbacks=[tensorboard_callback])
Всякий раз, когда я запускаю этот код, я получаю эту ошибку:
if len(validation_data) == 2:
TypeError: object of type 'BatchDataset' has no len()
Важно отметить, что я использую tenorflow 1.8.