Я хочу оживить процесс нахождения минимальной точки функции различными методами оптимизации градиентного спуска. Для этого я использую пакеты matplotlib и celluloid. Проблема в том, что невозможно зафиксировать легенду графика в анимации, и в каждом l oop новая легенда добавляется под предыдущей легендой, как вы можете видеть на рисунке ниже. Есть ли способ исправить легенду и избежать этой проблемы?
from celluloid import Camera
fig,ax = plt.subplots(1, 1,figsize=(10, 10))
camera = Camera(fig)
for i in range(path1.shape[1])
ax.contour(x_mesh, y_mesh, z, levels=np.logspace(0, 5, 35), norm=LogNorm(), cmap=plt.cm.jet)
ax.plot(*minima_, 'r*', markersize=18)
line, = ax.plot([], [], 'k', label='Simple SGD', lw=2)
point, = ax.plot([], [], 'ko')
line.set_data(path1[::,:i])
point.set_data(path1[::,i-1:i])
line, = ax.plot([], [], 'r', label='SGD with momentum', lw=2)
point, = ax.plot([], [], 'ro')
line.set_data(*path2[::,:i])
point.set_data(*path2[::,i-1:i])
line, = ax.plot([], [], 'g', label='SGD with Nesterov', lw=2)
point, = ax.plot([], [], 'go')
line.set_data(*path3[::,:i])
point.set_data(*path3[::,i-1:i])
line, = ax.plot([], [], 'b', label='SGD with Adagrad', lw=2)
point, = ax.plot([], [], 'bo')
line.set_data(*path4[::,:i])
point.set_data(*path4[::,i-1:i])
line, = ax.plot([], [], 'c', label='SGD with Adadelta', lw=2)
point, = ax.plot([], [], 'co')
line.set_data(*path5[::,:i])
point.set_data(*path5[::,i-1:i])
line, = ax.plot([], [], 'm', label='SGD with RMSprob', lw=2)
point, = ax.plot([], [], 'mo')
line.set_data(*path6[::,:i])
point.set_data(*path6[::,i-1:i])
line, = ax.plot([], [], 'y', label='SGD with Adam', lw=2)
point, = ax.plot([], [], 'yo')
line.set_data(*path7[::,:i])
point.set_data(*path7[::,i-1:i])
line, = ax.plot([], [], 'y', label='SGD with Adamax', lw=2)
point, = ax.plot([], [], 'y*')
line.set_data(*path8[::,:i])
point.set_data(*path8[::,i-1:i])
line, = ax.plot([], [], 'k', label='SGD with Nadam', lw=2)
point, = ax.plot([], [], 'kp')
line.set_data(*path9[::,:i])
point.set_data(*path9[::,i-1:i])
line, = ax.plot([], [], 'r', label='SGD with AMSGrad', lw=2)
point, = ax.plot([], [], 'rD')
line.set_data(*path10[::,:i])
point.set_data(*path10[::,i-1:i])
ax.legend(loc='upper left')
camera.snap()
animation = camera.animate()
animation.save('2D_animation_overlap.gif', writer='imagemagick')