CNN, Двунаправленная реализация LSTM с потерей CTC в тензорном потоке для распознавания текста - PullRequest
0 голосов
/ 27 июня 2018

Я пытаюсь реализовать идею исследовательской работы https://arxiv.org/pdf/1507.05717.pdf использование CNN, двунаправленной потери LSTM и CTC для прогнозирования текста на изображениях. Я нашел некоторые ресурсы на Github, который включает код в tenorflow, однако, поскольку я новичок в Tensorflow, мне было трудно понять их код. Может кто-нибудь, пожалуйста, предоставьте мне очень минималистичный и простой код о том, как реализовать идею, используя некоторый набор изображений и соответствующие им ярлыки, такие как активы, финансы и т. Д. Эти ярлыки имеют разную длину, но размер изображения зафиксирован на ширине = 200, высота = 20, глубина = 1.

Это код генерации текста:

data=[]
labels=[]
length=[]

def GenerateCharacters():  
    k = 1
    for filename in fonts:  
        font_resource_file = filename   
        for word in words:
            if len(word)>12:
                continue
            l=function(word)
            if l==-1:
              continue
            for font_size in font_sizes:
                font = ImageFont.truetype(font_resource_file, font_size)
                TEXT=word
                txt = TEXT
                (width, height) = font.getsize(txt)

                #New Image
                FOREGROUND = (255)
                background = Image.new('L', (width, height), color = 0)

                draw = ImageDraw.Draw(background)
                draw.text((0,0), txt, font = font, fill=FOREGROUND)

                w,h = background.size
                W = int(round(w/float(h)*22))
                req_im = np.asarray(background.resize((W, 22)))
                if(W<200):
                    req_im = np.hstack([req_im, np.zeros((22,200-W),dtype='uint8')])
                else:
                    req_im=Image.fromarray(req_im)
                    req_im=req_im.resize((200,22))
                    req_im=np.asarray(req_im)
                word_image=req_im
#                 outpath=out_dir+word
#                 if not os.path.exists(outpath):
#                     os.makedirs(outpath)
#                 file_name = os.path.join(outpath,str(k)+'.jpg')  
#                 cv2.imwrite(file_name,word_image)
                data.append(word_image)
                l=np.array(l)
                pad=np.full((1,12-l.shape[0]),-1)
                l=np.append(l,pad)
                labels.append(l)
                length.append(len(word))
                k = k + 1   
    return

Код для модели и графика:

def run_ctc():
  graph = tf.Graph()
  with graph.as_default():
    input_data=tf.placeholder(tf.float32,[None, height, width, depth])
    labelss=tf.sparse_placeholder(tf.int32)
    sequence_length=tf.placeholder(tf.int32, [None])

    conv1=tf.layers.conv2d(inputs=input_data,
        filters=64,
        kernel_size=[3, 3],
        strides=[1,1],
        padding="same",
        activation=tf.nn.relu)
    print "Conv1",conv1.shape
    bn1=tf.layers.batch_normalization(inputs=conv1,axis=-1)
    print "BN1",bn1.shape
    pool1=tf.layers.max_pooling2d(inputs=bn1,pool_size=[2,2],strides=[2,2])
    print "Pool1",pool1.shape
    conv2=tf.layers.conv2d(inputs=pool1,
        filters=128,
        kernel_size=[3, 3],
        strides=[1,1],
        padding="same",
        activation=tf.nn.relu)
    print "Conv2",conv2.shape
    bn2=tf.layers.batch_normalization(inputs=conv2,axis=-1)
    print "BN2",bn2.shape
    pool2=tf.layers.max_pooling2d(inputs=bn2,pool_size=[2,2],strides=[2,2])
    print "Pool2",pool2.shape
    conv3=tf.layers.conv2d(inputs=pool2,
        filters=256,
        kernel_size=[3, 3],
        strides=[1,1],
        padding="same",
        activation=tf.nn.relu)
    print "Conv3",conv3.shape
    bn3=tf.layers.batch_normalization(inputs=conv3,axis=-1)
    print "BN3",bn3.shape
    pool3=tf.layers.max_pooling2d(inputs=bn3,pool_size=[2,2],strides=[2,2])
    print "Pool3",pool3.shape
    conv4=tf.layers.conv2d(inputs=pool3,
        filters=512,
        kernel_size=[3, 3],
        strides=[1,1],
        padding="same",
        activation=tf.nn.relu)
    print "Conv4",conv4.shape
    bn4=tf.layers.batch_normalization(inputs=conv4,axis=-1)
    print "BN4",bn4.shape
    pool4=tf.layers.max_pooling2d(inputs=bn4,pool_size=[2,2],strides=[2,2])
    print "Pool4",pool4.shape
    features = tf.squeeze(pool4, axis=1, name='features')
    print "CNN features",features.shape

    rnn_ = tf.transpose(features, perm=[1, 0, 2], name='time_major')
    print "RNN_seq",rnn_.shape



    weight_initializer = tf.truncated_normal_initializer(stddev=0.01)
    with tf.variable_scope('forward'):
        cell_fw = tf.contrib.rnn.LSTMCell(512,initializer=weight_initializer)
    with tf.variable_scope('backward'):
        cell_bw = tf.contrib.rnn.LSTMCell(512,initializer=weight_initializer) 
    with tf.variable_scope('Bdrnn',reuse=tf.AUTO_REUSE):
      rnn_output,_ = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, rnn_,
        sequence_length=sequence_length,
        time_major=True,
      dtype=tf.float32)

    rnn_output_stack = tf.concat(rnn_output,2,name='output_stack')
    logit_activation = tf.nn.relu
    weight_initializer = tf.contrib.layers.variance_scaling_initializer()
    bias_initializer = tf.constant_initializer(value=0.0)
    logit_output = tf.layers.dense( rnn_output_stack, classes+1, 
                                        activation=logit_activation,
                                        kernel_initializer=weight_initializer,
                                        bias_initializer=bias_initializer)

    loss = tf.nn.ctc_loss(labelss,logit_output,sequence_length,time_major=True )
    total_loss = tf.reduce_mean(loss)
    optimizer = tf.train.MomentumOptimizer(learning_rate=0.005, momentum=0.9).minimize(total_loss)      
    decoded, log_prob = tf.nn.ctc_greedy_decoder(logit_output, sequence_length)
    ler = tf.reduce_mean(tf.edit_distance(tf.cast(decoded[0], tf.int32),labelss))


    with tf.Session(graph=graph) as sess:
      init=tf.global_variables_initializer()
      init_l=tf.local_variables_initializer()
      sess.run(init)
      sess.run(init_l)
      for i in range(epochs):
        train_cost = train_ler = 0
        print("Epoch: {}").format(i)
        for batch_no in range(total_train_batches):

          feed = {input_data: batch_data(train_data,batch_no,total_train_batches,batch_size),
                 labelss:sess.run(sparse_tuples(batch_data(train_labels,batch_no,total_train_batches,batch_size))),
                 sequence_length: batch_data(train_seq_len,batch_no,total_train_batches,batch_size)}

          da,l,seq=sess.run([input_data,labelss,sequence_length],feed_dict=feed)
#           print da.shape
#           print seq.shape

      #     prediction=tf.argmax(tf.nn.softmax(logit_output),axis=2)
          original=convert(train_labels[batch_no])
          print "Original Label",original
          batch_cost,_ = sess.run([total_loss,optimizer],feed_dict=feed)
          train_cost += batch_cost*batch_size
          train_ler += sess.run(ler, feed_dict=feed)*batch_size
          print "Batch_loss",batch_cost
          d = sess.run(decoded[0], feed_dict=feed)
          str_decoded = convert(d[1])
          print "Decoded Label",str_decoded
        train_cost /= num_train_examples
        train_ler /= num_train_examples

        for batch_no in range(total_val_batches):
          val_feed = {input_data: batch_data(val_data,batch_no,total_val_batches,batch_size),
                      labelss: sess.run(sparse_tuples(batch_data(val_labels,batch_no,total_val_batches,batch_size))),
                      sequence_length: batch_data(val_seq_length,batch_no,total_val_batches,batch_size)}

          val_cost,val_ler = sess.run([total_loss, ler], feed_dict=val_feed)

          log = "Epoch {}/{}, train_cost = {:.3f}, train_ler = {:.3f}, val_cost = {:.3f}, val_ler = {:.3f}"
          print(log.format(i+1, epochs, train_cost, train_ler,val_cost, val_ler))

if __name__ == '__main__':
  run_ctc()

Проблема, с которой я сталкиваюсь, заключается в том, что я получаю пакетную потерю inf, а декодированные этикетки некорректны по сравнению с оригинальной этикеткой.

Вывод, который я получаю (Ссылка на изображение)

...