Какова архитектура следующей модели в tenorflow? - PullRequest
0 голосов
/ 27 октября 2019

Код Follwing является частью гибридной архитектуры распознавания рукописного текста. Это модифицированная версия сообщения в блоге Создание системы распознавания рукописного текста с использованием TensorFlow . Эта модель имеет архитектуру, состоящую из CNN-BLSTM-CTC. setupRNN() настройки функций BLSTM part.

    def setupRNN(self, rnnIn4d):
        """ Create RNN layers and return output of these layers """
        rnnIn4d = tf.slice(rnnIn4d, [0, 0, 0, 0], [
                           self.batchSize, 100, 1, 512])
        rnnIn3d = tf.squeeze(rnnIn4d)

        # 2 layers of LSTM cell used to build RNN
        numHidden = 512
        cells = [tf.nn.rnn_cell.LSTMCell(
            numHidden, name='basic_lstm_cell') for _ in range(2)]
        stacked = tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True)

        # Bi-directional RNN
        # BxTxF -> BxTx2H
        ((forward, backward), _) = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=stacked, cell_bw=stacked, inputs=rnnIn3d, dtype=rnnIn3d.dtype)

        # BxTxH + BxTxH -> BxTx2H -> BxTx1X2H
        concat = tf.expand_dims(tf.concat([forward, backward], 2), 2)

        # Project output to chars (including blank): BxTx1x2H -> BxTx1xC -> BxTxC
        kernel = tf.Variable(tf.truncated_normal(
            [1, 1, numHidden*2, len(self.charList)+1], stddev=0.1))
        return tf.squeeze(tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME'), axis=[2])

Рисование графика вычислений в тензорной доске дает нам следующую модель. Но я не могу понять, откуда этот блок bidirectional_rnn выходит за пределы блока RNN . enter image description here Внутри моего RNN блока. enter image description here

CNN-BLSTM-CTC, соединенный отсюда

    def __init__(self, charList, decoderType=DecoderType.BestPath, mustRestore=False):
        self.charList = charList
        self.decoderType = decoderType
        self.mustRestore = mustRestore
        self.snapID = 0

        # CNN
        with tf.name_scope('CNN'):
            with tf.name_scope('Input'):
                self.inputImgs = tf.placeholder(tf.float32, shape=(
                    Model.batchSize, Model.imgSize[0], Model.imgSize[1]))
            cnnOut4d = self.setupCNN(self.inputImgs)

        # RNN
        with tf.name_scope('RNN'):
            rnnOut3d = self.setupRNN(cnnOut4d)

        # # Debuging CTC
        # self.rnnOutput = tf.transpose(rnnOut3d, [1, 0, 2])

        # CTC
        with tf.name_scope('CTC'):
            (self.loss, self.decoder) = self.setupCTC(rnnOut3d)
            self.training_loss_summary = tf.summary.scalar(
                'loss', self.loss)  # Tensorboard: Track loss

        # Optimize NN parameters
        with tf.name_scope('Optimizer'):
            self.batchesTrained = 0
            self.learningRate = tf.placeholder(tf.float32, shape=[])
            self.optimizer = tf.train.RMSPropOptimizer(
                self.learningRate).minimize(self.loss)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...