Как мне отладить модель keras - PullRequest
0 голосов
/ 08 ноября 2019

Я прохожу учебник по распознаванию рукописного текста. И для распознавания рукописных цифр автор построил модель Keras следующим образом:

# # Creating CNN model

input_shape = (28,28,1)
number_of_classes = 10

model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),activation='relu',input_shape=input_shape))

model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPool2D(pool_size=(2, 2)))

model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))

model.add(Dropout(0.5))
model.add(Dense(number_of_classes, activation='softmax'))

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adadelta(),metrics=['accuracy'])

model.summary()

history = model.fit(X_train, y_train,epochs=5, shuffle=True,
                    batch_size = 200,validation_data= (X_test, y_test))


model.save('digit_classifier2.h5')

Источник ( здесь )

Я очень смущен тем, какАвтор выбирает эти слои. Я знаю, как Conv2D работает, применяя фильтры к изображению, я знаю, что такое activation function. Короче говоря, у меня есть грубое понимание того, что означает каждый термин.

Что мне трудно, так это как узнать, что происходит на каждом этапе этого кода? Например, давайте возьмем этот код Python:

values_List=[11,34,43]
for index, num in enumerate(values_List):
    print(index,num)
  1. Я знаю, что строка 1 инициализирует список с именем values_List
  2. Строка 2 перебирает этот список
  3. Строка 3выводит вывод как (индекс числа, числа)

Этот код Python прост для понимания и отладки. Но меня смущает, что если есть какие-то ошибки внутри слоев keras. Как мне перейти к отладке этого кода Keras? Как я вижу вывод на каждом шаге внутри кода Keras?

1 Ответ

0 голосов
/ 08 ноября 2019

Короче говоря, вы не можете легко отлаживать в Keras, потому что это высокоуровневый API, созданный для более быстрой и простой реализации архитектуры нейронной сети с использованием предопределенных слоев и функций, поэтому вероятность возникновения ошибок внутри этих слоев меньше илифункция причина это хорошо проверено.

Если вы хотите более детальный контроль над собой, вам необходимо реализовать его в низкоуровневом API, таком как Tensorflow v1, или использовать tf.GradientTape с tf-keras в TensorFlow v2, чтобы видеть градиенты на каждом шаге.

Вы также можете попробовать Tensorwatch от Microsoft для более глубокого понимания вашей модели - https://github.com/microsoft/tensorwatch

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...