Я использую следующую функцию для генерации путаницы:
def plot_confusion_matrix(cm, classes, normalize=False, cmap=cm.Blues, png_output=None, show=True):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
title='Normalized confusion matrix'
else:
title='Confusion matrix'
f = plt.figure()
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
if png_output is not None:
os.makedirs(png_output, exist_ok=True)
f.savefig(os.path.join(png_output,'confusion_matrix.png'), bbox_inches='tight')
if show:
plt.show()
plt.close(f)
else:
plt.close(f)
Когда у меня есть несколько классов, я получаю аккуратную диаграмму, подобную этой:

Но когда у меня большое количество классов, я получаю это:

Я пытался использоватьтот же подход, что и в этом решении Python boxplot matplotlib автоматический размер фигуры, основанный на количестве категорий , но это не сработало.
Как я могу изменить свою матрицу путаницы, основываясь на ее размереколичество классов, как в приведенном выше решении boxplot?
ОБНОВЛЕНИЕ 1
После включения позиции тиков и динамической ширины фиг

def plot_confusion_matrix(y_true,y_pred, classes, normalize=False, cmap=cm.Blues, png_output=None, show=True):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
cm = confusion_matrix(y_true,y_pred)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
title='Normalized confusion matrix'
else:
title='Confusion matrix'
# Calculate chart area size
leftmargin = 0.5 # inches
rightmargin = 0.5 # inches
categorysize = 0.5 # inches
figwidth = leftmargin + rightmargin + (len(classes) * categorysize)
f = plt.figure(figsize=(figwidth, figwidth))
# Create an axes instance and ajust the subplot size
ax = f.add_subplot(111)
ax.set_aspect(1)
f.subplots_adjust(left=leftmargin/figwidth, right=1-rightmargin/figwidth, top=0.94, bottom=0.1)
res = ax.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar(res)
ax.set_xticks(range(len(classes)))
ax.set_yticks(range(len(classes)))
ax.set_xticklabels(classes, rotation=45, ha='right')
ax.set_yticklabels(classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
ax.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
# plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
if png_output is not None:
os.makedirs(png_output, exist_ok=True)
f.savefig(os.path.join(png_output,'confusion_matrix.png'), bbox_inches='tight')
if show:
plt.show()
plt.close(f)
else:
plt.close(f)
С наилучшими пожеланиями.Клейсон Риос.