Мои прогнозы под тензором pred
, а pred.shape
равно (4254, 10, 3)
.Итак, у нас есть 4254
матрицы размерности (10, 3)
.Давайте посмотрим на одну из этих матриц.
W = array([[0.04592975, 0.09632163, 0.85774857],
[0.03408821, 0.27141285, 0.6944989 ],
[0.02538731, 0.4691383 , 0.50547445],
[0.01959289, 0.6456455 , 0.33476162],
[0.01333424, 0.7494791 , 0.23718661],
[0.0109237 , 0.77042925, 0.218647 ],
[0.01438793, 0.7796771 , 0.20593494],
[0.01474626, 0.6817438 , 0.30350992],
[0.02189695, 0.57687664, 0.40122634],
[0.03810155, 0.5130332 , 0.44886518]], dtype=float32)
Как видно из приведенного выше примера, есть 10 векторов, которые представляют горячее представление метки.Например, np.argmax([0.04592975, 0.09632163, 0.85774857]) = 2
.
Почему я продолжаю серию из 10 векторов?Я работаю над проблемой прогнозирования временных рядов, где во время t_0
я предсказываю следующие 10 меток для времени t_1
для времени t_10
.
Для каждой из этих матриц мне было бы интересно вернуть оригинальные метки.Поэтому для матрицы W
я должен получить массив array([2, 2, 2, 1, 1, 1, 1, 1, 1, 1])
.
Давайте определим пороговый массив threshold_array = np.array([0.6, 0.65, 0.70, 0.75, 0.80, 0.80, 0.80, 0.80, 0.80, 0.80])
и возьмем обратно labels = array([2, 2, 2, 1, 1, 1, 1, 1, 1, 1])
.Предположим, что нейтральная позиция равна 1
, а действие - 0
или 2
.Цель здесь - изменить labels
в соответствии с threshold_array
и нашей матрицей W
.
Если я возьму W[0]
, мы знаем, что np.argmax(W[0]) = 2
и W[0][2] = 0.85774857
.Как W[0][2] >= threshold_array[0]
, тогда labels[0]
останется 2
.
Этот другой пример немного отличается.Если я возьму W[2]
, мы знаем, что np.argmax(W[2]) = 2
и W[2][2] = 0.50547445
.Как W[2][2] < threshold_array[2]
, тогда labels[2]
будет изменено с 2
на 0
.
Если я применяю эту стратегию к каждому вектору из W
, labels
теперь устанавливается на array([2, 2, 0, 1, 1, 1, 1, 1, 1, 1])
.Обратите внимание, что только действие может стать нейтральной позицией, а не обратной.
Как можно кодировать в python эту стратегию для каждой матрицы W
внутри pred
, чтобы получить матрицу меток измерения (4254, 10)
?