От модели Keras до обучаемого файла и контрольной точки - PullRequest
0 голосов
/ 16 мая 2019

Я хочу тренировать ПБ и контрольно-пропускной пункт из Кераса в Тензорлоу.

Мне уже удалось преобразовать модель Keras в Tensorflow pb и контрольную точку. И мне уже удалось сделать вывод. Но проблема в том, что я понятия не имею, что делать для тренировок. Эта модель Keras, похоже, не имеет обучающей операции, или я просто не знаю, какой вклад я должен подавать при обучении.

Этот код преобразует модели Keras в Tensorflow pb и контрольную точку.

from keras import backend as K
from keras.models import load_model
import tensorflow as tf

model = load_model('model/my_model.h5')

K.set_learning_phase(0) #0 : test, 1 : train

sess = K.get_session()

saver = tf.train.Saver()
saver.save(sess, 'keras/keras.ckpt')

sess.graph.as_default()
graph = sess.graph

with open('keras/keras.pb', 'wb') as f:
    f.write(graph.as_graph_def().SerializeToString())

Это код чтения pb и контрольная точка

def keras_model():
    sess = tf.Session()
    saver = tf.train.import_meta_graph('keras/keras.ckpt.meta')
    saver.restore(sess, "keras/keras.ckpt")

    sess.graph.as_default()
    graph = tf.get_default_graph()

    a = [x for x in tf.get_default_graph().get_operations() if x.type == "Placeholder"]
    #print(a)

    img = cv2.imread("data/wqds_backbead_0_3.png", cv2.IMREAD_COLOR)
    img = img[...,::-1] # bgr to rgb
    img = img.astype('float32')
    img = np.expand_dims(img, axis=0)

    INPUT1 = graph.get_tensor_by_name("input_1:0")
    OUTPUT1 = graph.get_tensor_by_name("softmax/Softmax:0")
    TARGET1 = graph.get_tensor_by_name("softmax_target:0")

    print(TARGET1)

    pred = sess.run(OUTPUT1, feed_dict={INPUT1: img})
    print(pred, pred.shape, pred.dtype)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...