Обучение и тестирование CNN с pytorch.С и без model.eval () - PullRequest
0 голосов
/ 02 мая 2019

У меня есть два вопроса: -

  1. Я пытаюсь обучить сверточную нейронную сеть, инициализированную с некоторыми предварительно обученными весами (Netwrok также содержит слои пакетной нормализации) (с учетом ссылки из здесь ).Перед тренировкой я хочу вычислить ошибку проверки, используя loss_fn = torch.nn.MSELoss().cuda().И в ссылке, автор использует model.eval() перед вычислением ошибки валидации.Но с этим результатом модель CNN отличается от того, какой она должна быть, однако, когда я закомментирую model.eval(), результат будет хорошим (что должно быть с предварительно обученными весами).Что может быть причиной этого, поскольку я прочитал во многих постах, что model.eval следует использовать перед тестированием модели и model.train() перед ее обучением.

  2. При вычислении ошибки проверки с помощьюпредварительно обученные веса и вышеупомянутая функция потерь должны соответствовать размеру партии.Разве это не должно быть 1, так как я хочу выводить на каждом из моих входных данных, вычислять ошибку с истинной землей и в итоге брать среднее значение всех результатов.Если я использую более высокий размер пакета, ошибка увеличивается.Так что вопрос в том, могу ли я использовать более высокий размер партии, если да, что должно быть правильным способом.В данном коде я дал err = float(loss_local) / num_samples, но наблюдал без усреднения, т.е. err = float(loss_local).Ошибка отличается для разных размеров партии.Я делаю это без model.eval прямо сейчас.

    batch_size = 1
    data_path = 'path_to_data'
    dtype = torch.FloatTensor
    weight_file = 'path_to_weight_file'
    val_loader = torch.utils.data.DataLoader(NyuDepthLoader(data_path, val_lists),batch_size=batch_size, shuffle=True, drop_last=True)
    model = Model(batch_size)
    model.load_state_dict(load_weights(model, weight_file, dtype))
    loss_fn = torch.nn.MSELoss().cuda()
    # model.eval()

    with torch.no_grad():
        for input, depth in val_loader:
            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)

            input_rgb_image = input_var[0].data.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
            input_gt_depth_image = depth_var[0][0].data.cpu().numpy().astype(np.float32)
            pred_depth_image = output[0].data.squeeze().cpu().numpy().astype(np.float32)
            print (format(type(depth_var)))
            pred_depth_image_resize = cv2.resize(pred_depth_image, dsize=(608, 456), interpolation=cv2.INTER_LINEAR)
            target_depth_transform = transforms.Compose([flow_transforms.ArrayToTensor()])
            pred_depth_image_tensor = target_depth_transform(pred_depth_image_resize)
            #both inputs to loss_fn are 'torch.Tensor'
            loss_local += loss_fn(pred_depth_image_tensor, depth_var)

            num_samples += 1
            print ('num_samples {}'.format(num_samples))

    err = float(loss_local) / num_samples
    print('val_error before train:', err)

1 Ответ

2 голосов
/ 02 мая 2019

Что может быть причиной этого, поскольку я прочитал во многих постах, что model.eval должен использоваться перед тестированием model и model.train () перед его обучением.

Примечание: тестирование модели называется выводом.

Как объяснено в официальной документации :

Помните, что вы должны вызвать model.eval(), чтобы установить выпадающий и пакетный уровни нормализации в режим оценки перед запуском вывода. Невыполнение этого приведет к противоречивым результатам вывода.

Таким образом, этот код должен присутствовать, как только вы загрузите модель из файла и сделаете вывод.

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

Это потому, что отсев работает как регуляризация для предотвращения переоснащения во время тренировки, он не нужен для вывода. То же самое для норм партии. При использовании eval() это просто устанавливает метку последовательности модулей на False и влияет только на определенные типы модулей, в частности Dropout и BatchNorm.

...