Тип данных M xnet: float64, но он все время говорит, что это float32. - PullRequest
2 голосов
/ 13 января 2020

Я - пользователь pytorch и тензор потока. Я наткнулся на M xnet, чтобы использовать AWS вывод мастера-мудреца c.

M xnet api набора данных глюона, похоже, очень похож на набор данных pytorch.

class CustomDataset(mxnet.gluon.data.Dataset):
    def __init__(self):
        self.train_df = pd.read_csv('/shared/KTUTOR/test_summary_data.csv')
    def __getitem__(self, idx):
        return mxnet.nd.array(self.train_df.loc[idx, ['TT', 'TF', 'FT', 'FF']], dtype='float64'), mxnet.nd.array(self.train_df.loc[idx, ['p1']], dtype='float64')
    def __len__(self):
        return len(self.train_df)

Я определил свой набор пользовательских данных, как указано выше, и установил типы данных как float64.

test_data = mxnet.gluon.data.DataLoader(CustomDataset(), batch_size=8, shuffle=True, num_workers=2)

Я обернул свой набор данных с помощью DataLoader, и до этого момента ошибок не было. Ошибка возникает, когда я передаю данные в сеть.

for epoch in range(1):
for data, label in test_data:
    print(data.dtype)
    print(label.dtype)
    with autograd.record():
        output = net(data)
        loss = softmax_cross_entropy(output, label)
    loss.backward()
    trainer.step(batch_size)

Ошибка возрастает в net (данные), и сообщение об ошибке выглядит следующим образом.

MXNetError: [07:53:55] src/operator/contrib/../elemwise_op_common.h:135: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node  at 1-th input: expected float64, got float32
Stack trace:
  [bt] (0) /root/anaconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x4b09db) 
[0x7f00f96519db] ...

Когда Я печатаю тип данных и метку, все они - float64, но M XNet говорит мне, что тип данных данных - float32. Может кто-нибудь объяснить, почему это происходит? Большое спасибо заранее.

Ответы [ 2 ]

1 голос
/ 14 января 2020

Вы должны неинтуитивно преобразовывать свои входные данные в float32 (не float64).

Хотя ошибка, по-видимому, говорит о полной противоположности этому предложению, эта неудачная проверка передается из низкоуровневой операции в сети. это наиболее вероятно в форме: (input * weight) + bias.

Поскольку input является первой переменной вычисления, он устанавливает ожидаемый тип данных для других переменных (вес и смещение) равным float64. Таким образом, проверка на самом деле жалуется, что тип данных weight равен float32, когда ожидается float64.

1 голос
/ 14 января 2020

Ваша сеть в float64 или float32? Попробуем привести веса к float64:

net = net.cast('float64')

Тем не менее, по моему опыту, я не часто тренирую модели DL в float64, float32 и float16 гораздо более распространены для обучения. , А M XNet позволяет легко использовать точность float16 для обучения либо явно , либо автоматически с помощью инструмента AMP (Automati c Mixed Precision)

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...