ValueError: Ожидаемый ввод batch_size (633) для соответствия целевому batch_size (1024) - PullRequest
0 голосов
/ 04 апреля 2020
ValueErrorTraceback(most recent call last) <ipython-input-48-6571c3f05123> in <module>
     37         _, domain_pred = model(X_t, grl_lambda)
     38         #Calculating the domain loss for target data
---> 39         loss_t_domain = loss_fn_domain(domain_pred, y_t_domain)
     40         # Calculating total loss
     41         loss = loss_t_domain + loss_s_domain + loss_s_label

ValueError: Expected input batch_size (633) to match target batch_size
(1024).

Я получаю эту ошибку для следующего кода:

class DACNN(nn.Module):
   def __init__(self):
       super().__init__()
       self.feature_extractor = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=5),
        nn.BatchNorm2d(64), nn.MaxPool2d(2),
        nn.ReLU(True),
        nn.Conv2d(64, 50, kernel_size=5),
        nn.BatchNorm2d(50), nn.Dropout2d(), nn.MaxPool2d(2),
        nn.ReLU(True),
    )
    self.class_classifier = nn.Sequential(
        nn.Linear(50 * 4 * 4, 100), nn.BatchNorm1d(100), nn.Dropout2d(),
        nn.ReLU(True),
        nn.Linear(100, 100), nn.BatchNorm1d(100),
        nn.ReLU(True),
        nn.Linear(100, 10),
        nn.LogSoftmax(dim=1),
    )
    self.domain_classifier = nn.Sequential(
        nn.Linear(50 * 4 * 4, 100), nn.BatchNorm1d(100),
        nn.ReLU(True),
        nn.Linear(100, 2),
        nn.LogSoftmax(dim=1),
    )

    def forward(self, x, grl_lambda=1.0):
        x = x.expand(x.data.shape[0], 3, image_size, image_size)
        features = self.feature_extractor(x)
        features = features.view(-1,50 * 4 * 4)
        reverse_features = GradientReversalFn.apply(features, grl_lambda)
        class_pred = self.class_classifier(features)
        # Giving features that are passed through GRL to domain classifier
        domain_pred = self.domain_classifier(reverse_features)
        return class_pred, domain_pred

Пожалуйста, помогите мне в решении этой проблемы.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...