Я использую функциональный API Keras. У меня есть некоторая модель, которая выводит распределение вероятностей с помощью слоя softmax:
action_logits = Dense(units=self.action_space, activation='softmax')(prev_layer)
Затем я маскирую недопустимые действия (или классы, если хотите), умножая логиты на битовый вектор, представляющий законный actions:
mask_illegal_moves = keras.layers.multiply([action_logits, valid_actions])
Наконец, я хочу перенормировать логиты, теперь, когда я установил вывод для некоторых действий равным 0. Это кажется очень простым делом, но я не могу получить это на работу. Например, другой слой softmax не дал желаемых результатов. Более того, поиск в Google любого слоя «нормализации» в основном привел меня к BatchNorm, который меня здесь не интересует.
Любые советы будут с благодарностью приняты!