Проблема набора данных Tensorflow на этапе вывода - PullRequest
0 голосов
/ 28 августа 2018

Я создал генерацию языка уровня символов с Tensorflow здесь . Я использовал tf.placeholder API, который согласно google docs :

Подача - наименее эффективный способ подачи данных в программу TensorFlow.

Я решил изменить свой код и заменить его новым TensroFlow API набора данных .

Я использовал from_generator для создания набора данных:

dataset = tf.data.Dataset.from_generator(gen, (tf.int32, tf.int32),
                                             (tf.TensorShape([None, None]),
                                              tf.TensorShape([None, None])))
self.iterator = dataset.make_initializable_iterator()
self.inp, self.target = self.iterator.get_next()

Как видно из приведенного выше кода, я использовал [None, None] для Tensorshape, чтобы придать модели большую универсальность. Во время тренировок все отлично. Но при выводе возникает некоторая проблема . В tf.placeholder API я использовал следующий код для генерации символов:

def inference(self):
    converter = utils.TextReader(filename=FLAGS.CONVERTER_PATH)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        samples = []
        new_state = sess.run(self.init_state)
        c = 12 # random starting token
        samples.append(c)

        for i in range(1000):
            x = np.zeros((1, 1))
            x[0, 0] = c
            feed_dict = {
                self.inp: x,
                self.init_state: new_state
            }
            preds, new_state = sess.run([self.prediction, self.final_state], feed_dict=feed_dict)
            c = utils.pick_top_n(preds, converter.vocab_size)
            samples.append(c)

        samples = np.array(samples)
        print(converter.arr_to_text(samples))

В API набора данных у меня нет tf.placeholder, чтобы кормить моего предыдущего персонажа. И когда я использую приведенный выше код, как и ожидалось, произошла следующая ошибка:

InvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [1,50] vs. shape[1] = [32,50]

На самом деле, модель использует ту же форму ввода ([32,50]), которую я использовал для обучения. Это не то, что я хочу (на самом деле я определяю TensorShape ([None, None]), чтобы справиться с этим, но не работает).

Как я могу исправить проблему с новым API набора данных?

Полный код .

...