Поскольку мне нужно обучить мою модель на Imag eNet с 1024 пакетными размерами, я должен использовать более 2 графических процессоров, поэтому я использую tf.distribute.MirroredStrategy () для trian, эта стратегия может использовать только fit (), а не fit_generator (). Однако Imag eNet имеет большой объем данных, я использую tenorflow.keras.preprocessing.image.ImageDataGenerator.flow_from_directory () для обработки данных ,, но это делает некоторые ошибки, мой код приведен ниже: (я опустил некоторые неважные функции)
import os, sys, argparse
import numpy as np
from multiprocessing import cpu_count
from multi_gpu import ParallelModel
import tensorflow.keras.backend as K
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau, LearningRateScheduler, TerminateOnNaN
import tensorflow as tf
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
num_epochs = 5
batch_size_per_replica = 256
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: %d' % strategy.num_replicas_in_sync) # 输出设备数量
batch_size_all = batch_size_per_replica * strategy.num_replicas_in_sync
def preprocess(x):
x = np.expand_dims(x, axis=0)
x /= 255.0
x -= 0.5
x *= 2.0
return x
def main(args):
log_dir = args.log_dir#'logs/'
checkpoint = ModelCheckpoint(args.model_save_dir + 'ep{epoch:03d}-val_loss{val_loss:.3f}-val_acc{val_acc:.3f}-val_top_k_categorical_accuracy{val_top_k_categorical_accuracy:.3f}.h5',
monitor='val_acc',
mode='max',
verbose=1,
save_weights_only=False,
save_best_only=True,
period=1)
logging = TensorBoard(log_dir=args.model_save_dir, histogram_freq=0, write_graph=False, write_grads=False, write_images=False, update_freq='batch')
terminate_on_nan = TerminateOnNaN()
learn_rates = [0.05, 0.01, 0.005, 0.001, 0.0005, 0.0001]
lr_scheduler = LearningRateScheduler(lambda epoch: learn_rates[epoch // 30])
def make_train_generator():
train_datagen = ImageDataGenerator(preprocessing_function=preprocess,
zoom_range=0.25,
#shear_range=0.2,
#channel_shift_range=0.1,
#rotation_range=0.1,
width_shift_range=0.05,
height_shift_range=0.05,
horizontal_flip=True)
train_generator = train_datagen.flow_from_directory(
args.train_data_path,
target_size=(224, 224),
batch_size=args.batch_size)
total_train_samples = train_generator.samples
return train_generator
train_dataset = tf.data.Dataset.from_generator(make_train_generator, output_types=tf.float32, output_shapes=tf.TensorShape([224,224]))
def make_test_generator():
test_datagen = ImageDataGenerator(preprocessing_function=preprocess)
test_generator = test_datagen.flow_from_directory(
args.val_data_path,
target_size=(224, 224),
batch_size=args.batch_size)
total_test_samples = test_generator.samples
return test_generator
test_dataset = tf.data.Dataset.from_generator(make_test_generator, output_types=tf.float32, output_shapes=tf.TensorShape([224,224]))
optimizer = get_optimizer(args.optim_type, args.learning_rate)
with strategy.scope():
# prepare model
model = get_model(args.model_type)
model.compile(
optimizer=optimizer,
metrics=['accuracy', 'top_k_categorical_accuracy'],
loss='categorical_crossentropy')
model.summary()
print('Train on {} samples, val on {} samples, with batch size {}.'.format(train_generator.samples, test_generator.samples, args.batch_size))
model.fit(
train_dataset,
batch_size=None,
epochs=args.total_epoch,
initial_epoch=args.init_epoch,
validation_data=test_dataset,
callbacks=[logging, checkpoint, lr_scheduler, terminate_on_nan])
# Finally store model
model.save(log_dir + 'trained_final.h5')
но это дает мне странную ошибку:
Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py", line 1352, in _do_call
return fn(*args)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py", line 1337, in _run_fn
target_list, run_metadata)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py", line 1430, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: Cannot use rebatching fallback when 0th dimensions of dataset components are not fully known. Component 0 has shape: dim { size: -1 } dim { size: -1 } dim { size: -1 } dim { size: -1 }
[[{{node ExperimentalRebatchDataset_2_1}}]]
[[MultiDeviceIteratorInit_1/_4593]]
(1) Invalid argument: Cannot use rebatching fallback when 0th dimensions of dataset components are not fully known. Component 0 has shape: dim { size: -1 } dim { size: -1 } dim { size: -1 } dim { size: -1 }
[[{{node ExperimentalRebatchDataset_2_1}}]]
0 successful operations.
1 derived errors ignored.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "yolo3/models/backbones/imagenet_training/train_imagenet.py", line 242, in
main(args)
File "yolo3/models/backbones/imagenet_training/train_imagenet.py", line 200, in main
callbacks=[logging, checkpoint, lr_scheduler, terminate_on_nan])
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py", line 711, in fit
use_multiprocessing=use_multiprocessing)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_distributed.py", line 681, in fit
steps_name='steps_per_epoch')
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 203, in model_iteration
val_iterator = _get_iterator(val_inputs, model._distribution_strategy)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 546, in _get_iterator
inputs, distribution_strategy)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/distribute/distributed_training_utils.py", line 546, in get_iterator
initialize_iterator(iterator, distribution_strategy)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/distribute/distributed_training_utils.py", line 554, in initialize_iterator
K.get_session((init_op,)).run(init_op)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py", line 945, in run
run_metadata_ptr)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py", line 1168, in _run
feed_dict_tensor, options, run_metadata)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py", line 1346, in _do_run
run_metadata)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py", line 1371, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: Cannot use rebatching fallback when 0th dimensions of dataset components are not fully known. Component 0 has shape: dim { size: -1 } dim { size: -1 } dim { size: -1 } dim { size: -1 }
[[node ExperimentalRebatchDataset_2_1 (defined at /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py:1657) ]]
[[MultiDeviceIteratorInit_1/_4593]]
(1) Invalid argument: Cannot use rebatching fallback when 0th dimensions of dataset components are not fully known. Component 0 has shape: dim { size: -1 } dim { size: -1 } dim { size: -1 } dim { size: -1 }
[[node ExperimentalRebatchDataset_2_1 (defined at /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py:1657) ]]
0 successful operations.
1 derived errors ignored.
Original stack trace for 'ExperimentalRebatchDataset_2_1':
File "yolo3/models/backbones/imagenet_training/train_imagenet.py", line 242, in
main(args)
File "yolo3/models/backbones/imagenet_training/train_imagenet.py", line 200, in main
callbacks=[logging, checkpoint, lr_scheduler, terminate_on_nan])
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py", line 711, in fit
use_multiprocessing=use_multiprocessing)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_distributed.py", line 681, in fit
steps_name='steps_per_epoch')
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 203, in model_iteration
val_iterator = _get_iterator(val_inputs, model._distribution_strategy)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 546, in _get_iterator
inputs, distribution_strategy)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/distribute/distributed_training_utils.py", line 545, in get_iterator
iterator = distribution_strategy.make_dataset_iterator(dataset)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/distribute/distribute_lib.py", line 978, in make_dataset_iterator
return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/distribute/mirrored_strategy.py", line 522, in _make_dataset_iterator
split_batch_by=self._num_replicas_in_sync)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/distribute/input_lib.py", line 757, in __init__
input_context=input_context)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/distribute/input_lib.py", line 553, in __init__
input_context=input_context)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/distribute/input_lib.py", line 511, in __init__
cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/distribute/input_ops.py", line 57, in _clone_dataset
remap_dict = _clone_helper(dataset._variant_tensor.op, variant_tensor_ops)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/distribute/input_ops.py", line 97, in _clone_helper
op_def=_get_op_def(op_to_clone))
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py", line 3261, in create_op
attrs, op_def, compute_device)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py", line 3330, in _create_op_internal
op_def=op_def)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py", line 1657, in __init__
self._traceback = tf_stack.extract_stack()
Я не знаю, что вызвало ошибку. Я не нашел соответствующего решения в сети. Я надеюсь, что кто-то может помочь мне решить это, большое спасибо!