Как создать матрицу смешения в TF, когда метки запекаются в объекте DataSet? - PullRequest
0 голосов
/ 29 апреля 2020

Я использовал Keras API в TF 2.0 для создания CNN и использовал концепцию tf.datasets для этого. Мои данные (изображения) поступают из нескольких папок в структуре каталогов, где папка является меткой, а данные находятся в папке.

Я могу создать набор данных так же, как я это делал для обучения сети и передать его в функцию прогнозирования, как показано ниже:

    predictions = model.predict(test_labeled_ds, steps=40)
    arg_max = np.argmax(predictions, axis =-1)

Но здесь я хочу иметь возможность отображать матрицу путаницы с кодом, похожим на

  tf.math.confusion_matrix(
    **test_labeled_ds,**
    arg_max,
    num_classes=4,
    weights=None,
    dtype=tf.dtypes.int32,
    name=None
)

Но я, очевидно, хочу только передать истинные метки данных для сравнения с arg_max, а не со всем объектом набора данных. Неудивительно, что он жалуется.

Итак, как мне передать только метки в объекте DataSet в качестве аргумента матрице путаницы? Вы могли бы подумать, что есть очевидный способ сделать это вместо того, чтобы снова выделять их в большем количестве кода.

...