Оценить модель в другом измерении данных - PullRequest
1 голос
/ 02 марта 2020

У меня есть данные формы (37520, 32, 9) с пятью классами формы (37520, 5), я использую Conv1D для обучения модели, и до сих пор я могу тренировать данные. Но проблема в том, что мне нужно оценить его в другом измерении - (37520, 32, 4) (классы одинаковы), я получаю следующую ошибку:

Traceback (most recent call last):
  File "data_maker_cnn_multiuser_folds_correct.py", line 870, in <module>
    f = run_cnn(a)
  File "data_maker_cnn_multiuser_folds_correct.py", line 155, in run_cnn
    _, accuracy = model.evaluate(x_test, y_test, batch_size=batch_size, verbose=verbose)
  File "/Users/akshayrajgollahalli/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 930, in evaluate
    use_multiprocessing=use_multiprocessing)
  File "/Users/akshayrajgollahalli/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 490, in evaluate
    use_multiprocessing=use_multiprocessing, **kwargs)
  File "/Users/akshayrajgollahalli/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 426, in _model_iteration
    use_multiprocessing=use_multiprocessing)
  File "/Users/akshayrajgollahalli/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 646, in _process_inputs
    x, y, sample_weight=sample_weights)
  File "/Users/akshayrajgollahalli/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 2383, in _standardize_user_data
    batch_size=batch_size)
  File "/Users/akshayrajgollahalli/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 2410, in _standardize_tensors
    exception_prefix='input')
  File "/Users/akshayrajgollahalli/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_utils.py", line 582, in standardize_input_data
    str(data_shape))
ValueError: Error when checking input: expected conv1d_input to have shape (32, 9) but got array with shape (32, 4)

Со следующим кодом:

def run_cnn(data_dict: Dict[str, np.ndarray]):
    """
    Runs a 1D CNN.

    :param data_dict: A dictionary of training, testing and their targets (Ndarray).
    :return:
    """

    x_train = data_dict['training_data']
    x_test = data_dict['testing_data']

    y_train = np.expand_dims(data_dict['training_target'], axis=1)
    y_test = np.expand_dims(data_dict['testing_target'], axis=1)

    y_train = y_train - 1
    y_test = y_test - 1
    y_train = tf.keras.utils.to_categorical(y_train)
    y_test = tf.keras.utils.to_categorical(y_test)
    print(y_test.shape)

    print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)

    verbose = 1
    epochs = 2
    batch_size = 32

    n_timesteps, n_features, n_outputs = x_train.shape[1], x_train.shape[2], y_train.shape[1]
    print("Number of time steps: ", n_features)
    print("Number of features: ", n_features)
    print("Number of outputs: ", n_outputs)

    model = tf.keras.Sequential()
    model.add(
        tf.keras.layers.Conv1D(filters=64, kernel_size=3, activation='relu',
                               input_shape=(n_timesteps, n_features)))
    model.add(tf.keras.layers.Conv1D(filters=64, kernel_size=3, activation='relu'))
    model.add(tf.keras.layers.Dropout(0.5))
    model.add(tf.keras.layers.MaxPooling1D(pool_size=2))
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(100, activation='relu'))
    model.add(tf.keras.layers.Dense(n_outputs, activation='softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    # fit network
    history = model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, verbose=verbose)
    # evaluate model
    _, accuracy = model.evaluate(x_test, y_test, batch_size=batch_size, verbose=verbose)

    return accuracy

Возможно ли это сделать? Я тоже пытался использовать прогноз, но все равно получаю ошибку.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...