Получение значения потерь равным 0 при обучении нейронной сети - PullRequest
0 голосов
/ 30 мая 2018

Я не уверен, должен ли я вставлять весь код, но вот он:

import tensorflow as tf 
import numpy as np  
import requests 
from sklearn.model_selection import train_test_split

BATCH_SIZE = 20


#Get data
birthdata_url = 'http://springer.bme.gatech.edu/Ch17.Logistic/Logisticdat/lowbwt.dat'
birth_file = requests.get(birthdata_url)
birth_data = birth_file.text.split('\r\n')[5:]
birth_data = np.array([[x for x in y.split(' ') if len(x)>=1] for y in birth_data[1:] if len(y)>=1])

#Get x and y vals
y_vals = np.array([x[1] for x in birth_data]).reshape((-1,1))
x_vals = np.array([x[2:10] for x in birth_data])

#Split data
x_train, x_test, y_train, y_test = train_test_split(x_vals,y_vals,test_size=0.3)


#Placeholders
x_data = tf.placeholder(dtype=tf.float32,shape=[None,8])
y_data = tf.placeholder(dtype=tf.float32,shape=[None,1])


#Define our Neural Network

def init_weight(shape):
    return tf.Variable(tf.truncated_normal(shape=shape,stddev=0.1))

def init_bias(shape):
    return tf.Variable(tf.constant(0.1,shape=shape))

def fully_connected(inp_layer,weights,biases):
    return tf.nn.relu(tf.matmul(inp_layer,weights)+biases)

def nn(x):
    w1 = init_weight([8,25])
    b1 = init_bias([25])
    layer1 = fully_connected(x,w1,b1)

    w2 = init_weight([25,10])
    b2 = init_bias([10])
    layer2 = fully_connected(layer1,w2,b2)

    w3 = init_weight([10,3])
    b3 = init_bias([3])
    layer3 = fully_connected(layer2,w3,b3)

    w4 = init_weight([3,1])
    b4 = init_bias([1])
    final_output = fully_connected(layer3,w4,b4)


    return final_output




#Predicted values.
y_ = nn(x_data)


#Loss and training step.
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_data,logits=y_))
train_step = tf.train.AdamOptimizer(0.1).minimize(loss)

#Initalize session and global variables
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Accuracy
def get_accuracy(logits,labels):
    batch_predicitons = np.argmax(logits,axis=1)
    num_correct = np.sum(np.equal(batch_predicitons,labels))
    return(100*num_correct/batch_predicitons.shape[0])



loss_vec = []
for i in range(500):
    #Get random indexes and create batches. 
    rand_index = np.random.choice(len(x_train),size=BATCH_SIZE)

    #x and y batch.
    rand_x = x_train[rand_index]
    rand_y = y_train[rand_index]

    #Run the training step.
    sess.run(train_step,feed_dict={x_data:rand_x,y_data:rand_y})

    #Get the current loss. 
    temp_loss = sess.run(loss,feed_dict={x_data:x_test,y_data:y_test})
    loss_vec.append(temp_loss)

    if(i+1)%20==0:
        print("Current Step is: {}, Loss: {}"
            .format((i+1),
            temp_loss))

    #print("-----Test Accuracy: {}-----".format(get_accuracy(logits=sess.run(y_,feed_dict={x_data:x_test}),labels=y_test)))

Когда я запускаю свою программу, я всегда получаю значение потери равным 0 при обучении.Я не уверен, где проблема потенциально в.Но у меня есть несколько идей.

1) Может быть, я создаю пакеты данных?Это кажется необычным способом, но, насколько мне известно, он должен работать, получая случайные индексы, как в rand_index = np.random.choice(len(x_train),size=BATCH_SIZE).

2) Это не имеет смысла, но может ли это быть из-за того, что данные являются «небольшими данными»?

3) Любая простая ошибка в коде?

4) Или у меня действительно потеря как 0. (что является наиболее невероятным случаем)

Я был бы более чем счастлив, если бы вы могли дополнительно указать, чего мне следует избегать вкод выше.

Спасибо.

1 Ответ

0 голосов
/ 30 мая 2018

Я запустил ваш код, и это ошибки, которые я обнаружил.

  1. Похоже, ваши входные данные являются строками.Вы должны преобразовать его в float.
  2. Не используйте relu на последнем слое.Он должен вводиться непосредственно в функцию потерь без нелинейности.
  3. Вы должны использовать функцию sigmoid_cross_entropy_with_logits вместо softmax.Sigmoid - для бинарной классификации, а softmax - для мультиклассовой классификации.
  4. Возможно, ваша скорость обучения слишком высока.Я бы попробовал пониже.
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...