from keras.models import Sequential, load_model
CONFIG.input_layers = 2
CONFIG.output_layers = 2
CONFIG.amount_of_dropout = 0.2
CONFIG.batch_size = 100
CONFIG.epochs = 500
def create_model(output_len, chars=None):
print('Build model...')
chars = chars or CHARS
model = Sequential()
for layer_number in range(CONFIG.input_layers):
model.add(recurrent.LSTM(CONFIG.hidden_size, input_shape=(None, len(chars)), kernel_initializer=CONFIG.initialization,
return_sequences=layer_number + 1 < CONFIG.input_layers))
model.add(Dropout(CONFIG.amount_of_dropout))
model.add(RepeatVector(output_len))
for _ in range(CONFIG.output_layers):
model.add(recurrent.LSTM(CONFIG.hidden_size, return_sequences=True, kernel_initializer=CONFIG.initialization))
model.add(Dropout(CONFIG.amount_of_dropout))
model.add(TimeDistributed(Dense(len(chars), kernel_initializer=CONFIG.initialization)))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
def iterate_training(model, X_train, y_train, X_val, y_val, ctable):
for iteration in range(1, CONFIG.number_of_iterations):
model.fit(X_train, y_train, batch_size=CONFIG.batch_size, epochs=CONFIG.epochs,
validation_data=(X_val, y_val), shuffle=True, verbose=2)
Я тренируюсь со странами:
United States of America
United Arab Emirates
India
China
Australia
Canada
Germany
Poland
Brazil
Afghanistan
В основном у меня проблемы с Индией и Китаем. Я создал шумные данные для каждой страны, используя скрипт.
например, это несколько записей, я тренируюсь с большим количеством.
unitei arab emirates,united arab emirates
dnited arab emirates,united arab emirates
united amrab emirates,united arab emirates
indea,india
igndia,india
chinab,china
chpina,china
chrna,china
cmina,china
lchina,china
Скрипт для создания зашумленных данных:
def create_noisy_data(word):
letters = 'abcdefghijklmnopqrstuvwxyz'
splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
deletes = [L + R[1:] for L, R in splits if R]
transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R)>1]
replaces = [L + c + R[1:] for L, R in splits if R for c in letters]
inserts = [L + c + R for L, R in splits for c in letters]
return set(deletes + transposes + replaces + inserts)
Вышеупомянутая функция создает зашумленные данные для каждой страны, а в соответствии с длиной названия страны функция create_noisy_data создает зашумленные данные.
Она отлично работает, когда я тренируюсь только для:
India
China
Australia
Canada
Germany
Poland
Brazil
Afghanistan
или
United States of America
United Arab Emirates
Afghanistan