Я использую следующую функцию tf.function () для обновления значения ранее назначенной переменной V0
V0 = tf.Variable(tf.ones((20, 20, 11)))
@tf.function(experimental_compile=True)
def iterate():
for count in range(11):
< Do a bunch of stuff to calculate V0_update >
V0[:,:,count].assign(V0_update)
Впоследствии я хочу перебирать указанную выше функцию до схождения V0.
V_ = V0.numpy()
for counter in range(MAX_ITER):
iterate()
V_new_ = V0.numpy()
dist = np.abs(V_ - V_new_).max()
V_ = V_new_
if dist < tol:
print('\n\nConverged!')
break
Кажется, что узким местом выше является V0 [:,:, count] .assign (V0_update), что значительно замедляет каждую итерацию. Более того, когда я запускаю функцию iterate () несколько раз (для l oop выше), она сначала работает очень быстро (0,001 с), но вскоре для выполнения требуется до 4-5 se c.
Есть идеи, как я могу улучшить этап назначения или структуру кода для более быстрого выполнения / сходимости? Я попытался добавить for-l oop в функцию iterate (), но это не помогло. Спасибо.
Дополнительные сведения: Я запускаю TF2.0 в Google Colab с использованием графического процессора и среды выполнения High-RAM. Я новичок в Tensor Flow и не слишком осведомлен о его структуре.