tf.decode_csv - вопрос об аргументе - PullRequest
0 голосов
/ 25 июня 2019

Я работаю с моделью классификации изображений с использованием TensorFlow и проверяю весь код, чтобы убедиться, что я его понимаю;Я понимаю все это, кроме одной части входной функции.

Внутри входной функции CSV-файл (имена файлов данных Train / Eval) превращается в два тензора, по одному для каждого столбца;и сами изображения преобразуются в двоичные данные.

В родительской функции make_input_fn csv_row не является аргументом.Внутри этой родительской функции находится _input_fn, а внутри нее, в свою очередь, вложена функция decode_csv.

Так что я не понимаю, вот что: csv_row не является аргументом в make_input_fn, но это аргумент decode_csvфункция;как код узнает - из-за отсутствия лучшего способа выразить его - что такое csv_row?

Я видел похожий код, используемый в других местах, поэтому я знаю, что он правильный, но я просто хочу понять, как это работает.

Любая помощь высоко ценится.


def make_input_fn(csv_of_filenames, batch_size, mode, augment = False):
    def _input_fn():
        def decode_csv(csv_row):
            filename, label = tf.decode_csv(records = csv_row, record_defaults = [[""],[""]])
            image_bytes = tf.read_file(filename = filename)
            return image_bytes, label

        # Create tf.data.dataset from filename
        dataset = tf.data.TextLineDataset(filenames = csv_of_filenames).map(map_func = decode_csv)     

        if augment: 
            dataset = dataset.map(map_func = read_and_preprocess_with_augment)
        else:
            dataset = dataset.map(map_func = read_and_preprocess)

        if mode == tf.estimator.ModeKeys.TRAIN:
            num_epochs = None # indefinitely
            dataset = dataset.shuffle(buffer_size = 10 * batch_size)
        else:
            num_epochs = 1 # end-of-input after this

        dataset = dataset.repeat(count = num_epochs).batch(batch_size = batch_size)
        return dataset.make_one_shot_iterator().get_next()
    return _input_fn
...