Я пытаюсь обучить модель, определенную здесь , используя функцию train
, определенную в том же скрипте. Я попытался использовать файл записи tf, предоставленный авторами в по этой ссылке (для этого единственными изменениями, которые я сделал, были изменение init_path=None
и запись пути к файлу tfrecord в файле txt train_list=train_tfs.txt
, оба определены в shift_params.py
shift_v1
).
Однако, даже попробовав этот простой тест, я получаю следующую ошибку:
OutOfRangeError (see above for traceback): RandomShuffleQueue '_1_shuffle_batch_join/random_shuffle_queue' is closed and has insufficient elements (requested 15, current size 0)
[[Node: shuffle_batch_join = QueueDequeueManyV2[component_types=[DT_UINT8, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64, DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/device:CPU:0"](shuffle_batch_join/random_shuffle_queue, shuffle_batch_join/n)]]
Насколько я поймите, эта ошибка означает, что данные не загружаются. Итак, я предполагаю, что проблема заключается в функции read_example
в shift_dset.py
, как показано ниже (я сохранил только часть кода, которая загружает изображения ims
и звук sound
. Для полного кода , пожалуйста, отметьте здесь .)
def read_example(rec_queue, pr, input_types):
reader = tf.TFRecordReader()
k, serialized_example = reader.read(rec_queue)
full = pr.full_im_dim
feats = {}
feats['im_0'] = tf.FixedLenFeature([], dtype=tf.string)
feats['im_1'] = tf.FixedLenFeature([], dtype=tf.string)
feats['sound'] = tf.FixedLenFeature([], dtype=tf.string)
if pr.variable_frame_count:
feats['num_frames'] = tf.FixedLenFeature([1], dtype=tf.int64)
example = tf.parse_single_example(serialized_example, features = feats)
total_frames = pr.total_frames if not pr.variable_frame_count \
else tf.cast(example['num_frames'][0], tf.int32)
assert not pr.variable_frame_count
def f(x):
x.set_shape((full*total_frames/2, full, 3))
return tf.reshape(x, (total_frames/2, full, full, 3))
im_parts = map(f, [tf.image.decode_jpeg(example['im_0'], channels = 3, name = 'decode_im1'),
tf.image.decode_jpeg(example['im_1'], channels = 3, name = 'decode_im2')])
samples = tf.decode_raw(example['sound'], tf.int16)
samples.set_shape((pr.full_samples_len*2))
samples = tf.reshape(samples, (pr.full_samples_len, 2))
samples = tf.cast(samples, 'float32') / np.iinfo(np.dtype(np.int16)).max
num_slice_frames = pr.sampled_frames
num_samples = int(pr.samples_per_frame * float(num_slice_frames))
if pr.do_shift:
choices = []
max_frame = total_frames - num_slice_frames
frames1 = ([0] if pr.fix_frame else xrange(max_frame))
for frame1 in frames1:
found = False
for frame2 in reversed(range(max_frame)):
inv1 = xrange(frame1, frame1 + num_slice_frames)
inv2 = xrange(frame2, frame2 + num_slice_frames)
if len(set(inv1).intersection(inv2)) <= pr.max_intersection:
found = True
choices.append([frame1, frame2])
if pr.fix_frame:
break
if pr.skip_notfound:
pass
else:
assert found
print 'Number of frame choices:', len(choices)
choices = tf.constant(np.array(choices), dtype = tf.int32)
idx = tf.random_uniform([1], 0, shape(choices, 0), dtype = tf.int64)[0]
start_frame_gt = choices[idx, 0]
shift_frame = choices[idx, 1]
elif ut.hastrue(pr, 'use_first_frame'):
shift_frame = start_frame_gt = tf.constant(0, dtype = tf.int32)
else:
shift_frame = start_frame_gt = tf.random_uniform(
[1], 0, total_frames - num_slice_frames, dtype = tf.int32)[0]
if pr.augment_ims:
print 'Augment:', pr.augment_ims
r = tf.random_uniform(
[2], 0, pr.full_im_dim - pr.crop_im_dim, dtype = tf.int32)
x, y = r[0], r[1]
else:
if hasattr(pr, 'resize_dims'):
y = pr.resize_dims[0]/2 - pr.crop_im_dim/2
x = pr.resize_dims[1]/2 - pr.crop_im_dim/2
else:
y = x = pr.full_im_dim/2 - pr.crop_im_dim/2
offset = [start_frame_gt, y, x, 0]
d = pr.crop_im_dim
size_im = [num_slice_frames, d, d, 3]
slice_parts = []
for j in xrange(len(im_parts)):
num_frames_in_part = total_frames/len(im_parts)
part_start = j*num_frames_in_part
frame_offset = tf.maximum(0, tf.minimum(start_frame_gt - part_start, num_frames_in_part))
end_offset = tf.maximum(0, tf.minimum(start_frame_gt + num_slice_frames - part_start, num_frames_in_part))
num_frames_in_part_slice = tf.maximum(0, end_offset - frame_offset)
offset = [frame_offset, y, x, 0]
d = pr.crop_im_dim
size_im = [num_frames_in_part_slice, d, d, 3]
p = tf.slice(im_parts[j], offset, size_im)
slice_parts.append(p)
ims_slice = tf.concat(slice_parts, 0)
ims_slice.set_shape([num_slice_frames, pr.crop_im_dim, pr.crop_im_dim, 3])
ims = ims_slice
if pr.augment_ims:
ims = tf.cond(tf.cast(tf.random_uniform([1], 0, 2, dtype = tf.int64)[0], tf.bool),
lambda : tf.map_fn(tf.image.flip_left_right, ims),
lambda : ims)
def slice_samples(frame):
start = round_int(pr.samples_per_frame * cast_float(frame))
offset = [start, 0]
size = [num_samples, 2]
r = tf.slice(samples, offset, size, name = 'slice_sample')
r.set_shape([num_samples] + list(shape(r)[1:]))
return r
if 'samples' in input_types:
samples_gt = slice_samples(start_frame_gt)
samples_shift = slice_samples(shift_frame)
else:
samples_gt = samples_shift = tf.zeros((1, 1), dtype = tf.int16)
samples_exs = tf.concat([ed(samples_shift, 0), ed(samples_gt, 0)], 0)
return ims, _, samples_exs, _, _, _
Я также пытался загрузить свои собственные tfrecords, изменив эту функцию, и получил аналогичную ошибку. Кроме того, когда не указаны пути tfrecord (т.е. train_list=train_tfs.txt
не содержит пути к tfrecord, поэтому невозможно что-либо загрузить), я получаю ту же ошибку, которая указывает на то, что что-то не так с тем, как данные загружается.
Заранее благодарим вас за любую помощь.
Как воспроизвести тот же тест
- Клонировать код в здесь
- Загрузите образец данных в здесь
- Измените параметры в
shift_v1
из shift_params.py
: init_path=None
и содержимое train_tfs.txt
(он должен содержать путь к загруженному tf-файлу) Из папки src
выполните следующую команду:
python - c "import shift_params , shift_net; shift _net .train (shift_params.shift_v1 (num_gpus = 3), [0, 1, 2], restore = False) "
Сведения о среде
Я использую следующие пакеты в среде anaconda (Linux):
- Python 2.7.18
- tensorflow, t Ensorflow-gpu и tenorflow-base 1.9.0
- numpy, matplotlib, подушка и scipy