ValueError: Ошибка при проверке ввода: ожидалось, что input_5 будет иметь 4 измерения, но получил массив с shape () - PullRequest
0 голосов
/ 27 сентября 2019

Я пытаюсь загрузить модель densenet-121, предварительно обученную на весах imagenet, и тренироваться на моем наборе данных.У меня есть два файла, а именно train.csv и validation.csv, и я разделил файл train.csv на 80-20 наборов поездов и наборов проверки и использую файл validation.csv для тестирования.Однако, когда я пытаюсь подгонять мою модель к набору данных, я получаю вышеуказанную ошибку.Вот мой фрагмент кода:

import os
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras.applications.densenet import DenseNet121
import collections
from sklearn.model_selection import train_test_split
import pandas as pd
import tensorflow_datasets as tfds
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.data import TFRecordDataset
model = DenseNet121(include_top=False, weights='imagenet',  input_shape=(input_shp_img),  
        classes=class_names)
_DATA_DIR = "C:/Users/1024/Documents/my_file"
_TRAIN_DIR = os.path.join(_DATA_DIR, "train")
TEST_DIR = os.path.join(_DATA_DIR, "valid")
_TRAIN_LABELS_FNAME = os.path.join(_DATA_DIR, "train.csv")
TEST_LABELS_FNAME = os.path.join(_DATA_DIR, "valid.csv")
_LABELS = collections.OrderedDict({
"-1.0": "uncertain",
"1.0": "positive",
"0.0": "negative",
"": "unmentioned",
})
len(data_csv)
train_data_size= int(0.8 * len(data_csv))
val_data_size = int(0.2 * len(data_csv))
full_dataset = tf.data.TFRecordDataset(_TRAIN_LABELS_FNAME)
full_dataset = full_dataset.shuffle(buffer_size=800)
train_dataset = full_dataset.take(train_data_size)
val_dataset = full_dataset.skip(train_data_size)
val_dataset = val_dataset.skip(val_data_size)
print(train_dataset)
model.compile(optimizer='sgd',
      loss='mean_squared_error',
      metrics=['accuracy'])
history = model.fit(train_dataset, batch_size=1000, validation_data=val_dataset, 
          steps_per_epoch=100, epochs=100, use_multiprocessing=True)

Ниже приведена ошибка, которую я получаю:

    ValueError                                Traceback (most recent call last)
    <ipython-input-117-bf00f479a1f4> in <module>
      2           loss='mean_squared_error',
      3           metrics=['accuracy'])
----> 4 history = model.fit(train_dataset, batch_size=1000, validation_data=val_dataset, 
    steps_per_epoch=100, epochs=100, use_multiprocessing=True)
      5 

    ~\.conda\envs\manpreet_env\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, 
     x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, 
     class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, max_queue_size, 
     workers, use_multiprocessing, **kwargs)
    774         steps=steps_per_epoch,
    775         validation_split=validation_split,
    --> 776         shuffle=shuffle)
    777 
    778     # Prepare validation data.

    ~\.conda\envs\manpreet_env\lib\site-packages\tensorflow\python\keras\engine\training.py in 
    _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, 
    steps, validation_split, shuffle)
    2380         feed_input_shapes,
    2381         check_batch_axis=False,  # Don't enforce the batch size.
    -> 2382         exception_prefix='input')
    2383 
    2384     if y is not None:

    ~\.conda\envs\manpreet_env\lib\site-packages\tensorflow\python\keras\engine\training_utils.py in 
    standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    351                            ': expected ' + names[i] + ' to have ' +
    352                            str(len(shape)) + ' dimensions, but got array '
    --> 353                            'with shape ' + str(data_shape))
    354         if not check_batch_axis:
    355           data_shape = data_shape[1:]

    ValueError: Error when checking input: expected input_5 to have 4 dimensions, but got array with shape ()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...