требования к учащимся fasttai и прогнозирование партий - PullRequest
0 голосов
/ 12 ноября 2018

Ранее я обучал модель resnet34 с помощью библиотеки fastai и сохранил файл weights.h5. В последней версии fastai мне все еще нужно иметь непустой поезд и действительные папки, чтобы импортировать моего ученика и прогнозировать на тестовом наборе?

Кроме того, я в настоящее время перебираю каждое тестовое изображение и использую learn.predict_array, но есть ли способ прогнозирования в пакетах для тестовой папки?

Пример того, что я сейчас делаю только для загрузки / прогнозирования:

PATH = '/path/to/model/'
sz = 224
arch=resnet34
tfms = tfms_from_model(resnet34, sz, aug_tfms=transforms_side_on, max_zoom=1.1)
data = ImageClassifierData.from_paths(PATH, tfms=tfms, bs=64)
learn = ConvLearner.pretrained(arch, data, precompute=False)
learn.unfreeze()
learn.load('224_all')

imgs = sorted(glob(os.path.join(test_path, '*.jpg')))
preds = []
_,val_tfms = tfms_from_model(resnet34, 224)
for n, i in enumerate(imgs):
        im = val_tfms(open_image(i))[None]
        preds.append(1-np.argmax(learn.predict_array(im)[0]))

Должен быть более чистый способ сделать это сейчас, нет?

Ответы [ 2 ]

0 голосов
/ 04 марта 2019

В fastai теперь вы можете экспортировать и загружать ученика для прогнозирования в наборе тестов, не загружая не пустой набор обучения и проверки. Для этого следует использовать метод export и функцию load_learner (оба определены в basic_train).

В вашей текущей ситуации вам, возможно, придется загрузить своего ученика старым способом (с поездом / действительным набором данных), затем экспортировать его, и вы сможете использовать load_learner, чтобы делать свои прогнозы на вашем тестовом наборе.

Я оставлю ссылку на документацию:

- https://docs.fast.ai/basic_train.html#Deploying-your-model

Это должно прояснить любые последующие вопросы.

0 голосов
/ 20 февраля 2019
data = ImageClassifierData.from_paths(PATH, tfms=tfms, bs=64)
learn = ConvLearner.pretrained(arch, data, precompute=False)
learn.unfreeze()
learn.load('224_all')

preds = learn.predict(is_test=True)
...