Онлайновая модель классификации LSTM, дающая очень большое количество ошибочных прогнозов - PullRequest
0 голосов
/ 26 сентября 2018

Я пытаюсь реализовать модель онлайновой классификации, используя набор данных из 20 групп новостей, чтобы классифицировать сообщения по соответствующим группам.

предварительная обработка : я прохожу всепосты и создание словаря со словами. Затем я индексирую слова, начиная с 1. Затем я перебираю все посты и для каждого слова в посте я ищу словарь и помещаю соответствующий индексный номер в массив.Затем я дополнил все массивы, поставив 0 в конце, чтобы они были одинакового размера (6577).

Затем я создаю встроенный слой (размер встраивания = 300).и каждый вход будет проходить через этот встроенный слой перед подачей на слой LSTM (форма входа LSTM = (1,6577,300)).

В моей модели у меня есть слой LSTM (размер = 200) искрытый слой (размер = 25).Для этого я использую ячейку dynamic_rnn в тензорном потоке, и я устанавливаю параметр длины последовательности для фактической длины записи (длина без дополненных 0 с), чтобы избежать анализа дополненных 0 с.Затем из выходных данных уровня LSTM я подаю только соответствующий выходной сигнал скрытому слою.

С этого момента это похоже на обычную реализацию LSTM.Я сделал все, что знаю, чтобы улучшить точность модели, но число неправильных прогнозов очень велико:

Количество точек данных: 18,846
Ошибки: 17876
Частота ошибок: 0.9485301920832007

Примечание: во время обратного распространения я тренирую встроенный слой и скрытый слой.

Вопрос: Я хочу знать, что я делаю не так, или какие-либо мысли по улучшению модели.Заранее спасибо.

Мой полный код показан ниже:

from collections import Counter
import tensorflow as tf
from sklearn.datasets import fetch_20newsgroups
import matplotlib as mplt
mplt.use('agg') # Must be before importing matplotlib.pyplot or pylab!
import matplotlib.pyplot as plt
from string import punctuation
from sklearn.preprocessing import LabelBinarizer
import numpy as np
from nltk.corpus import stopwords
import nltk
nltk.download('stopwords')



def pre_process():
    newsgroups_data = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))

    words = []
    temp_post_text = []
    print(len(newsgroups_data.data))

    for post in newsgroups_data.data:

        all_text = ''.join([text for text in post if text not in punctuation])
        all_text = all_text.split('\n')
        all_text = ''.join(all_text)
        temp_text = all_text.split(" ")

        for word in temp_text:
            if word.isalpha():
                temp_text[temp_text.index(word)] = word.lower()

        # temp_text = [word for word in temp_text if word not in stopwords.words('english')]
        temp_text = list(filter(None, temp_text))
        temp_text = ' '.join([i for i in temp_text if not i.isdigit()])
        words += temp_text.split(" ")
        temp_post_text.append(temp_text)

    # temp_post_text = list(filter(None, temp_post_text))

    dictionary = Counter(words)
    # deleting spaces
    # del dictionary[""]
    sorted_split_words = sorted(dictionary, key=dictionary.get, reverse=True)
    vocab_to_int = {c: i for i, c in enumerate(sorted_split_words,1)}

    message_ints = []
    for message in temp_post_text:
        temp_message = message.split(" ")
        message_ints.append([vocab_to_int[i] for i in temp_message])


    # maximum message length = 6577

    # message_lens = Counter([len(x) for x in message_ints])AAA

    seq_length = 6577
    num_messages = len(temp_post_text)
    features = np.zeros([num_messages, seq_length], dtype=int)
    for i, row in enumerate(message_ints):
        # print(features[i, -len(row):])
        # features[i, -len(row):] = np.array(row)[:seq_length]
        features[i, :len(row)] = np.array(row)[:seq_length]
        # print(features[i])

    lb = LabelBinarizer()
    lbl = newsgroups_data.target
    labels = np.reshape(lbl, [-1])
    labels = lb.fit_transform(labels)

    sequence_lengths = [len(msg) for msg in message_ints]
    return features, labels, len(sorted_split_words)+1, sequence_lengths


def get_batches(x, y, sql, batch_size=1):
    for ii in range(0, len(y), batch_size):
        yield x[ii:ii + batch_size], y[ii:ii + batch_size], sql[ii:ii+batch_size]


def plot(noOfWrongPred, dataPoints):
    font_size = 14
    fig = plt.figure(dpi=100,figsize=(10, 6))
    mplt.rcParams.update({'font.size': font_size})
    plt.title("Distribution of wrong predictions", fontsize=font_size)
    plt.ylabel('Error rate', fontsize=font_size)
    plt.xlabel('Number of data points', fontsize=font_size)

    plt.plot(dataPoints, noOfWrongPred, label='Prediction', color='blue', linewidth=1.8)
    # plt.legend(loc='upper right', fontsize=14)

    plt.savefig('distribution of wrong predictions.png')
    # plt.show()



def train_test():
    features, labels, n_words, sequence_length = pre_process()

    print(features.shape)
    print(labels.shape)

    # Defining Hyperparameters

    lstm_layers = 1
    batch_size = 1
    lstm_size = 200
    learning_rate = 0.01

    # --------------placeholders-------------------------------------

    # Create the graph object
    graph = tf.Graph()
    # Add nodes to the graph
    with graph.as_default():

        tf.set_random_seed(1)

        inputs_ = tf.placeholder(tf.int32, [None, None], name="inputs")
        # labels_ = tf.placeholder(dtype= tf.int32)
        labels_ = tf.placeholder(tf.float32, [None, None], name="labels")
        sql_in = tf.placeholder(tf.int32, [None], name= 'sql_in')

        # output_keep_prob is the dropout added to the RNN's outputs, the dropout will have no effect on the calculation of the subsequent states.
        keep_prob = tf.placeholder(tf.float32, name="keep_prob")

        # Size of the embedding vectors (number of units in the embedding layer)
        embed_size = 300

        # generating random values from a uniform distribution (minval included and maxval excluded)
        embedding = tf.Variable(tf.random_uniform((n_words, embed_size), -1, 1),trainable=True)
        embed = tf.nn.embedding_lookup(embedding, inputs_)

        print(embedding.shape)
        print(embed.shape)
        print(embed[0])

        # Your basic LSTM cell
        lstm =  tf.contrib.rnn.BasicLSTMCell(lstm_size)

        # Getting an initial state of all zeros
        initial_state = lstm.zero_state(batch_size, tf.float32)

        outputs, final_state = tf.nn.dynamic_rnn(lstm, embed, initial_state=initial_state, sequence_length=sql_in)

        out_batch_size = tf.shape(outputs)[0]
        out_max_length = tf.shape(outputs)[1]
        out_size = int(outputs.get_shape()[2])
        index = tf.range(0, out_batch_size) * out_max_length + (sql_in - 1)
        flat = tf.reshape(outputs, [-1, out_size])
        relevant = tf.gather(flat, index)

        # hidden layer
        hidden = tf.layers.dense(relevant, units=25, activation=tf.nn.relu,trainable=True)

        print(hidden.shape)

        logit = tf.contrib.layers.fully_connected(hidden, num_outputs=20, activation_fn=None)

        cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=labels_))


        optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)


        saver = tf.train.Saver()

    # ----------------------------online training-----------------------------------------

    with tf.Session(graph=graph) as sess:
        tf.set_random_seed(1)
        sess.run(tf.global_variables_initializer())
        iteration = 1
        state = sess.run(initial_state)
        wrongPred = 0
        noOfWrongPreds = []
        dataPoints = []

        for ii, (x, y, sql) in enumerate(get_batches(features, labels, sequence_length, batch_size), 1):

            feed = {inputs_: x,
                    labels_: y,
                    sql_in : sql,
                    keep_prob: 0.5,
                    initial_state: state}

            predictions = tf.nn.softmax(logit).eval(feed_dict=feed)

            print("----------------------------------------------------------")
            print("sez: ",sql)
            print("Iteration: {}".format(iteration))

            isequal = np.equal(np.argmax(predictions[0], 0), np.argmax(y[0], 0))

            print(np.argmax(predictions[0], 0))
            print(np.argmax(y[0], 0))

            if not (isequal):
                wrongPred += 1

            print("nummber of wrong preds: ",wrongPred)

            if iteration%50 == 0:
                noOfWrongPreds.append(wrongPred/iteration)
                dataPoints.append(iteration)

            loss, states, _ = sess.run([cost, outputs, optimizer], feed_dict=feed)

            print("Train loss: {:.3f}".format(loss))
            iteration += 1

        saver.save(sess, "checkpoints/sentiment.ckpt")
        errorRate = wrongPred / len(labels)
        print("ERRORS: ", wrongPred)
        print("ERROR RATE: ", errorRate)
        plot(noOfWrongPreds, dataPoints)


if __name__ == '__main__':
    train_test()

РЕДАКТИРОВАТЬ

enter image description here

1 Ответ

0 голосов
/ 26 сентября 2018

Несколько вещей для рассмотрения -:

  1. График потери против итераций график.Это должно быть вниз, чтобы знать, что ваша сеть изучает. Вы можете использовать тензорная доска для создания этих графиков.Также обеспечивает точность по сравнению с итерациями.
  2. Увеличьте размер пакета с 1 до мини-пакета 64,128 в зависимости от конфигурации вашей системы (ОЗУ)
  3. Используйте двунаправленный LSTM поскольку у вас есть полные предложения перед тренировочной моделью для повышения точности.

РЕДАКТИРОВАТЬ

Ваша модель не учитывает веса правильно.Запустив свой код, модель прогнозирует только класс 0. Взгляните на свои прогнозы и прогноз1.предсказания всегда 0.

Итерация: 1 0 10 число неправильных пред: 1

Потеря поезда: 3.116

Итерация: 2 0 3 количество неправильных пред: 2

Потеря поезда: 3,163

Итерация: 3 0 17 число неправильных пределов: 3

Потеря поезда: 3,212

Итерация: 4 0 3 количество неправильныхпред: 4

Потеря поезда: 2.992

Итерация: 5 0 4 Количество неправильных пред: 5

Потеря поезда: 2.892

Итерация: 6 012 неправильных чисел: 6

Потери в поездах: 3,077

Итерация: 7 0 4 неправильных числа поездов: 7

Потери в поездах: 2,554

Итерация: 8 0 10 число неправильных предков: 8

Потеря поезда: 3.459

Итерация: 9 0 10 количество неправильных предков: 9

Потеря поезда: 2.341

Итерация: 10 0 19 неправильных чисел: 10

Потеря поезда: 3.303

Итерация: 11 0 19 неправильных чисел: 11

Потеря поезда: 3.193

Итерация: 12 0 11 число неправильных сбоев: 12

Потеря поезда: 3.323

Итерация: 13 0 19 число ошибочных сбоев: 13

Потеря поезда: 2.773

Итерация: 14 0 13 число неправильных предков: 14

Потеря поезда: 3.129

Итерация: 15 0 0 количество неправильных предков: 14

Потеря поезда: 3.992

Итерация: 16 0 17 число неправильных предков: 15

Потеря поезда: 3.010

Итерация: 17 0 12 количество неправильных предков: 16

Потеря поезда: 2.534

Итерация: 18 0 12 число неправильных pred: 17

Потеря поезда: 2.804

Итерация: 19 0 11 число неправильных pred: 18

Потеря поезда: 4,369

Итерация: 20 0 8 число неправильных предков: 19

Потеря поезда: 4,028

Итерация: 21 0 7 количество неправильных предов: 20

Потеря поезда: 3.844

Итерация: 22 0 5 число неправильных пред: 21

Потеря поезда: 3.579

Итерация: 23 0 1 числонеправильный пред: 22 * ​​1113 *

Потеря поезда: 3,418

Итерация: 24 0 8 число неправильных пред: 23

Потеря поезда: 4,337

Итерация: 25 0 10 количество неправильных пред: 24

Потеря поезда: 2.328

Итерация: 26 0 14 число неправильных пред: 25

Потеря поезда: 4.216

Итерация: 27 0 16 числонеправильные значения: 26

потеря поезда: 3,155

повторение: 28 0 1 число неправильных значений: 27

потеря поезда: 3,307

повторение: 290 6 число неправильных показателей: 28

потеря поезда: 3,744

повторение: 30 0 0 количество ошибочных показателей: 28

потеря поезда: 4,180

Итерация: 31 0 7 число неправильных пред: 29

Потеря поезда: 3.400

Итерация: 32 0 16 число неправильных пред: 30

Потеря поезда: 2.706

Итерация: 33 0 5 число неправильных предков: 31

Потеря поезда: 2.994

Итерация: 34 0 9 число неправильных предков: 32

Поездпотеря: 3,610

Итеранация: 35 0 13 число неправильных пред: 33

потеря поезда: 2,689

повторение: 36 0 4 количество неправильных предков: 34

потеря поезда: 2,755

Итерация: 37 0 4Количество неправильных пред: 35

Потеря поезда: 2.778

Итерация: 38 0 18 Количество неправильных пред: 36

Потеря поезда: 3.361

Итерация: 39 0 8 число неправильных пред: 37

потеря поезда: 3,640

повторение: 40 0 ​​8 количество неправильных предков: 38

потеря поезда: 3,276

Итерация: 41 0 19 число неправильных пред: 39

Потеря поезда: 2,796

Итерация: 42 0 1 количество неправильных предков: 40

Потеря поезда:3.189

Итерация: 43 0 12 число неправильных предков: 41

Потеря поезда: 2.901

Итерация: 44 0 7 число неправильных предков: 42

Потеря поезда: 2.913

Итерация: 45 0 10 число неправильных пред: 43

Потеря поезда: 2.875

Итерация: 46 0 5 число неправильных пред: 44

Потеря поезда: 3,005

Итерация: 47 0 2 число неправильных пределов: 45

Потеря поезда: 3,246

Итерация: 48 0 6 количество неправильныхпред: 46

Потеря поезда: 3,071

Итерация: 49 0 11 число неправильных пред: 47

Потеря поезда: 2,971

Итерация: 50 0 2 количество неправильных пред: 48

Потеря поезда: 3,192

Итерация: 51 0 12 число неправильных пределов: 49

Потеря поезда: 2,894

Итерация: 52 0 7 количество неправильных пред: 50

Потеря поезда: 2.980

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...