ValueError: Получить аргумент не может быть истолкован как Тензор - PullRequest
2 голосов
/ 04 мая 2020

Здесь я обучаю свою сеть, используя набор данных, который хранится в файле .tfrecord. Этот набор данных включает в себя изображения и позы объектов. Но всякий раз, когда я запускаю этот код, я получаю следующую ошибку, которая упоминается внизу.

def _parse_image_function(example_proto):
    image_feature_description = {
    'height': tf.FixedLenFeature([], tf.int64),
    'width': tf.FixedLenFeature([], tf.int64),
    'depth': tf.FixedLenFeature([], tf.int64),
    'label': tf.FixedLenFeature([], tf.string),
    'image_raw': tf.FixedLenFeature([], tf.string),
     }

     # Parse the input tf.Example proto using the dictionary above.
     example = tf.parse_single_example(example_proto, image_feature_description)

     height_feature = example['height'] # get byte string
     width_feature = example['width'] # get byte string
     depth_feature = example['depth'] # get byte string
     image_raw_feature = example['image_raw'] # get byte string    
     label_feature = example['label'] # get byte string

     images = tf.parse_tensor(example['image_raw'], out_type=tf.int32) # restore 2D array from byte string        
     images = tf.cast(images, dtype=tf.float32)/255.0    
     images = tf.reshape(images, [120, 120, 3])    

     image_label = tf.parse_tensor(label_feature, out_type=tf.float64) # restore 2D array from byte string

     return images, image_label


def get_batched_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(_parse_image_function)
    dataset = dataset.cache() # This dataset fits in RAM
    dataset = dataset.repeat()
    dataset = dataset.shuffle(20000)
    dataset = dataset.batch(BATCH_SIZE) # drop_remainder will be needed on TPU
    #dataset = dataset.prefetch(AUTO) #

    return dataset


def get_training_dataset():
    return get_batched_dataset(training_filenames)


def get_validation_dataset():
    return get_batched_dataset(validation_filenames)


model = tf.keras.Sequential([              
tf.keras.layers.Conv2D(kernel_size=(3, 3), filters=64, dtype='float32', input_shape=(120, 120, 3)),
tf.keras.layers.Activation('relu'),

tf.keras.layers.Conv2D(kernel_size=(3, 3), filters=64, use_bias=True),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=2),

tf.keras.layers.Conv2D(kernel_size=(3, 3), filters=32, use_bias=True),
#tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=2),

tf.keras.layers.Conv2D(kernel_size=(3, 3), filters=32, use_bias=True),
#tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=2),

tf.keras.layers.Conv2D(kernel_size=(3, 3), filters=16, use_bias=True),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=2),


tf.keras.layers.Conv2D(kernel_size=(3, 3), filters=8, use_bias=True),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=2),



tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, use_bias = True),
tf.keras.layers.Activation('relu'),

#tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(12, activation='linear', name='fc')])


model.compile(optimizer='Adam', loss='mse', metrics=['mae', 'mse'])  # mean absolute error        
#model.summary()

logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)


history = model.fit_generator(get_training_dataset(), steps_per_epoch=steps_per_epoch, epochs=100, validation_data=get_validation_dataset(), validation_steps=validation_steps, callbacks=[tensorboard_callback])

Всякий раз, когда я запускаю этот код, я получаю эту ошибку:

ValueError: Fetch argument <tf.Tensor 'IteratorGetNext:0' shape=(?, 120, 120, 3) dtype=float32> cannot be interpreted as a Tensor. (Tensor Tensor("IteratorGetNext:0", shape=(?, 120, 120, 3), dtype=float32) is not an element of this graph.)

Важно отметить, что я использую tenorflow 1.8.

...