Как я могу улучшить предсказания чисел? - PullRequest
0 голосов
/ 10 июня 2019

У меня есть некоторая модель классификации чисел, по тестовым данным она работает нормально, но когда я хочу классифицировать другие изображения, я столкнулся с проблемами, которые моя модель не может точно предсказать, какое это число. Пожалуйста, помогите мне улучшить производительность model.predict ().

Я пытался обучить мою модель многими способами, в приведенном ниже коде есть функция, которая создает модель классификации, я обучил эту модель фактически многими способами, [1K

def load_data():
    (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

    train_images = tf.keras.utils.normalize(train_images, axis = 1)
    test_images = tf.keras.utils.normalize(test_images, axis = 1)

    return (train_images, train_labels), (test_images, test_labels)

def create_model():
    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(128, activation = tf.nn.relu))
    model.add(tf.keras.layers.Dense(128, activation = tf.nn.relu))
    model.add(tf.keras.layers.Dense(10, activation = tf.nn.softmax))

    data = load_data(n=60000, k=5)
    model.compile(optimizer ='adam',
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])
    model.fit(data[0][0][:n], data[0][1][:n], epochs = e)# ive tried from 3-50 epochs
    model.save(config.model_name)

def load_model():
    return tf.keras.models.load_model(config.model_name)def predict(images):
    try:
        model = load_model()
    except:
        create_model()
        model = load_model()
    images = tf.keras.utils.normalize(images, axis = 0)
    d = load_data()

    plot_many_images([d[0][0][0].reshape((28,28)), images[0]],['data', 'image'])

    predictions = model.predict(images)
    return predictions

Я думаю, что мои входные данные не выглядят так, как будто это модель предсказания, но я пытался сделать ее максимально похожей. На этой картинке (https://imgur.com/FfLGMEK) на левом изображении - изображение данных поезда, а на правом - мое проанализированное изображение, оба имеют размер 28x28 пикселей, оба cv2.noramalized

для предсказаний тестовых изображений, которые я использовал (https://imgur.com/RMfKtag) sudoku, он уже отформатирован, чтобы быть похожим на номера тестовых данных, но когда я тестирую это изображение с предсказанием модели, результат не так хорош (https://imgur.com/RQFvLNE) Как видите, прогнозируемые данные оставляют желать лучшего.

P.S. ('') элементы в предсказанных данных получаются моими руками (я заменил числа в этих позициях на ''), потому что после предсказаний все они имеют какое-то значение (1-9), теперь это не нужно.

1 Ответ

0 голосов
/ 13 июня 2019

что вы имеете в виду "на тестовых данных все работает нормально"?если вы имеете в виду, что он работает хорошо для данных о поездах, но не имеет хорошего прогноза для тестовых данных, возможно, ваша модель была перегружена на этапе обучения.я предлагаю использовать метод обучения / проверки / тестирования для обучения вашей сети.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...