Я строю вариационный автоэнкодер (подробнее об этом здесь ) в tenorflow-gpu 2.0, используя высокоуровневый API Keras. Это моя архитектура до сих пор. Для простоты я опустил некоторые подробности о входных параметрах (...).
inputs = Input(...)
c1 = Conv2D(...)(inputs)
c1 = ReLU()(c1)
c2 = Conv2D(...)(c1)
c2 = ReLU()(c2)
c3 = Conv2D(...)(c2)
c3 = ReLU()(c3)
c4 = Conv2D(...)(c3)
c4 = ReLU()(c4)
z_mean = Conv2D(...)(c4)
z_log_var = Conv2D(...)(c4)
lat_var = Lambda(...)([z_mean, z_log_var]) # Sampling
dec1 = Conv2DTranspose(...)(lat_var)
dec1 = ReLU()(dec1)
dec2 = Conv2DTranspose(...)(dec1)
dec2 = ReLU()(dec2)
dec3 = Conv2DTranspose(...)(dec2)
dec3 = ReLU()(dec3)
dec4 = Conv2DTranspose(...)(dec3)
dec4 = ReLU()(dec4)
outputs = Conv2DTranspose(...)(dec4)
Мои потери определяются как среднее значение потерь на восстановление (MSE) между входами и выходами и расхождение Кульбака-Лейблера между z_mean
и z_log_var
:
def my_vae_loss(y_pred, y_gt):
mse_loss = mse(y_pred, y_gt)
# mse_loss *= original_dim
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(mse_loss + GLOBAL.beta*kl_loss)
return vae_loss
В настоящее время я планирую свою потерю после обучения сети с model.fit()
следующим образом:
model.compile(..., loss = vae_loss, ...)
history = model.fit(...)
fig1 = plt.figure()
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
Проблема в том, что я хотел бы построить график потерь MSE и KL-дивергенции отдельно для каждой эпохи , тогда как в настоящее время я строю среднее значение этих компонентов, и любая конкретная информация c потерял. Однако потери на обучение должны оставаться такими же.
Полагаю, мне следует использовать лямбда-колбэк и кешировать потери MSE и значения KL-div во время обучения и в конечном итоге построить их. Один из способов сделать это - взять inputs
, outputs
, z_mean
, z_var
после каждой эпохи, вычислить две потери и затем кэшировать значения. Проблема с этим подходом состоит в том, что я не нашел элегантного способа получить значения тензоров во время обучения в Керасе. Обычно, если мне нужно значение в промежуточном слое X, я загружаю уменьшенную версию модели, в которой X является последним слоем, и в основном делаю прогноз.
Любой, у кого есть элегантный подход для построения обоих компонентов потерь отдельно