Более конкретный пример поможет лучше понять вашу проблему. Но из того, что я понимаю, вы не решаете проблему классификации classi c, где мы будем использовать CrossEntropyLoss, но вы хотите узнать меру взаимосвязанности.
Лучший метод, который я знаю для этого, - это использование метри c обучения (обучение ранжированию является примером, который, кажется, соответствует вашей проблеме). Обучение Triplet особенно известно, но есть и другие методы, которые делают это возможным. Вы можете посмотреть здесь: https://gombru.github.io/2019/04/03/ranking_loss/, а здесь: https://towardsdatascience.com/image-similarity-using-triplet-loss-3744c0f67973
Вкратце метри c обучение отображает входные данные для функции вектор. Обучение выполняется таким образом, что аналогичные входные данные отображаются в окрестности пространства признаков, а разнородные примеры отображаются в удаленных частях пространства признаков.
Минимальный код (на основе PyTorch) для этого:
# bare bone initialization
loss_func = nn.TripletMarginLoss(swap=True)
feature_extractor = Model(feat_size=1024)
optim = Adam(feature_extractor.parameters())
for x, x_pos, x_neg in dataloader:
# x contains the data we want to correctly rank;
# x_pos contains examples that should be mapped close to x;
# x_neg contains examples that should be mapped far from x;
optim.zero_grad()
# extract features of each input
feat_x = feature_extractor(x)
feat_x_pos = feature_extractor(x_pos)
feat_x_neg = feature_extractor(x_neg)
# calculate the loss based on those inputs
loss = loss_func(feat_x, feat_x_pos, feat_x_neg)
loss.backward()
optim.step()
PyTorch делает множество действительно простых в реализации! Сложной задачей является разработка учебного конвейера, который соответствует вашим потребностям.