def make_generator_model(n_classes=62):
# Create class embedding channel
input_label = tf.keras.layers.Input(shape=(1,))
label_embedding = tf.keras.layers.Embedding(n_classes, 50)(input_label)
upscaling = tf.keras.layers.Dense(7*7*1)(label_embedding)
upscaling = tf.keras.layers.Reshape((7, 7, 1))(upscaling)
# create seed encoding network
seed_input = tf.keras.layers.Input(shape=(NOISE_DIM,))
seed_fc = tf.keras.layers.Dense(7*7*256, use_bias=False)(seed_input)
seed_fc = tf.keras.layers.BatchNormalization()(seed_fc)
seed_fc = tf.keras.layers.LeakyReLU()(seed_fc)
seed_fc = tf.keras.layers.Reshape((7, 7, 256))(seed_fc)
# merge embedding with seed encoder
merge = tf.keras.layers.Concatenate()([seed_fc, upscaling])
assert tuple(merge.shape) == (None, 7, 7, 256+1)
x = layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)(merge)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
assert tuple(x.shape) == (None, 7, 7, 128)
x = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
assert tuple(x.shape) == (None, 14, 14, 64)
# TODO tanh function with output between [-1, 1]
output = layers.Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')(x)
model = tf.keras.Model([seed_input, input_label], output)
assert model.output_shape == (None, 28, 28, 3)
return model
При попытке преобразовать приведенный выше код в PyTorch я столкнулся с проблемой размерности. Не могу понять, как добавить этот код кортежа утверждения. Я пробовал использовать функцию view (), но она не принимает "None" в качестве измерения. Вот что я попробовал позже. Заменено 1 на none in view (), но затем это дало ошибку в дискриминаторе.
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
#claas = label
#noise == seed == data
self.linear_data = nn.Linear(100, 256, bias = False)
self.BN_data = nn.BatchNorm1d(256)
self.label_embedding = nn.Embedding(62, 100)
self.linear_label = nn.Linear(100, 1, bias = False)
self.DCV1 = nn.ConvTranspose2d(257, 128, kernel_size =(4,4), stride = (2,2), padding = (1,1))
self.BN1 = nn.BatchNorm2d(128)
self.DCV2 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
self.BN2 = nn.BatchNorm2d(64)
self.DCV3 = nn.ConvTranspose2d(64, 3, 4, 2, 1)
def forward(self, inputs, labels):
'''
input size: (batch_size, 100, 1, 1)
outputs size: (batch_size, 1, 32, 32)
'''
# linear 1 for data
x = self.linear_data(inputs)
x = F.leaky_relu(x, negative_slope=0.2)
x = self.BN_data(x)
x = x.view(1,256,4,4)
#x = x.view(None, 4,4,256)
# linear 1 for label
y = self.label_embedding(labels)
#print(y.shape)
y = self.linear_label(y)
#print(y.shape)
y = y.view(1,1,4,4)
#x = x.view(None, 4,4,1)
#print(y.shape)
#concat data and label
x = torch.cat((x, y), dim=1)
x = x.view(1,257,4,4)
#print(x.shape)
#assert tuple(x.shape) == (1, 4, 4, 256+1)
# Deconv 1
x = self.DCV1(x)
x = self.BN1(x)
x = F.leaky_relu(x, 0.2)
# Deconv 2
x = self.DCV2(x)
x = self.BN2(x)
x = F.leaky_relu(x, 0.2)
# Deconv 3 + output
x = self.DCV3(x)
outputs = torch.tanh(x)
return outputs
Размер пакета составляет 16, а размер входного шума: 16,100 Размер входной метки: 16,1