При попытке соответствовать модели Keras, написанной в tensorflow.keras
API с tf.Dataset
индуцированным итератором, модель жалуется на аргумент steps_per_epoch
, хотя я установил для него конкретное значение.
Вот мой класс модели
import tensorflow as tf
import numpy as np
from typing import Union, List
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras import layers
from tftools import TFTools
class TestServe():
def __init__(self, tfrecords: Union[List[tf.train.Example], tf.train.Example], batch_size: int = 10, input_shape: tuple = (64, 23)) -> None:
self.tfrecords = tfrecords
self.batch_size = batch_size
self.input_shape = input_shape
def get_model(self):
ins = layers.Input(shape=(64, 23))
l = layers.Reshape((*self.input_shape, 1))(ins)
l = layers.Conv2D(8, (30, 23), padding='same', activation='relu')(l)
l = layers.MaxPool2D((4, 5), strides=(4, 5))(l)
l = layers.Conv2D(16, (3, 3), padding='same', activation='relu')(l)
l = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(l)
l = layers.MaxPool2D((2, 2), strides=(2, 2))(l)
l = layers.Flatten()(l)
out = layers.Dense(1, activation='softmax')(l)
return tf.keras.models.Model(ins, out)
def train(self):
# Create Dataset
dataset = TFTools.create_dataset(self.tfrecords)
dataset = dataset.repeat(6).batch(self.batch_size)
val_iterator = dataset.take(300).make_one_shot_iterator()
train_iterator = dataset.skip(300).make_one_shot_iterator()
model = self.get_model()
model.summary()
model.compile(optimizer='rmsprop',
loss='binary_crossentropy', metrics=['accuracy'])
model.fit(train_iterator, validation_data=val_iterator,
epochs=10, verbose=1, steps_per_epoch=20)
def predict(self, X: np.array) -> np.array:
pass
ts = TestServe(['./ok.tfrecord', './nok.tfrecord'])
ts.train()
Но как только я начинаю тренировку, до того, как закончится первая партия, я получаю исключение от tenorflow
2019-06-13 14:22:25.393398: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 1995445000 Hz
2019-06-13 14:22:25.393681: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x2f7d120 executing computations on platform Host. Devices:
2019-06-13 14:22:25.393708: I tensorflow/compiler/xla/service/service.cc:158] StreamExecutor device (0): <undefined>, <undefined>
Epoch 1/2
19/20 [===========================>..] - ETA: 0s - loss: 1.1921e-07 - acc: 1.0000Traceback (most recent call last):
File "TestServe.py", line 62, in <module>
ts.train()
File "TestServe.py", line 56, in train
epochs=2, verbose=1, callbacks=callbacks, steps_per_epoch=20) #The steps_per_epoch is typically samples_per_epoch / batch_size
File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 880, in fit
validation_steps=validation_steps)
File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 364, in model_iteration
validation_in_fit=True)
File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 202, in model_iteration
steps_per_epoch)
File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 76, in _get_num_samples_or_steps
'steps_per_epoch')
File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_utils.py", line 230, in check_num_samples
if check_steps_argument(ins, steps, steps_name):
File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_utils.py", line 960, in check_steps_argument
input_type=input_type_str, steps_name=steps_name))
ValueError: When using data tensors as input to a model, you should specify the `steps_per_epoch` argument.
Исходный набор данных содержит около 1500 сэмплов, но я хочу присоединить несколько файлов tfrecord к TFRecordDataset, поэтому у меня не будет информации о длине.
Кто-нибудь видел что-то подобное раньше? Я не знаю, куда обратиться за помощью, поскольку API tf.keras
является относительно новым. Функция create_dataset
просто возвращает набор данных, сопоставленный с правой функцией анализа.