Если у вас есть ваши окончательные результаты в тензоре probabilities
размера формы данных теста × количество классов, я бы сделал что-то вроде этого:
best_scores, best_classes = probabilities.max(dim=1)
per_class_examples = []
for class_id in range(50):
# mask telling where class_id class is
class_positions = best_classes == class_id
# make sure there are at least three examples,
# if not, rather take less
k = min(3, class_positions.sum())
if k == 0:
per_class_examples.append([])
else:
# set zero score to everything that is not class_id
_, best_examples = torch.topk(best_scores * class_positions, k, dim=1)
per_class_examples.append(best_examples)