Повторная реализация функции TF 1.0 sampled_softmax_loss для модели seq2seq в модели TF 2 Keras - PullRequest
11 голосов
/ 19 июня 2020

У меня есть код TF 1.0.1 модели seq2seq. Я пытаюсь переписать его с помощью Tensorflow Keras.

Код TF 1.0.1 имеет следующую архитектуру декодера:

with tf.variable_scope("decoder_scope") as decoder_scope:

    # output projection
    # we need to specify output projection manually, because sampled softmax needs to have access to the the projection matrix 

    output_projection_w_t = tf.get_variable("output_projection_w", [vocabulary_size, state_size], dtype=DTYPE)
    output_projection_w = tf.transpose(output_projection_w_t)
    output_projection_b = tf.get_variable("output_projection_b", [vocabulary_size], dtype=DTYPE)
    
    decoder_cell = tf.contrib.rnn.LSTMCell(num_units=state_size)
    decoder_cell = DtypeDropoutWrapper(cell=decoder_cell, output_keep_prob=tf_keep_probabiltiy, dtype=DTYPE)
    decoder_cell = contrib_rnn.MultiRNNCell(cells=[decoder_cell] * num_lstm_layers, state_is_tuple=True)   
    
    # define decoder train netowrk
    decoder_outputs_tr, _ , _ = dynamic_rnn_decoder( 
        cell=decoder_cell, 
        decoder_fn= simple_decoder_fn_train(last_encoder_state, name=None),
        inputs=decoder_inputs, 
        sequence_length=decoder_sequence_lengths,
        parallel_iterations=None,
        swap_memory=False,
        time_major=False)
    
    # define decoder inference network
    decoder_scope.reuse_variables()    

Вот как выглядит sampled_softmax_loss вычислено:

decoder_forward_outputs = tf.reshape(decoder_outputs_tr,[-1, state_size])
decoder_target_labels  = tf.reshape(decoder_labels ,[-1, 1]) #decoder_labels is target sequnce of decoder

sampled_softmax_losses = tf.nn.sampled_softmax_loss(
    weights = output_projection_w_t,
    biases = output_projection_b,
    inputs = decoder_forward_outputs,
    labels = decoder_target_labels , 
    num_sampled = 500,
    num_classes=vocabulary_size,
    num_true = 1,
)    
total_loss_op = tf.reduce_mean(sampled_softmax_losses) 

И это мой декодер в Keras:

decoder_inputs = tf.keras.Input(shape=(None,), name='decoder_input')
emb_layer = tf.keras.layers.Embedding(vocabulary_size, state_size)
x_d = emb_layer(decoder_inputs)

d_lstm_layer = tf.keras.layers.LSTM(embed_dim, return_sequences=True)
d_lstm_out = d_lstm_layer(x_d, initial_state=encoder_states)

Это моя функция sampled_softmax_loss, которую я использую для Keras модель:

class SampledSoftmaxLoss(object):

  def __init__(self, model):
    self.model = model
    output_layer = model.layers[-1]
    self.input = output_layer.input
    self.weights = output_layer.weights

  def loss(self, y_true, y_pred, **kwargs):
    loss = tf.nn.sampled_softmax_loss(
        weights=self.weights[0],
        biases=self.weights[1],
        labels=tf.reshape(y_true ,[-1, 1]),
        inputs=tf.reshape(d_lstm_out,[-1, state_size]),
        num_sampled = 500,
        num_classes = vocabulary_size
    )

Но это не работает. Может ли кто-нибудь помочь мне правильно реализовать sampled_loss_funtion в Keras.

...