Как обрабатывать преобразования. FiveCrop изменение размера тензора - PullRequest
0 голосов
/ 10 июля 2020

Я пытаюсь добавить transforms.FiveCrop в свою модель. Я понимаю, что этот метод увеличения данных добавляет измерения к моему тензору, но я не уверен, как с этим справиться. Примечания к документации:

Это преобразование возвращает кортеж изображений, и может быть несоответствие в количестве входных данных и целевых объектов, возвращаемых вашим набором данных. См. Ниже пример того, как с этим справиться.

Пример

>>> transform = Compose([
>>>    FiveCrop(size), # this is a list of PIL Images
>>>    Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
>>> ])
>>> #In your test loop you can do the following:
>>> input, target = batch # input is a 5d tensor, target is 2d
>>> bs, ncrops, c, h, w = input.size()
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops

... Но я не уверен, как это реализовать.

Поезд l oop:

    for batch_idx, (data, target) in enumerate(train_loader):

        print(f'Data: {data.shape}, Target: {target.shape}')       
        # Before Fivecrop: Data: torch.Size([32, 3, 224, 224]), Target: torch.Size([32])
        # After Fivecrop: Data: torch.Size([32, 5, 3, 224, 224]), Target: torch.Size([32])

        indx_target = target.clone()
        data = data.to(train_config.device)
        target = target.to(train_config.device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

Может кто-нибудь помочь объяснить, как это реализовано в моем поезде l oop и что я не понимаю?

Спасибо

1 Ответ

0 голосов
/ 10 июля 2020

Теперь это имеет смысл. Здесь для всех, кому нужна помощь. Новый код выглядит так:

    for batch_idx, (data, target) in enumerate(train_loader):
        
        bs, ncrops, c, h, w = data.size()
        print(f'Data: {data.view(-1, c, h, w).shape}, Target: {target.shape}')

        indx_target = target.clone()
        data = data.to(train_config.device)
        target = target.to(train_config.device)
        
        optimizer.zero_grad()
        output = model(data.view(-1, c, h, w)) # fuse batch size and ncrops
        output_avg = output.view(bs, ncrops, -1).mean(1) # average the output over fivecrops
        loss = F.cross_entropy(output_avg, target)
        loss.backward()
        optimizer.step()
        
        batch_loss = np.append(batch_loss, [loss.item()])
        prob = F.softmax(output_avg, dim=1)
        pred = prob.data.max(dim=1)[1]  
        correct = pred.cpu().eq(indx_target).sum()
        accuracy = float(correct) / float(len(data))
        batch_acc = np.append(batch_acc, [accuracy])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...