Сохранение конфигурации keras с пользовательской метрической функцией в JSON - PullRequest
0 голосов
/ 24 апреля 2019

Я пытаюсь сохранить свою конфигурацию для модели Keras. Я хотел бы иметь возможность прочитать конфигурацию из файла, чтобы иметь возможность воспроизвести обучение.

Перед реализацией пользовательской метрики в функции я мог бы просто сделать это, как показано ниже, без mean_pred. Сейчас я сталкиваюсь с проблемой TypeError: Object of type 'function' is not JSON serializable.

Здесь Я прочитал, что можно получить имя функции в виде строки с помощью custom_metric_name = mean_pred.__name__. Я хотел бы не только сохранить имя, но и по возможности сохранить ссылку на функцию.

Возможно, мне следует, как уже упоминалось, здесь также подумать не только о сохранении моей конфигурации в файле .py, но и об использовании ConfigObj. Если это не решит мою текущую проблему, я реализую это позже.

Минимальный рабочий пример проблемы:

import keras.backend as K
import json

def mean_pred(y_true, y_pred):
    return K.mean(y_pred)

config = {'epochs':500,
          'loss':{'class':'categorical_crossentropy'},
          'optimizer':'Adam',
          'metrics':{'class':['accuracy', mean_pred]}
          }

# Do the training etc...

config_filename = 'config.txt'
with open(config_filename, 'w') as f:
    f.write(json.dumps(config))

Огромная благодарность за помощь в решении этой проблемы, а также за другие подходы к сохранению моей конфигурации наилучшим образом.

1 Ответ

0 голосов
/ 26 апреля 2019

Чтобы решить мою проблему, я сохранил имя функции в виде строки в файле конфигурации, а затем извлек функцию из словаря, чтобы использовать ее в качестве метрики в модели. Можно также использовать: 'class':['accuracy', mean_pred.__name__], чтобы сохранить имя функции в виде строки в конфигурации. Это также работает для нескольких пользовательских функций и для большего количества ключей к метрикам (например, определение метрик для 'reg', как 'class' при выполнении регрессии и классификации).

import keras.backend as K
import json
from collections import defaultdict

def mean_pred(y_true, y_pred):
    return K.mean(y_pred)


config = {'epochs':500,
          'loss':{'class':'categorical_crossentropy'},
          'optimizer':'Adam',
          'metrics':{'class':['accuracy', 'mean_pred']}
          }


custom_metrics= {'mean_pred':mean_pred}

metrics = defaultdict(list)
for metric_type, metric_functions in config['metrics'].items():
    for function in metric_functions:
        if function in custom_metrics.keys():
            metrics[metric_type].append(custom_metrics[function])
        else:
            metrics[metric_type].append(function)

# Do the training, use metrics

config_filename = 'config.txt'
with open(config_filename, 'w') as f:
    f.write(json.dumps(config))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...