Я пытаюсь реализовать входной конвейер с помощью tf.data.Элементы находятся в матрице, экспортированной из matlab, в то время как метки находятся в других файлах, для чтения которых требуются определенные функции.
Имена файлов, которые должны быть загружены, могут быть вычислены с использованием числа.
Вот как я это реализовал
def load_files(k):
mesh_file = file_path(k, "off", flags.dataset_mesh)
mat_file = file_path(k, "mat", flags.dataset_mat)
mesh = pymesh.load_mesh(mesh_file)
mat = scipy.io.loadmat(mat_file)
return mesh.vertices, mat
def generator_fn():
return (load_files(x) for x in range(1000000 + 1))
def input_fn() -> Dataset:
dataset = tf.data.Dataset.from_generator(generator_fn,
output_types=(tf.as_dtype(tf.float32), tf.as_dtype(tf.float32)), )
dataset = dataset.batch(batch_size=flags.batch_size).repeat()
dataset = dataset.cache()
dataset = dataset.prefetch(buffer_size=flags.prefetch_buffer_size)
return dataset
Проблема в том, что использование графического процессора очень низкое, около 5% (2080 ti).Я не уверен, где узкое место.Я тестирую с простым MLP, но использование графического процессора, кажется, не меняется, несмотря на слои или нейроны для каждого слоя, который я добавляю.
Я выполняю тренинг следующим образом:
model = keras.Sequential([
keras.layers.Flatten(input_shape=(n_input,)),
keras.layers.Dense(1024, activation=tf.nn.relu),
.
.
.
keras.layers.Dense(1024, activation=tf.nn.relu),
keras.layers.Dense(n_output, activation=None)
])
model.compile(optimizer='adam', loss='mean_squared_error')
model.fit(input_fn().make_one_shot_iterator(), steps_per_epoch=1000000, epochs=1)
Итак, я думаю, что проблема может заключаться в том, как я передаю данные (проблема не должна быть только чтением файла, так как я на SSD NVMe), о том, как я тренируюсь, или о том, что это просто простая сеть, несмотря на добавленные мной слои.
Однако я бы хотелчтобы узнать, есть ли более эффективный способ подачи данных.
Я использую tensorflow-gpu 2.0.0a0
, я запустил эталонный тест от lambda-labs и смог использовать gpu при 100%