Я хочу использовать numpy where
для матрицы весов текущей нейронной сети для обновления значений, превышающих пороговое значение.
обновление до -> вес сообщения в порядке.
но обновление сообщения-> вес не работает.
Кто-то, кто хорошо справляется с NumPy, пожалуйста, помогите мне !!
>>> import numpy as np
>>> v = np.arange(3*2*2).reshape((3,2,2))
"""
array([[[ 0, 1],
[ 2, 3]],
[[ 4, 5],
[ 6, 7]],
[[ 8, 9],
[10, 11]]])
"""
>>> w = np.arange(3*2*2*3*2*2).reshape((3,2,2,3,2,2))
"""
array([[[[[[ 0, 1],
[ 3, 3]],
[[ 6, 5],
[ 9, 7]],
[[ 12, 9],
[ 15, 11]]],
[[[ 18, 13],
[ 21, 15]],
[[ 24, 17],
[ 27, 19]],
[[ 30, 21],
[ 33, 23]]]],
~~~~~~~~~~~~~~~~~~~~~~~~~~
[[144, 137],
[147, 139]],
[[150, 141],
[153, 143]]]]]])
"""
>>> w[np.where(v>3)] += 1 # pre -> post is OK.
>>> w[:,:,:, np.where(v>3)] += 1 # post -> pre is not working!! I can't understand this result.
# incremented all elements!!