Самый быстрый способ подсчитать количество ненулевых элементов в выводе get_weights () - PullRequest
1 голос
/ 21 июня 2019

Я хотел бы рассчитать количество ненулевых значений в весах нейронной сети.

Я попробовал следующий код, но получил ValueError.Это может быть связано с тем, что каждый массив имеет различную форму.

h = model.get_weights()  # return a list of numpy arrays
merged_h = []
for l in h:
    merged_h += l
nzcounts = np.count_nonzero(merged_h)

ValueError: operands could not be broadcast together with shapes (0,) (3,3,3,32) 

Интересно, есть ли другие способы вычисления числа ненулевых элементов в выходных данных get_weights()?Спасибо!

1 Ответ

1 голос
/ 21 июня 2019

По сути, проблема в том, что model.get_weights() возвращает список массивов. Я думаю, что самый простой способ сделать это - применить np.count_nonzero() к каждому из этих массивов независимо, а затем суммировать результаты.

np.sum([np.count_nonzero(x) for x in model.get_weights()])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...