Как сделать простую байесовскую нейронную сеть для мультиклассовой классификации в pymc2 - PullRequest
0 голосов
/ 14 января 2019

Я хочу построить модель BNN для набора данных iris в pymc2.

Я определил свою модель, и я пытался тренироваться, но точность была только 0.33 для данных поезда и теста.

Это мой текущий код

iterations = 2000
iris = load_iris()
X = iris.data[:, :]
Y = iris.target

X = scale(X)
X = X.astype(float)
Y = Y.astype(float)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.5)

w11 = pm.Normal('w11', mu=0., tau=1.)
w12 = pm.Normal('w12', mu=0., tau=1.)
w21 = pm.Normal('w21', mu=0., tau=1.)
w22 = pm.Normal('w22', mu=0., tau=1.)
w31 = pm.Normal('w31', mu=0., tau=1.)
w32 = pm.Normal('w32', mu=0., tau=1.)

x1 = X_train[:, 0]
x2 = X_train[:, 1]
# x3 = X_train[:, 2]
# x4 = X_train[:, 3]

x3 = pm.Lambda('x3', lambda w1=w11, w2=w12: np.tanh(w1 * x1 + w2 * x2))
x4 = pm.Lambda('x4', lambda w1=w21, w2=w22: np.tanh(w1 * x1 + w2 * x2))


@pm.deterministic
def activation(x=w31 * x3 + w32 * x4):
    return 1. / (1. + np.exp(-x)) #sigmoid


y = pm.Categorical('y', activation, observed=True, value=Y_train)

model = pm.Model([w11, w12, w21, w22, w31, w32, y])
inference = pm.MCMC(model)
inference.sample(iterations)

y_pred_train = pm.Categorical('y_pred_train', activation)
print("Accuracy on train data = {}".format((y_pred_train.value == Y_test).mean()))

x1 = X_test[:, 0]
x2 = X_test[:, 1]

inference.sample(iterations)

y_pred_test = pm.Categorical('y_pred_test', activation)
print("Accuracy on test data = {}".format((y_pred_test.value == Y_test).mean()))

Моя сетевая архитектура

my_network_architecture

Я не уверен, что проблема в категориальном распределении для y, y_pred_train и y_pred_test. И я не могу отследить эти переменные с помощью inference.trace("y")[:], чтобы посмотреть, что именно находится внутри.

Мои текущие результаты

Accuracy on train data = 0.3333333333333333
Accuracy on test data = 0.3333333333333333

Есть ли у вас какие-либо предложения по улучшению этого показателя?

...