RBF-Нейронная сеть не может классифицировать набор данных MNIST - PullRequest
0 голосов
/ 24 января 2019

Я реализовал классификатор нейронной сети RBF. Я использую свою реализацию для классификации набора данных MNIST, но он не обучается и всегда просто прогнозирует один класс. Я был бы очень признателен, если бы кто-то мог помочь мне определить проблему с моей реализацией.

Я должен отметить, что реализация довольно медленная из-за того, что она работает пример за примером, но я не знаю, как сделать так, чтобы она работала партия за партией. (Я новичок в tenorflow и Python в целом)

Моя реализация выглядит следующим образом:

class RBF_NN:
    def __init__(self, M, K, L, lr):
    #Layer sizes
    self.M = M #input layer size - number of features
    self.K = K #RBF layer size
    self.L = L #output layer size - number of classes

    #
    x = tf.placeholder(tf.float32,shape=[M])

    matrix = tf.reshape(tf.tile(x,multiples=[K]),shape=[K,M])

    prototypes_input = tf.placeholder(tf.float32,shape=[K,M])
    prototypes = tf.Variable(prototypes_input) # prototypes - representatives of the data

    r = tf.reduce_sum(tf.square(prototypes-matrix),1)

    s = tf.Variable(tf.random.uniform(shape=[K],maxval=1)) #scaling factors

    h = tf.exp(-r/(2*tf.pow(s,2)))

    W = tf.Variable(tf.random.uniform(shape=[K,L],maxval=1))
    b = tf.Variable(tf.constant(0.1, shape=[L]))

    o = tf.matmul(tf.transpose(tf.expand_dims(h,1)),W) + b 

    pred_class = tf.argmax(o,1)

    y = tf.placeholder(shape=[L], dtype=tf.float32)

    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=o, labels=y))
    optimizer = tf.train.AdamOptimizer(lr).minimize(loss)

    self.x = x
    self.prototypes_input = prototypes_input
    self.prototypes = prototypes
    self.r = r
    self.s = s
    self.h = h
    self.W = W
    self.b = b
    self.o = o
    self.y = y
    self.loss = loss
    self.optimizer = optimizer
    self.pred_class = pred_class

    def fit(self,X,y,prototypes,epoch_count,print_step,sess):
    for epoch in range(epoch_count):
        epoch_loss = 0
        for xi,yi in zip(X,y):
            iter_loss, _ = sess.run((self.loss,self.optimizer),feed_dict={self.x: xi, self.y: yi, self.prototypes_input:prototypes})
            epoch_loss = epoch_loss + iter_loss
        epoch_loss = epoch_loss/len(X)
        if epoch%print_step == 0:
            print("Epoch loss",(epoch+1),":",epoch_loss)

    def predict(self,x,sess):
    return sess.run((self.pred_class),feed_dict={self.x:x})[0]

    def get_prototypes(self,sess):
    return sess.run((self.prototypes))

Использование:

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
y_train = to_one_hot(y_train,10)
y_test = to_one_hot(y_test,10)
x_train = np.asarray([np.asarray(x).reshape(-1) for x in x_train])
x_test = np.asarray([np.asarray(x).reshape(-1) for x in x_test])
M = 784
K = 1000
L = 10
lr = 0.01
rbfnn = RBF_NN(M,K,L,lr)

#Selecting prototypes from the train set
idx = np.random.randint(len(x_train), size=K)
prototypes = x_train[idx,:]

init = tf.global_variables_initializer()
sess = tf.InteractiveSession()

sess.run(init,feed_dict={rbfnn.prototypes_input:prototypes})

rbfnn.fit(x_train,y_train,prototypes,epoch_count=1, print_step=1,sess=sess)

y_test_p = []
for xi,yi in zip(x_test,y_test):
    yp = rbfnn.predict(xi,sess=sess)
    y_test_p.append(yp)

y_test_t = [np.argmax(yi) for yi in y_test]


acc = accuracy_score(y_test_t,y_test_p,)
precc = precision_score(y_test_t,y_test_p, average='macro')
recall = recall_score(y_test_t,y_test_p, average = 'macro')
f1 = f1_score(y_test_t,y_test_p,average='macro')

print("Accuracy:",acc)
print("Precision:",precc)
print("Recall:",recall)
print("F1 score:",f1)

sess.close()

1 Ответ

0 голосов
/ 24 января 2019

Реализация в порядке.Тем не менее, он кажется очень чувствительным к данным.Он начнет хорошо учиться, если будут добавлены следующие строки:

x_train = (x_train-x_train.min())/(x_train.max()-x_train.min())
x_test = (x_test-x_test.min())/(x_test.max()-x_test.min())

Таким образом, данные нормализуются, так что интервал каждой функции составляет от 0 до 1.

...