Распечатать значение индекса argMax? - PullRequest
1 голос
/ 12 июля 2020

Я хотел бы напечатать фактическое значение индекса argMax (самая высокая вероятность) в этом сценарии:

const result = await model.predict(t4d); 
result.print(); // puts out: Tensor [[0.9899636, 0.0100364],]
result.as1D().argMax().print(); // prints either 0 or 1

Помимо индекса, я хотел бы напечатать фактическое значение 0, XXXX позади argMax (). Есть какие-нибудь советы по этому поводу?

Тест, который не работает:

const confidence = result.dataSync<'float32'()>;
console.log(confidence);

Извините, если на этот вопрос уже неоднократно отвечали, я потратил несколько часов на поиск!

1 Ответ

2 голосов
/ 17 июля 2020

Значение индекса argMax может быть получено после получения данных тензоров

const result = await model.predict(t4d);
const index = await result.as1D().argMax().data()[0]
const predict = await result.data()

// get the value
const value = result[index]

Другой возможностью было бы использовать topk

const topk = result.as1D().topk()
// get the highest value
value = topk.value() // the value here is a tensor
...