Matplotlib 3d scatter animation - Как правильно обновлять - PullRequest
2 голосов
/ 12 февраля 2020

Я пытаюсь построить частицы в трехмерной анимации рассеяния, используя matplotlib. Я попытался изменить официальный пример анимации 3D-графика для достижения этой цели. Тем не менее, мой код не оживляет точки, а отображает их все сразу. Я не могу понять, в чем проблемы. Любая помощь или советы будут с благодарностью.

MRE:

import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d.axes3d as p3
import matplotlib.animation as animation
import numpy as np




def Gen_RandPrtcls():
    n = 10
    x = np.random.normal(size=(n,3))*5
    # m = np.repeat(1. / n, n)

    # Computing trajectory
    data = [x]
    nbr_iterations = 300
    for iteration in range(nbr_iterations):
        # data.append(data[-1] + GravAccel(data[-1], m))
        data.append(data[-1]*1.01)

    return data


def update_prtcls(num, dataPrtcls, parts):
    for prtcl, data in zip(parts, dataPrtcls):
        # NOTE: there is no .set_data() for 3 dim data...
        prtcl.set_data(data[:num][:,0:1])
        prtcl.set_3d_properties(data[:num][:,2])
    return parts

# Attaching 3D axis to the figure
fig = plt.figure()
ax = p3.Axes3D(fig)

# Fifty parts of random 3-D parts
data = Gen_RandPrtcls()

# NOTE: Can't pass empty arrays into 3d version of plot()
parts = [ax.plot(dat[:,0], dat[:,1], dat[:,2], marker='.', linestyle="None")[0] for dat in data]

# Setting the axes properties
ax.set_xlim3d([-10.0, 10.0])
ax.set_xlabel('X')

ax.set_ylim3d([-10.0, 10.0])
ax.set_ylabel('Y')

ax.set_zlim3d([-10.0, 10.0])
ax.set_zlabel('Z')

ax.set_title('3D Test')

# Creating the Animation object
prtcl_ani = animation.FuncAnimation(fig, update_prtcls, 25, fargs=(data, parts),
                                   interval=50, blit=False)

plt.show()

1 Ответ

2 голосов
/ 12 февраля 2020

Вы структурировали data в другом порядке, чем вы предполагали.

import numpy as np

def Gen_RandPrtcls(n_particles, n_iterations):
    x = np.random.normal(size=(n_particles, 3))*5

    # Computing trajectory
    data = [x]
    for iteration in range(n_iterations):
        # data.append(data[-1] + GravAccel(data[-1], m))
        data.append(data[-1]*1.01)
    return data

data = Gen_RandPrtcls(n_particles=10, n_iterations=300)
data = np.array(data)  # (n_iterations, n_particles, 3)

В первом измерении data - это iterations, во втором - другое particles и в третьем - spacial coordinates.

. В вашем текущем обновлении вы вычерчиваете всю итерацию для частиц до num data[:, 0:num, :], а не итерации до num data[0:num, :, :] для всех частиц.

Я внес небольшие изменения в ваш код. Я строю траектории всех частиц одновременно, начиная только с первой итерации. Так что мне не нужно л oop по частицам. (Если все частицы должны отображаться с одинаковым цветом [маркер, ...], это работает нормально. В противном случае у вас будет LineObject для каждой частицы. Но логика c должна быть одинаковой).

import numpy as np
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d.axes3d as p3
import matplotlib.animation as animation

fig = plt.figure()
ax = p3.Axes3D(fig)

# Plot the first position for all particles
h = ax.plot(*data[0].T, marker='.', linestyle='None')[0]
# Equivalent to
# h = ax.plot(data[0, :, 0], data[0, :, 1], data[0, :, 2], 
#             marker='.', linestyle='None')[0]

# Setting the axes properties
ax.set_xlim3d([-100.0, 100.0])
ax.set_xlabel('X')

ax.set_ylim3d([-100.0, 100.0])
ax.set_ylabel('Y')

ax.set_zlim3d([-100.0, 100.0])
ax.set_zlabel('Z')
ax.set_title('3D Test')

def update_particles(num):
    # Plot the iterations up to num for all particles
    h.set_xdata(data[:num, :, 0].ravel())
    h.set_ydata(data[:num, :, 1].ravel())
    h.set_3d_properties(data[:num, :, 2].ravel())
    return h

prtcl_ani = animation.FuncAnimation(fig, update_particles, frames=301, 
                                    interval=10)

Вот результат. Надеюсь, что поможет.

Редактировать:

Если вам нужны разные цвета для частиц, вам нужно нанести их отдельно:

colormap = plt.cm.tab20c
colors = [colormap(i) for i in np.linspace(0, 1, n_particles)]
h_particles = [ax.plot(*data[:1, i].T, marker='.', c=colors[i], ls='None')[0]
               for i in range(n_particles)]


def update_particles(num):
    for i, h in enumerate(h_particles):
        h.set_xdata(data[:num, i, 0])
        h.set_ydata(data[:num, i, 1])
        h.set_3d_properties(data[:num, i, 2])
    return h_particles
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...