Сохранить и загрузить модель как .pb Используя TFLearn? - PullRequest
0 голосов
/ 10 сентября 2018

Я обучил модель DNN с помощью tflearn и сохранил ее как файл контрольных точек, и модель показала хорошие результаты. Мне нужно заморозить модель, чтобы потом использовать ее на устройстве Android.

Я искал способ заморозить мою модель напрямую из TFLearn, но, очевидно, tflearn не предоставил такой функции на данный момент, поэтому я должен сделать это с тензорным потоком, используя этот код , чтобы заморозить мою модель в файл protobuf. И я также получил свое имя выходного узла из тензорной доски, используя this (тот же пользователь), потому что использование имени выходного узла, которое я явно дал в моей модели, по-видимому, не работает в TFLearn и просто замерзло 0 переменные из него.

Казалось, что он работает нормально, когда я успешно заморозил его в файл .pb, но когда я загружаю этот замороженный график в другое время выполнения и делаю прогноз из него, используя model.predict(), он дал мне очень неправильное предсказание, как будто моя модель еще даже не тренировался.

Это моя архитектура NN.

convnet = input_data(shape=[None, IMG_SIZE, IMG_SIZE, 3], name='masuk')

convnet = conv_2d(convnet, 40, 2, activation='relu')
convnet = max_pool_2d(convnet, 4)

convnet = conv_2d(convnet, 40, 2, activation='relu')
convnet = max_pool_2d(convnet, 4)

convnet = conv_2d(convnet, 80, 2, activation='relu')
convnet = max_pool_2d(convnet, 4)

convnet = conv_2d(convnet, 160, 2, activation='relu')
convnet = max_pool_2d(convnet, 4)

convnet = fully_connected(convnet, 500, activation='relu')
convnet = dropout(convnet, 0.75)

convnet = fully_connected(convnet, 32, activation='softmax')
convnet = regression(convnet, optimizer='adam', learning_rate=LR, loss='categorical_crossentropy', name='keluar')

model = tflearn.DNN(convnet, tensorboard_dir='log')

Это код, который я использовал, чтобы попытаться загрузить мой файл .pb.

with tf.Session() as sess:
    with gfile.FastGFile(MODEL_NAME + '.frozen_graph.pb','rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')

Это для моего задания, и я в некотором роде новичок в tenorflow (и в машинном обучении в целом), так что мне здесь чего-то не хватает, чтобы заставить его работать?

...