Я использую spicy.optimize.fmin_l_bfgs_b
для оптимизации.
Калькулятор для вычисления loss
и grad
реализован классом
from calculator import loss_calculator
class one_batch:
def __init__(self, setup_dict):
self.setup_dict = setup_dict
def calculate(self):
temp_instance = loss_calculator(self.setup_dict, self.parameters)
self.loss, self.grad = temp_instance.result()
def objective_function(self, parameter):
self.parameters = parameters
self.calculate()
###########################################################
# I want to check if some convergence occures here. If #
# the convergence occurs, stop l-bfgs-b optimization. #
###########################################################
return self.loss, self.grad
Для оптимизации, экземплярone_batch
будет создаваться каждый раз при загрузке новой партии примеров.
from scipy.optimize import fmin_l_bfgs_b as optimizer
model_vector = initial_vector
for n in range(niter):
setup_dict = setup_dict # load the batch of examples
temp_batch = one_batch(setup_dict)
model_update = optimizer(temp_batch.objective_function, x0=model_vector)
model_vector = model_update
Как вы можете видеть из кода, я хочу добиться:
Когда optimizer
вызывает temp_batch.objective_function
для минимизации функции потерь, если происходит некоторая конвергенция, яхотите «сломать» процесс оптимизации, закодированный в spicy.optimize.fmin_l_bfgs_b
.
Как я могу это сделать?