Вот основной код:
features = torch.Tensor(feature_num, in_feature) # feature_num=500,000+
labels = torch.Tensor(feature_num, dtype=torch.long) # each value range from 0 to label_num-1
weights = torch.Tensor(label_num, in_feature, out_feature) # in_feature=out_feature=200, label_num=500+
# transform each feature according to its label
mlp_weights = torch.index_select(weights, 0, labels)
out = torch.bmm(features.unsequeeze(dim=1), mlp_weights).unsequeeze(dim=1)
Я просто хочу преобразовать каждый feature
, используя спецификацию c weight
. Но это вызывает ошибку out of memory , когда выполняется 'index_select' на моей машине с 12 ГБ графического процессора. Я попробовал два решения:
- разделить
features
на несколько кусков, затем сделать то же самое, но это также вызывает ошибку OOM. - l oop в каждой метке и заменить
torch.bmm
с torch.matmul
, но это очень медленно.
Так что любой может дать несколько советов. Спасибо.