Мой слой похож на этот (я делаю слой LSTM с выпадением, применяемым на каждом временном шаге, вход пропускается 10 раз и возвращается среднее значение выходных данных)
import torch
from torch import nn
class StochasticLSTM(nn.Module):
def __init__(self, input_size: int, hidden_size: int, dropout_rate: float):
"""
Args:
- dropout_rate: should be between 0 and 1
"""
super(StochasticLSTM, self).__init__()
self.iter = 10
self.input_size = input_size
self.hidden_size = hidden_size
if not 0 <= dropout_rate <= 1:
raise Exception("Dropout rate should be between 0 and 1")
self.dropout = dropout_rate
self.bernoulli_x = torch.distributions.Bernoulli(
torch.full((self.input_size,), 1 - self.dropout)
)
self.bernoulli_h = torch.distributions.Bernoulli(
torch.full((hidden_size,), 1 - self.dropout)
)
self.Wi = nn.Linear(self.input_size, self.hidden_size)
self.Ui = nn.Linear(self.hidden_size, self.hidden_size)
self.Wf = nn.Linear(self.input_size, self.hidden_size)
self.Uf = nn.Linear(self.hidden_size, self.hidden_size)
self.Wo = nn.Linear(self.input_size, self.hidden_size)
self.Uo = nn.Linear(self.hidden_size, self.hidden_size)
self.Wg = nn.Linear(self.input_size, self.hidden_size)
self.Ug = nn.Linear(self.hidden_size, self.hidden_size)
def forward(self, input, hx=None):
"""
input shape (sequence, batch, input dimension)
output shape (sequence, batch, output dimension)
return output, (hidden_state, cell_state)
"""
T, B, _ = input.shape
if hx is None:
hx = torch.zeros((self.iter, T + 1, B, self.hidden_size), dtype=input.dtype)
else:
hx = hx.unsqueeze(0).repeat(self.iter, T + 1, B, self.hidden_size)
c = torch.zeros((self.iter, T + 1, B, self.hidden_size), dtype=input.dtype)
o = torch.zeros((self.iter, T, B, self.hidden_size), dtype=input.dtype)
for it in range(self.iter):
# Dropout
zx = self.bernoulli_x.sample()
zh = self.bernoulli_h.sample()
for t in range(T):
x = input[t] * zx
h = hx[it, t] * zh
i = torch.sigmoid(self.Ui(h) + self.Wi(x))
f = torch.sigmoid(self.Uf(h) + self.Wf(x))
o[it, t] = torch.sigmoid(self.Uo(h) + self.Wo(x))
g = torch.tanh(self.Ug(h) + self.Wg(x))
c[it, t + 1] = f * c[it, t] + i * g
hx[it, t + 1] = o[it, t] * torch.tanh(c[it, t + 1])
o = torch.mean(o, axis=0)
c = torch.mean(c[:, 1:], axis=0)
hx = torch.mean(hx[:, 1:], axis=0)
return o, (hx, c)
Когда я оптимизирую сеть, у меня появляется ошибка one of the variables needed for gradient computation has been modified by an inplace operation
. Мы можем определить несколько операций на месте, таких как o[it, t] = torch.sigmoid(self.Uo(h) + self.Wo(x))
.
Как мне избежать этой операции на месте, когда я хочу найти среднее значение?
Спасибо