import os
import sys
import numpy as np
import tensorflow as tf
from config import config
from model.TransE import *
from data.DataLoader import *
def score(h, t, r, score_type):
if score_type == 'l1':
score = tf.reduce_sum(tf.abs(h+r-t), axis=1)
elif score_type == 'l2':
score = tf.sqrt(tf.reduce_sum(tf.square(h+r-t), axis=1))
return score
if __name__ == '__main__':
# set config
params = config()
# import data
d = DataLoader()
it = d.get_batch('data/FB15k/train2id.txt')
model = TransE()
optimizer = tf.keras.optimizers.SGD(learning_rate = params.learning_rate)
print ('-----------Start Training----------')
epoch_loss = 0.0
step = 0
bound = 6 / math.sqrt(params.entity_dim)
entity_embedding = tf.Variable(name = 'entity_embedding',
initial_value=tf.random.uniform(shape = [params.entity_size, params.entity_dim], minval=-bound, maxval=bound), trainable=True)
entity_embedding = tf.nn.l2_normalize(entity_embedding, axis=1)
relation_embedding = tf.Variable(name = 'relation_embedding',
initial_value=tf.random.uniform(shape = [params.relation_size, params.relation_dim], minval=-bound, maxval=bound), trainable=True)
relation_embedding = tf.nn.l2_normalize(relation_embedding, axis=1)
for i in range(5):
with tf.GradientTape() as tape:
data = it.next()
h = tf.nn.embedding_lookup(entity_embedding, data[:,0])
t = tf.nn.embedding_lookup(entity_embedding, data[:,1])
r = tf.nn.embedding_lookup(relation_embedding, data[:,2])
h_neg = tf.nn.embedding_lookup(entity_embedding, data[:,3])
t_neg = tf.nn.embedding_lookup(entity_embedding, data[:,4])
batch_loss = tf.reduce_sum(tf.maximum(0.0, params.margin + score(h, t, r, 'l2') - score(h_neg, t_neg, r, 'l2')))
epoch_loss = epoch_loss + batch_loss
step = step + 1
grads = tape.gradient(batch_loss,[entity_embedding, relation_embedding])
print (grads)
Проблема, с которой я столкнулся, заключается в том, что когда я печатаю грады, получается результат [Нет, нет]. Таким образом, я не могу обновить свои переменные. Однако после того, как я удалил
entity_embedding = tf.nn.l2_normalize(entity_embedding, axis=1)
relation_embedding = tf.nn.l2_normalize(relation_embedding, axis=1)
, появятся грады, и я смогу обновить свои переменные. Я действительно хочу знать, что вызвало это, и могу ли я использовать функцию l2_normalize и все же обновлять свои переменные.