Я не понимаю, "ошибка формы" с помощью mxnet - PullRequest
0 голосов
/ 10 ноября 2018

Исходя из Кераса, я пытаюсь воспроизвести мою простую модель с MXNet для прогнозирования с использованием модуля.

Я использую этот простой набор данных: https://archive.ics.uci.edu/ml/machine-learning-databases/wine/wine.data

У меня есть 13 входных данных (от алкоголя до пролина), которые я хочу отправить в модель, и мне нужно классифицировать первый столбец, который называется «тип вина», поэтому я создаю массив nd.array, который имеет 3 записи .


x = data.values[: , 1:14]
y = data.values[:, 0]

X = mx.nd.array(x)
Y = []
for i, v in enumerate(y):
    d = [0,0,0]
    d[int(v)-1] = 1
    Y.append(d)
Y = mx.nd.array(Y)
Y.shape, X.shape
# ((178, 3), (178, 13))

Затем я создаю модель и NDIterator:


net = mx.symbol.Variable('winechemical')
net = mx.symbol.FullyConnected(net, num_hidden=64)
net = mx.symbol.Activation(net, act_type='relu')
net = mx.symbol.FullyConnected(net, num_hidden=32)
net = mx.symbol.Activation(net, act_type='relu')
net = mx.symbol.FullyConnected(net, num_hidden=16)
net = mx.symbol.SoftmaxOutput(net, name='wineclass')

model = Module(symbol=net, context=mx.cpu(),
                  data_names=['winechemical'],
                  label_names=['wineclass_label'])

gen = mx.io.NDArrayIter(X, label=Y, 
                        batch_size=10, 
                        shuffle=True, data_name='winechemical', 
                        label_name='wineclass_label')

Но когда я пытаюсь "обучить" модель, используя метод "подгонки", я получаю эту ошибку:

model.fit(gen, num_epoch=5)

[...]
Error in operator wineclass: Shape inconsistent, Provided = [10,3], inferred shape=[10]

Я почти уверен, что не понимаю, какую форму использовать, потому что я из Кераса, который использует другую форму ... Но где я не прав?

Спасибо за вашу помощь.

Ответы [ 2 ]

0 голосов
/ 30 ноября 2018

Вы уже нашли решение самостоятельно. Но если вы снова столкнетесь с подобной проблемой, вы можете использовать mx.visualization.print_summary () Эта функция очень полезна для проверки форм различных слоев в модели.

0 голосов
/ 10 ноября 2018

Господи, прости ... Я не видел, чтобы я давал 16 выходов вместо 3 ...

...