Понимание mxnet.image.ImageDetIter - PullRequest
1 голос
/ 05 марта 2019

Я изучаю инфраструктуру MXNet и пытаюсь запустить пример обнаружения объектов с помощью SSD: https://gluon.mxnet.io/chapter08_computer-vision/object-detection.html

Я использую графический процессор NVidia GTX 1050, 4 ГБ для обучения. Я работаю в тетради Jupyter. Версии: Python 3.6, MXNet 1.3.1.

В руководстве было сказано, что обучение с нуля занимает около 30 минут с одним графическим процессором. Я остановился через 3 часа. Модель обработала 24459 партий (размер партии 32), когда я прервал обучение. Размер всего набора данных составляет 87,7 МБ, что составляет менее 24459 * 32 * 256 * 256 (размер изображения 256х256). Я не могу понять, почему это может занять слишком много времени. Возможно, есть какие-то особенности image.ImageDetIter (например, тот, который никогда не останавливается сам по себе)?

1 Ответ

0 голосов
/ 06 марта 2019

Спасибо, что включили информацию о версии.Вы абсолютно правы - в MXNet 1.3.0 была ошибка, в которой ImageDetIter зацикливался бесконечно в вашем примере.Это было исправлено декабрь 2018 , и если вы обновитесь до MXNet 1.4.0, проблема не появится.Я подтвердил это, выполнив приведенный выше код.

Еще одно важное замечание, «Глубокое обучение - прямой допинг», устарело в пользу (Погружение в глубокое обучение] (d2l.ai). Содержаниеобновляется и используется для курса в MXNet. Вот соответствующая глава в книге.

Кроме того, видео с курса размещены здесь ,если вы хотите посмотреть их.

Что касается репро, я запустил и подтвердил, что это бесконечно зацикливалось в 1.3.x и исправлялось в 1.4.0.

train_iter = image.ImageDetIter(
        batch_size=1000, 
        data_shape=(3, data_shape, data_shape),
        path_imgrec='./data/pikachu_train.rec',
        path_imgidx='./data/pikachu_train.idx',
        #shuffle=True, 
        #mean=True,
        #rand_crop=1, 
        min_object_covered=0.95,
        last_batch_handle='pad',
        max_attempts=5)
train_iter.reset()
for i,data in enumerate(train_iter):    
    print((i+1)) # goes forever on 1.3.0 but not 1.4.0

Надеюсь, это поможет

Вишал

...