делать прогнозы, используя контрольную точку cntk - PullRequest
0 голосов
/ 23 октября 2018

В эти дни я попробовал модель, реализованную cntk.Но я не могу найти способ предсказать новую картинку с обученной моделью.Обученная модель сохранена в качестве контрольной точки:

trainer.save_checkpoint(os.path.join(output_model_folder, "model_{}".format(best_epoch)))

Затем я получил несколько файлов, таких как:

enter image description here

Итак, я попыталсячтобы загрузить эту модель контрольной точки, например:

model = ct.load_model('../data/models/VGG13_majority/model_94')

приведенный выше код может успешно выполняться.Затем я попытался

model.eval(image_data)

, но получил ошибку: enter image description here

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ update ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

на этот раз я попробовал метод ниже:

model = ct.load_model('../data/models/VGG13_majority/model_94')
model.eval({model.arguments[0]: [final_image]})

затем возникла новая ошибка:

enter image description here

1 Ответ

0 голосов
/ 23 октября 2018

Для любого C.Function.eval () вам нужно передать словарь в качестве аргумента.

Таким образом, это будет выглядеть примерно так, при условии, что в модели есть только одна переменная input_variable:

model = C.load_model()
model.eval({model.arguments[0]: image_data})

Во всяком случае, я заметил, что вы сохранили модель с контрольной точки.Таким образом, вы действительно сохранили input_variable "ground_truth" и в функции потерь.

Я бы порекомендовал в следующий раз напрямую сохранить модель.Обычно файлы из save_checkpoint предназначены для использования в restore_from_checkpoint ()

import cntk as C
from cntk.layers import Dense

model = Dense(10)(C.input_variable(1))
loss = C.binary_cross_entropy(model, C.input_variable(10))

trainer = C.Trainer(model, (loss,), [C.adam(model.parameters, 0.9, 0.9)])
trainer.save_checkpoint("hello")
model.save()  # used this to save the model directly

# to recover model from checkpoint use below
trainer.restore_from_checkpoint("hello")
original_model = trainer.model
print(trainer)
for i in trainer.model.arguments:
    print(i)
...