Я новичок в PyTorch и написал следующий код. Я использую одну нейронную сеть для части кодирования с последующей обработкой закодированных данных, а затем использую другую сеть для декодирования этих данных.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import numpy as np
import itertools
import datetime
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.fc1 = nn.Linear(4, 32, bias=False)
self.fc2 = nn.Linear(32, 16, bias=False)
self.fc3 = nn.Linear(16, 7, bias=False)
def forward(self, x):
x = self.fc1(x)
x = torch.tanh(x)
x = self.fc2(x)
x = torch.tanh(x)
x = self.fc3(x)
output = torch.tanh(x)
return output
def channel(codeword, snr_db, device):
snr_value = 10 ** (snr_db / 10)
h_real = torch.normal(mean=0, std=1, size=(codeword.shape[0], 1)) * torch.sqrt(torch.as_tensor(1/2))
h_imag = torch.normal(mean=0, std=1, size=(codeword.shape[0], 1)) * torch.sqrt(torch.as_tensor(1/2))
h_real_t = h_real.repeat(1, codeword.shape[1]).to(device)
h_imag_t = h_imag.repeat(1, codeword.shape[1]).to(device)
noise_real = torch.normal(mean=0, std=1, size=codeword.shape) * torch.sqrt(torch.as_tensor(1/(2*snr_value)))
noise_imag = torch.normal(mean=0, std=1, size=codeword.shape) * torch.sqrt(torch.as_tensor(1/(2*snr_value)))
noise_real = noise_real.to(device)
noise_imag = noise_imag.to(device)
faded_cw_real = torch.mul(h_real_t, codeword) + noise_real
faded_cw_imag = torch.mul(h_imag_t, codeword) + noise_imag
return torch.cat([faded_cw_real[:, :, None], faded_cw_imag[:, :, None]], dim=2), h_real, h_imag
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.bigru = nn.GRU(input_size=7, hidden_size=200, num_layers=20, bidirectional=True)
self.fc0 = nn.Linear(4, 1)
self.fc1 = nn.Linear(400, 4)
def forward(self, x):
x, states = self.bigru(x)
output = torch.squeeze(self.fc0(torch.transpose(x, 2, 1)))
output = torch.tanh(output)
output = self.fc1(output)
output = torch.tanh(output)
# output = torch.softmax(output, dim=0)
return output
def train(args, model1, model2, device, optimizer, epoch, snr):
model1.train()
model2.train()
count = 1000
for i in range(count):
data = np.array([list(i) for i in itertools.product([-1, 1], repeat=4)])
p = np.random.permutation(16)
# p = np.random.randint(low=0, high=16, size=(16,))
train_data = data[p]
data_one_hot = np.eye(16)
truth = data_one_hot[p]
# truth = torch.as_tensor(truth).to(device).float() # Uncomment this for BCE loss
train_data = torch.as_tensor(train_data).float()
train_data = train_data.to(device)
# optimizer1.zero_grad()
optimizer.zero_grad()
output = model1(train_data)
output = output.to(device)
ch_out, h_r, h_i = channel(output, snr, device)
h_r = torch.as_tensor(h_r[:, :, None].repeat(1, 7, 1)).to(device)
h_i = torch.as_tensor(h_i[:, :, None].repeat(1, 7, 1)).to(device)
dec_ip = torch.cat([ch_out, h_r, h_i], 2)
dec_ip = torch.transpose(dec_ip, 2, 1)
hat = model2(torch.as_tensor(dec_ip).float())
loss_d = F.mse_loss(hat, train_data)
# loss_d = F.binary_cross_entropy(hat, truth)
loss_d.backward()
optimizer.step()
if i % 10 == 0:
# print(f"Train epoch: {epoch}, Batch: {i}, Encoder Loss: {loss_e.item()}, SNR: {snr}")
print(f"Train epoch: {epoch}, Batch: {i}, Decoder Loss: {loss_d.item()}, SNR: {snr}")
def main():
epochs = 14
learning_rate = 1
learning_rate_step = 0.7
no_cuda = False
log_interval = 10
use_cuda = not no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
enc_model = Encoder().to(device)
dec_model = Decoder().to(device)
optimizer = optim.Adam(list(dec_model.parameters())+list(enc_model.parameters()), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=1, gamma=learning_rate_step)
for epoch in range(1, epochs+1):
snr = 20 - 20 * epoch / epochs
train(log_interval, enc_model, dec_model, device, optimizer, epoch, snr)
scheduler.step()
if __name__ == "__main__":
main()
Однако, когда я запускаю это, вывод: * 1004 Значения *
Train epoch: x, Batch: y, Decoder Loss: 2.0, SNR: z
x
, y
и z
зависят от итерации. Потеря декодера застряла на 2.0
. Фактически, потеря начинается с 2
, а затем застревает на 2.0
.
Что-то не так с кодом?