список обратных вызовов keras генерирует ошибку: объект 'tuple' не имеет атрибута 'set_model' - PullRequest
2 голосов
/ 20 сентября 2019

Я пишу модель keras, в которой я хочу использовать несколько встроенных обратных вызовов keras, однако я, вероятно, совершаю грамматическую ошибку где-то, что не могу обнаружить.Часть кода, вызывающая у меня проблемы, выглядит следующим образом:

from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard
...
...
es = EarlyStopping(monitor='val_loss', min_delta=0.01, verbose=1, patience=5)
tb = TensorBoard(log_dir=logdir, write_graph=True, write_images=True, histogram_freq=0)
mc = ModelCheckpoint(filepath=filepath, save_best_only=True, monitor='val_loss', mode='min')

history = model.fit(X_train, y_train,
                    batch_size=batch_size,
                    epochs=n_epochs,
                    verbose=1,
                    validation_split=0.3,
                    callbacks=[es, tb, mc])

, однако при этом я получаю ошибку 'tuple' object has no attribute 'set_model'.Ссылаясь на этот другой вопрос, кажется, что проблема вызвана тем фактом, что es, tb уже являются кортежами на се и поэтому их размещение в списке (в вызове callbacks=[es, tb, mc]) вызывает ошибку.На самом деле

print(type(es))
print(type(tb))
print(type(mc))

<class 'tuple'>
<class 'tuple'>
<class 'keras.callbacks.ModelCheckpoint'>

При этом сказано, я не понимаю, как обойти это.EarlyStopping и TensorBoard возвращают кортежи, как они должны вызываться в списке обратных вызовов keras?

Ответы [ 2 ]

1 голос
/ 20 сентября 2019

Распакуйте ваши кортежи - в этом случае все просто: (object,)[0] == object - но в целом у вас может быть (object1, object2) и т. Д., С которыми вы можете справиться через callbacks=[*es, *tb, *mc].

* распаковывает итерируемый - как демонстрацию:

def print_unpacked(*positional_args):
    print(positional_args)
    print(*positional_args)
a = 1
b = ('dog',5)
print_unpacked(a,b)
# >> (1, ('dog',5))
# >> 1 ('dog',5)
print(a,b)
# >> 1 ('dog',5)
print(a,*b)
# >> 1 'dog' 5
0 голосов
/ 20 сентября 2019

Вы должны удалить запятую в конце следующих строк в вашем коде, размещенном выше

es = EarlyStopping(monitor='val_loss', min_delta=0.01, verbose=1, patience=5),
tb = TensorBoard(log_dir=logdir, write_graph=True, write_images=True, histogram_freq=0),
...