Использование встраивания поиска и мой градиент Нет - PullRequest
0 голосов
/ 24 апреля 2020
import os
import sys
import numpy as np
import tensorflow as tf

from config import config
from model.TransE import * 
from data.DataLoader import *
def score(h, t, r, score_type):

if score_type == 'l1':
    score = tf.reduce_sum(tf.abs(h+r-t), axis=1)
elif score_type == 'l2':
    score = tf.sqrt(tf.reduce_sum(tf.square(h+r-t), axis=1))

return score

if __name__ == '__main__':

# set config
params = config()

# import data
d = DataLoader()
it = d.get_batch('data/FB15k/train2id.txt')

model = TransE()
optimizer = tf.keras.optimizers.SGD(learning_rate = params.learning_rate)

print ('-----------Start Training----------')
epoch_loss = 0.0
step = 0

bound = 6 / math.sqrt(params.entity_dim)

entity_embedding = tf.Variable(name = 'entity_embedding',
                initial_value=tf.random.uniform(shape = [params.entity_size, params.entity_dim], minval=-bound, maxval=bound), trainable=True)
entity_embedding = tf.nn.l2_normalize(entity_embedding, axis=1)

relation_embedding = tf.Variable(name = 'relation_embedding',
                    initial_value=tf.random.uniform(shape = [params.relation_size, params.relation_dim], minval=-bound, maxval=bound), trainable=True)
relation_embedding = tf.nn.l2_normalize(relation_embedding, axis=1)

for i in range(5):
    with tf.GradientTape() as tape:
        data = it.next()
        h = tf.nn.embedding_lookup(entity_embedding, data[:,0])
        t = tf.nn.embedding_lookup(entity_embedding, data[:,1])
        r = tf.nn.embedding_lookup(relation_embedding, data[:,2])
        h_neg = tf.nn.embedding_lookup(entity_embedding, data[:,3])
        t_neg = tf.nn.embedding_lookup(entity_embedding, data[:,4])
        batch_loss = tf.reduce_sum(tf.maximum(0.0, params.margin + score(h, t, r, 'l2') - score(h_neg, t_neg, r, 'l2')))
        epoch_loss = epoch_loss + batch_loss
        step = step + 1

    grads = tape.gradient(batch_loss,[entity_embedding, relation_embedding])
    print (grads)

Проблема, с которой я столкнулся, заключается в том, что когда я печатаю грады, получается результат [Нет, нет]. Таким образом, я не могу обновить свои переменные. Однако после того, как я удалил

entity_embedding = tf.nn.l2_normalize(entity_embedding, axis=1)
relation_embedding = tf.nn.l2_normalize(relation_embedding, axis=1)

, появятся грады, и я смогу обновить свои переменные. Я действительно хочу знать, что вызвало это, и могу ли я использовать функцию l2_normalize и все же обновлять свои переменные.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...