Есть ли способ передать argparse.ArgumentParser()
.map()
из tf.data.Dataset.from_tensor_slices()
. Поскольку from_tensor_slices()
ожидает тензоры в качестве аргументов, мне было интересно, сможем ли мы получить доступ к аргументам этого argparse.Namespace class
в map()
.
Мой код выглядит следующим образом:
import tensorflow as tf
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--arg1", help="Help text 1")
parser.add_argument("--arg2", help="Help text 2")
args = parser.parse_args()
image_paths, labels = load_base_data(...)
epoch_size = len(image_paths)
image_paths = tf.convert_to_tensor(image_paths, dtype=tf.string)
labels = tf.convert_to_tensor(labels)
dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels, args))
def map_fn(path, label, args):
# path/label represent values for a single example
image = tf.image.decode_jpeg(tf.read_file(path))
# some mapping to constant size - be careful with distorting aspec ratios
image = tf.image.resize_images(out_shape)
# color normalization - just an example
image = tf.to_float(image) * (2. / 255) - 1
return image, label
# num_parallel_calls > 1 induces intra-batch shuffling
dataset = dataset.map(map_fn, num_parallel_calls=8)
Когда я запускаю следующий код, я получаю следующую ошибку:
TypeError: Failed to convert object of type <class 'argparse.Namespace'> to Tensor. Contents: Namespace(arg1=value, arg2=value, n). Consider casting elements to a supported type.