KMNIST: RuntimeError: вывод с формой [1, 28, 28] не соответствует форме трансляции [3, 28, 28] - PullRequest
0 голосов
/ 08 мая 2019

Во время выполнения кода Wide Resnet я получил ошибку во время выполнения. RuntimeError: вывод с формой [1, 28, 28] не соответствует форме трансляции [3, 28, 28] Я испробовал несколько решений, доступных онлайн, но ни одно из них не исправлено, у всех были другие проблемы. Я не знаю, как это исправить. Все связанные ошибки времени выполнения отображаются в коде.

  elif args.data == 'kmnist':
      normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                       std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

      if args.data_augmentation:
          transform_train = transforms.Compose([
              transforms.RandomCrop(32, padding=4),
              transforms.RandomHorizontalFlip(),
              transforms.ToTensor(),
              normalize,
          ])
      else:
          transform_train = transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize([0.5], [0.5])
          ])

      # If following portion, it would be another runtime error
      # RuntimeError: Given groups=1, weight of size 16 3 3 3, expected 
      #input[128, 1, 28, 28] to have 3 channels, but got 1 channels instead
      transform_test = transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize([0.5], [0.5])
      ])


      # If I tried following portion, it would be 
      # RuntimeError: output with shape [1, 28, 28] doesn't match the 
      # broadcast shape [3, 28, 28]
      # transform_test = transforms.Compose([
      #     transforms.ToTensor(),
      #     normalize
      # ])

      # If I tried following portion of the code, I received
      # AttributeError: Can't pickle local object 'get_data_loaders.<locals>.<lambda>'
      # transform_test = transforms.Compose([
      #     transforms.ToTensor(),
      #     transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
      #     transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
      # ])

      kwargs = {'num_workers': 1, 'pin_memory': True}
      train_loader = torch.utils.data.DataLoader(
          KMNISTRandomLabels(root='./kmnistdata', train=True, download=True,
                            transform=transform_train, num_classes=args.num_classes,
                            corrupt_prob=args.label_corrupt_prob),
          batch_size=args.batch_size, shuffle=shuffle_train, **kwargs)
      val_loader = torch.utils.data.DataLoader(
          KMNISTRandomLabels(root='./kmnistdata', train=False,
                            transform=transform_test, num_classes=args.num_classes,
                            corrupt_prob=args.label_corrupt_prob),
          batch_size=args.batch_size, shuffle=False, **kwargs)

      return train_loader, val_loader
"""
Fashion-MNIST dataset, with support for random labels
"""
import numpy as np

import torch
import torchvision.datasets as datasets


class FashionMNISTRandomLabels(datasets.FashionMNIST):
  """Fashion-MNIST dataset, with support for randomly corrupt labels.

  Params
  ------
  corrupt_prob: float
    Default 0.0. The probability of a label being replaced with
    random label.
  num_classes: int
    Default 10. The number of classes in the dataset.
  """
  def __init__(self, corrupt_prob=0.0, num_classes=10, **kwargs):
    super(FashionMNISTRandomLabels, self).__init__(**kwargs)
    self.n_classes = num_classes
    if corrupt_prob > 0:
      self.corrupt_labels(corrupt_prob)

  def corrupt_labels(self, corrupt_prob):
    labels = np.array(self.train_labels if self.train else self.test_labels)
    np.random.seed(12345)
    mask = np.random.rand(len(labels)) <= corrupt_prob
    rnd_labels = np.random.choice(self.n_classes, mask.sum())
    labels[mask] = rnd_labels
    # we need to explicitly cast the labels from npy.int64 to
    # builtin int type, otherwise pytorch will fail...
    labels = [int(x) for x in labels]

    if self.train:
      self.train_labels = labels
    else:
      self.test_labels = labels
...