Вы определенно можете подсчитать количество умножений, сложений для обратного прохода вручную, но я думаю, что это исчерпывающий процесс для сложных моделей. обратный счет флопа для CNN и других моделей. Я предполагаю, что причина связана с тем, что вывод более важен с точки зрения различных вариантов CNN и других моделей глубокого обучения в приложении.
Обратный проход важен только во время обучения, и для большинства простых моделей обратный и прямой провалы должны быть близки с некоторыми постоянными коэффициентами.
Итак, я попробовал хитрый подход, чтобы вычислить градиенты для всей модели re snet на графике, чтобы получить флопы учитываются как для прямого прохода, так и для вычисления градиента, а затем вычитаются прямые флопы. Это не точное измерение, может пропустить многие операции для сложного графика / модели.
Но это может дать оценку провала для большинства моделей.
[Следующий фрагмент кода работает с tenorflow 2.0]
import tensorflow as tf
def get_flops():
for_flop = 0
total_flop = 0
session = tf.compat.v1.Session()
graph = tf.compat.v1.get_default_graph()
# forward
with graph.as_default():
with session.as_default():
model = tf.keras.applications.ResNet50() # change your model here
run_meta = tf.compat.v1.RunMetadata()
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
# We use the Keras session graph in the call to the profiler.
flops = tf.compat.v1.profiler.profile(graph=graph,
run_meta=run_meta, cmd='op', options=opts)
for_flop = flops.total_float_ops
# print(for_flop)
# forward + backward
with graph.as_default():
with session.as_default():
model = tf.keras.applications.ResNet50() # change your model here
outputTensor = model.output
listOfVariableTensors = model.trainable_weights
gradients = tf.gradients(outputTensor, listOfVariableTensors)
run_meta = tf.compat.v1.RunMetadata()
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
# We use the Keras session graph in the call to the profiler.
flops = tf.compat.v1.profiler.profile(graph=graph,
run_meta=run_meta, cmd='op', options=opts)
total_flop = flops.total_float_ops
# print(total_flop)
return for_flop, total_flop
for_flops, total_flops = get_flops()
print(f'forward: {for_flops}')
print(f'backward: {total_flops - for_flops}')
Out:
51112224
102224449
forward: 51112224
backward: 51112225