Я обучил модель DNN с помощью tflearn и сохранил ее как файл контрольных точек, и модель показала хорошие результаты. Мне нужно заморозить модель, чтобы потом использовать ее на устройстве Android.
Я искал способ заморозить мою модель напрямую из TFLearn, но, очевидно, tflearn не предоставил такой функции на данный момент, поэтому я должен сделать это с тензорным потоком, используя этот код , чтобы заморозить мою модель в файл protobuf. И я также получил свое имя выходного узла из тензорной доски, используя this (тот же пользователь), потому что использование имени выходного узла, которое я явно дал в моей модели, по-видимому, не работает в TFLearn и просто замерзло 0 переменные из него.
Казалось, что он работает нормально, когда я успешно заморозил его в файл .pb, но когда я загружаю этот замороженный график в другое время выполнения и делаю прогноз из него, используя model.predict()
, он дал мне очень неправильное предсказание, как будто моя модель еще даже не тренировался.
Это моя архитектура NN.
convnet = input_data(shape=[None, IMG_SIZE, IMG_SIZE, 3], name='masuk')
convnet = conv_2d(convnet, 40, 2, activation='relu')
convnet = max_pool_2d(convnet, 4)
convnet = conv_2d(convnet, 40, 2, activation='relu')
convnet = max_pool_2d(convnet, 4)
convnet = conv_2d(convnet, 80, 2, activation='relu')
convnet = max_pool_2d(convnet, 4)
convnet = conv_2d(convnet, 160, 2, activation='relu')
convnet = max_pool_2d(convnet, 4)
convnet = fully_connected(convnet, 500, activation='relu')
convnet = dropout(convnet, 0.75)
convnet = fully_connected(convnet, 32, activation='softmax')
convnet = regression(convnet, optimizer='adam', learning_rate=LR, loss='categorical_crossentropy', name='keluar')
model = tflearn.DNN(convnet, tensorboard_dir='log')
Это код, который я использовал, чтобы попытаться загрузить мой файл .pb.
with tf.Session() as sess:
with gfile.FastGFile(MODEL_NAME + '.frozen_graph.pb','rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
Это для моего задания, и я в некотором роде новичок в tenorflow (и в машинном обучении в целом), так что мне здесь чего-то не хватает, чтобы заставить его работать?