У меня есть кодировщик и сеть прокси, которые помогают кодировщику максимизировать информацию между его входом (изображение) и выходом (вектор признаков изображения). чтобы сделать это, я использовал функцию потерь, которая оценивает MI, и оптимизатором веса обеих сетей обновляются с вычисленными потерями, но я не уверен, что это сделано правильно или нет. Я использовал следующий код (в pytorch):
# Clear the previous gradients
discriminator_net_optim.zero_grad()
encoder_net_optim.zero_grad()
autograd.backward(loss)
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 2)
torch.nn.utils.clip_grad_norm_(encoder.parameters(), 2)
# adjust weights in discriminator and encoder
discriminator_net_optim.step()
encoder_net_optim.step()
любая помощь или предложение приветствуется.