На основе обсуждения в комментариях, вот способ, которым вы можете обрезать слой (весовую матрицу) вашей нейронной сети. По сути, метод делает выбор k%
наименьших весов (элементов матрицы) на основе их нормы и устанавливает их на ноль. Таким образом, соответствующая матрица может быть обработана как разреженная матрица, так что мы можем выполнить умножение плотно-разреженной матрицы, которое может быть быстрее при сокращении достаточного количества весов.
def weight_pruning(w: tf.Variable, k: float) -> tf.Variable:
"""Performs pruning on a weight matrix w in the following way:
- The absolute value of all elements in the weight matrix are computed.
- The indices of the smallest k% elements based on their absolute values are selected.
- All elements with the matching indices are set to 0.
Args:
w: The weight matrix.
k: The percentage of values (units) that should be pruned from the matrix.
Returns:
The unit pruned weight matrix.
"""
k = tf.cast(tf.round(tf.size(w, out_type=tf.float32) * tf.constant(k)), dtype=tf.int32)
w_reshaped = tf.reshape(w, [-1])
_, indices = tf.nn.top_k(tf.negative(tf.abs(w_reshaped)), k, sorted=True, name=None)
mask = tf.scatter_nd_update(tf.Variable(tf.ones_like(w_reshaped, dtype=tf.float32), name="mask", trainable=False), tf.reshape(indices, [-1, 1]),tf.zeros([k], tf.float32))
return w.assign(tf.reshape(w_reshaped * mask, tf.shape(w)))