Форма модели изменяется после обучения и не может тренироваться снова сохраненной модели - PullRequest
0 голосов
/ 21 января 2020

У меня эта проблема, которую я продолжаю получать последние 2 недели, и я не нашел решения для нее. Итак, я тренирую последовательную модель keras для генерации текста, она работает нормально, без проблем, но после того, как она закончила тренировку и сохранила модель, когда я пытаюсь загрузить ее и обучить ее дальше, я продолжаю получать ValueErrors из-за форм слои.

Вот код:

token = clean_tweets(text)
print("Generating sequences of text...")
sequence_len = 21
seq = []
for i in range(0, len(token) - sequence_len):
    seq.append(token[i:i + sequence_len])
print("Done generating the sequences.")

print("Tokenizing...")
tokenizer = Tokenizer()
tokenizer.fit_on_texts(seq)
sequence = tokenizer.texts_to_sequences(seq)
print("Done tokenizing. Printing first sequence.")

vocab_size = len(tokenizer.word_index) + 1
print("vocab_size = " + str(vocab_size))

print("Dividing our sequence list into input and corresponding output.")
arr = np.array(sequence)
X, Y = arr[:, :-1], arr[:, -1]
print("Y shape before to_categorical" + str(Y.shape))
Y = to_categorical(Y, num_classes=vocab_size)
print("Y shape after to _categorical" + str(Y.shape))
seq_length = X.shape[1]
print(seq_length)

if os.path.isfile(f'model_gen/{f_name}.h5'):
    model = load_saved_model(f_name)
else:
    model = Sequential()
    model.add(Embedding(vocab_size, 150, input_length=seq_length))
    model.add(LSTM(256, return_sequences=True))
    model.add(LSTM(256))
    model.add(Dense(256, activation='relu'))
    model.add(Dense(vocab_size, activation='softmax'))
    print(model.summary())

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.fit(X, Y, batch_size=80, epochs=30)

save_trained_model(model, f_name)

метод save_trained_model:

def save_trained_model(trained_model, file_name):
    with open(f'model_gen/{file_name}.json', 'w') as json:
        try:
            json.write(trained_model.to_json())
        except:
            print("Model could not be saved in a .json format.")
        try:
            trained_model.save_weights(f'model_gen/{file_name}.h5')
        except:
            print("Models weights could not be saved in a .h5 format.")

и load_saved_model:

def load_saved_model(FILE_NAME):
    with open(f'model_gen/{FILE_NAME}.json') as json:
        try:
            model = model_from_json(json.read())
            print("Loaded json file for the model.")
            try:
                model.load_weights(f'model_gen/{FILE_NAME}.h5')
                print("Loaded weights for the model.")
            except:
                print("Models weights could not be loaded from the .h5 format.")
        except:
            print("Model could not be loaded from the .json format.")
    return model

выход после попытки тренировать его снова:

Traceback (most recent call last):
  File ".../src/tweet_generator.py", line 201, in <module>
    train_gen()
  File ".../src/tweet_generator.py", line 70, in train_gen
    model.fit(X, Y, batch_size=80, epochs=30)
  File "...\Anaconda\envs\gputest\lib\site-packages\keras\engine\training.py", line 1154, in fit
    batch_size=batch_size)
  File "...\Anaconda\envs\gputest\lib\site-packages\keras\engine\training.py", line 621, in _standardize_user_data
    exception_prefix='target')
  File "...\Anaconda\envs\gputest\lib\site-packages\keras\engine\training_utils.py", line 145, in standardize_input_data
    str(data_shape))
ValueError: Error when checking target: expected dense_2 to have shape (6620,) but got array with shape (6609,)

6609 был vocab_size до первой тренировки, и кажется, что он увеличился после одной сессии тренировки. Я могу изменить его вручную на 6620, и он работает, но не правильно, я имею в виду, что это не похоже на возобновление. Например, последняя потеря первой тренировки была 2.7, и после ручного изменения vocab_size и повторного обучения на загруженной модели она начинается с потери 25 или чего-то в этом роде. Так что это не возобновление.

Кто-нибудь знает, как я могу решить эту проблему?

Спасибо!

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...