Я следую за NMT с вниманием (https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb) учебник и применяю его для собственного случая использования. К сожалению, когда я пытаюсь построить весы внимания, у меня возникают проблемы с выравниванием оси X, есливвод слишком длинный (например, 14 вместо 7).
В этом кодовом блоке построение графика работает, как и ожидалось:
import numpy as np
from matplotlib import pyplot as plt
def plot_attention():
attention = np.array([[7.78877574e-10, 4.04739769e-10, 6.65854022e-05, 1.63362725e-04,
2.85054208e-04, 8.50252633e-04, 4.58042100e-02],
[9.23501700e-02, 5.69618285e-01, 1.80586591e-01, 9.78111699e-02,
2.71992851e-02, 9.59911197e-03, 2.54837354e-03]])
sentence = ['<start>', 'hace', 'mucho', 'frio', 'aqui', '.', '<end>']
predicted_sentence = ['it', 's']
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1, 1, 1)
ax.matshow(attention, cmap='viridis')
fontdict = {'fontsize': 14}
ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)
ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)
plt.show()
plot_attention()
, но сбольше элементов в списке «предложение», кажется, смещено:
def plot_attention():
attention = np.array([[7.78877574e-10, 4.04739769e-10, 6.65854022e-05, 1.63362725e-04,
2.85054208e-04, 8.50252633e-04, 4.58042100e-02, 7.78877574e-10, 4.04739769e-10, 6.65854022e-05, 1.63362725e-04,
2.85054208e-04, 8.50252633e-04, 4.58042100e-02],
[9.23501700e-02, 5.69618285e-01, 1.80586591e-01, 9.78111699e-02,
2.71992851e-02, 9.59911197e-03, 2.54837354e-03, 7.78877574e-10, 4.04739769e-10, 6.65854022e-05, 1.63362725e-04,
2.85054208e-04, 8.50252633e-04, 4.58042100e-02]])
sentence = ['<start>', 'hace', 'mucho', 'frio', 'aqui', '.', '<end>', '<start>', 'hace', 'mucho', 'frio', 'aqui', '.', '<end>']
predicted_sentence = ['it', 's']
fig = plt.figure(figsize=(20,10))
ax = fig.add_subplot(1, 1, 1)
ax.matshow(attention, cmap='viridis')
fontdict = {'fontsize': 14}
ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)
ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)
plt.show()
plot_attention()
Я ожидаю, что ось х будет идеально выровнена и что все элементы оси х показаны (не каждый второй, как сейчас)