Как получить доступ к тензору (y_true) в функции потерь? - PullRequest
0 голосов
/ 03 августа 2020

У меня есть словарь типа:

dict = {"class_1" : array(12, 13, 14, 1686, 124 ,123,....), "class_2" : array(12312,312,3,34,3...), ...}

здесь, он содержит вложения с использованием bert, класса (строкового типа). Поэтому во время функции потерь я хочу минимизировать разницу между встраиванием фактического и прогнозируемого классов.

def loss_function(y_true, y_pred):
    # what can I do for finding the class here 
    # I need to find the classes y_pred is pointing to 
    # then need to find mse between embedding vector of y_pred class and y_true classes.

Итак, основная проблема в том, как мне найти значения, на которые y_true указывает на каждой итерации. Я не могу выполнить какую-либо функцию массива, потому что это тензор. Мне нужно выполнить несколько задач в функции потерь:

  • найти имя класса, где y_true равно 1
  • найти имя класса, где y_pred равно 1
  • получить их вложения из dict и усредняют их, потому что задача - это классификация по нескольким меткам
  • вычислить mse между ними

Шаг 1, 2 являются проблемными c. Пожалуйста, помогите мне и заранее спасибо.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...