Я экспериментирую с моделью кластеризации https://github.com/astirn/IIC (уже пытался связаться с ним по этому поводу)
Он использует Mnist Dataset, как и в большинстве научных работ. Здесь они сначала определяют имя набора данных как «mnist», которого достаточно для того, чтобы тензорный поток импортировал mnist из своих стандартных онлайн-наборов данных. Затем он загружает набор данных с помощью функции tenorflow_dataset.load ()
Я создал файл tfrecord для своего набора данных, и теперь мне просто нужно заменить часть, в которой вышеупомянутый скрипт указывает на «mnist» (строка 1 в код ниже) и вместо этого укажите на мой локальный набор данных.
Должен ли я просто заменить 'mnist' на путь к файлу в первой строке ???
Код из фактического файла модели обучения:
if __name__ == '__main__':
# pick a data set
DATA_SET = 'mnist'
# define splits
DS_CONFIG = {
# mnist data set parameters
'mnist': {
'batch_size': 700,
'num_repeats': 5,
'mdl_input_dims': [24, 24, 1]}
}
# load the data set
TRAIN_SET, TEST_SET, SET_INFO = load(data_set_name=DATA_SET, **DS_CONFIG[DATA_SET])
# configure the common model elements
MDL_CONFIG = {
# mist hyper-parameters
'mnist': {
'num_classes': SET_INFO.features['label'].num_classes,
'learning_rate': 1e-4,
'num_repeats': DS_CONFIG[DATA_SET]['num_repeats'],
'save_dir': None},
}
Код из «файла подготовки данных», где он вызывает набор данных с tenorflor_dataset.load как tfds.load:
def load(data_set_name, **kwargs):
"""
:param data_set_name: data set name--call tfds.list_builders() for options
:return:
train_ds: TensorFlow Dataset object for the training data
test_ds: TensorFlow Dataset object for the testing data
info: data set info object
"""
# get data and its info
ds, info = tfds.load(name=data_set_name, split=tfds.Split.ALL, with_info=True)
спасибо за помощь