Как настроить ячейку RNN - PullRequest
0 голосов
/ 25 июня 2018

Я хотел бы реализовать пользовательскую ячейку LSTM или GRU в TensorFlow (Python 3).Например, я хочу масштабировать сигнал состояния ячейки от ячейки на шаге времени T, прежде чем войти в ячейку на шаге времени T + 1.Я пытался искать в документации TensorFlow без успеха.Не могли бы вы дать мне подсказку?Спасибо.

РЕДАКТИРОВАТЬ
Проверяя ответ, заданный @ vijay m , я создаю свою модель следующим образом:

def dynamic_scale_RNN(x, timescale, seqlen, weights, biases, keep_prop):
    batch_size = tf.shape(x)[0]

    # Unstack to get a list of 'n_steps' tensors of shape (batch_size, n_input)
    x = tf.unstack(x, max_seq_len, 1)
    timescale_unstack = tf.unstack(timescale, max_seq_len, 1)

    gru_cell = tf.contrib.rnn.GRUCell(n_hidden)

    #init_state has to be set to zero
    init_state = gru_cell.zero_state(batch_size, dtype=tf.float32)

    outputs = []
    # Create a loop of N LSTM cells, N = time_steps.
    for i in range(len(x)):
        output, state= tf.nn.static_rnn(gru_cell, [x[i]], dtype=tf.float32, initial_state= init_state)
        # copy the init_state with the new state
        mask = tf.tile(tf.expand_dims(timescale_unstack[i],axis=1),[1,state[0].get_shape()[-1]])
        init_state = tf.multiply(state,mask)
        # init_state = state
        outputs.append(output)

    # Transform the output to [batch_size, time_steps, vector_size]        
    outputs = tf.transpose(tf.squeeze(tf.stack(outputs)), [1, 0, 2])

В приведенном выше коде шкала времени представляет собой тензор формы [batch_size, sequence_length, 1], и я хочу масштабировать состояние ячейки, используя этот тензор.Несмотря на то, что код может выполняться, он возвращает nan для функции стоимости.Если я раскомментирую строку init_state = state, она будет работать, но не масштабирует состояние ячейки.

На данный момент у меня возникает вопрос: почему я получаю nan значения для функции стоимости?

1 Ответ

0 голосов
/ 27 июня 2018

Я оставляю здесь свой ответ на случай, если он кому-нибудь поможет.Причиной стоимости стоимости «nan» является то, что init_state установлен слишком высоко.Хотя я не знаю подходящий диапазон для этого значения, я могу заметить, что если я масштабирую его с небольшим коэффициентом, например, 0,1, я больше не вижу 'nan'.

...