У меня есть эти обновления для Backprop, пожалуйста, дайте мне знать, где часть dx неверна. На графике вычислений я использую X, sample_mean и sample_var. Спасибо за вашу помощь
(x, norm, sample_mean, sample_var, gamma, eps) = cache
dbeta = np.sum(dout, axis = 0)
dgamma = np.sum(dout * norm, axis = 0)
dxminus = dout * gamma / np.sqrt(sample_var + eps)
dmean = - np.sum(dxminus, axis = 0)
dxmean = np.full(x.shape, 1.0/x.shape[0]) * dmean
dvar = np.sum(dout * gamma * (x - sample_mean), axis = 0)
dxvar = dvar * (-1 / x.shape[0]) * np.power(x, -1.5) * (x - sample_mean)
dx = dxminus + dxmean + dxvar
Вычислительный график BatchNorm, который я использовал для получения