Я нашел этот код, и он отлично работает. Идея - разделить мои данные и обучить кластеризации KMeans. Поэтому я создаю InitHook и итератор и использую его для обучения.
class _IteratorInitHook(tf.train.SessionRunHook):
"""Hook to initialize data iterator after session is created."""
def __init__(self):
super(_IteratorInitHook, self).__init__()
self.iterator_initializer_fn = None
def after_create_session(self, session, coord):
"""Initialize the iterator after the session has been created."""
del coord
self.iterator_initializer_fn(session)
# Run K-means clustering.
def _get_input_fn():
"""Helper function to create input function and hook for training.
Returns:
input_fn: Input function for k-means Estimator training.
init_hook: Hook used to load data during training.
"""
init_hook = _IteratorInitHook()
def _input_fn():
"""Produces tf.data.Dataset object for k-means training.
Returns:
Tensor with the data for training.
"""
features_placeholder = tf.placeholder(tf.float32,
my_data.shape)
delf_dataset = tf.data.Dataset.from_tensor_slices((features_placeholder))
delf_dataset = delf_dataset.shuffle(1000).batch(
my_data.shape[0])
iterator = delf_dataset.make_initializable_iterator()
def _initializer_fn(sess):
"""Initialize dataset iterator, feed in the data."""
sess.run(
iterator.initializer,
feed_dict={features_placeholder: my_data})
init_hook.iterator_initializer_fn = _initializer_fn
return iterator.get_next()
return _input_fn, init_hook
input_fn, init_hook = _get_input_fn()
output_cluster_dir = 'parameters/clusters'
kmeans = tf.contrib.factorization.KMeansClustering(
num_clusters=1024,
model_dir=output_cluster_dir,
use_mini_batch=False,
)
print('Starting K-means clustering...')
kmeans.train(input_fn, hooks=[init_hook])
Но если я изменю num_clusters на 512 или 256, я получу следующую ошибку:
InvalidArgumentError: сегмент_ids [0] = 600 находится вне диапазона [0, 256)
[[узел UnsortedSegmentSum (определенный в
/home/mikhail/.conda/envs/tf2/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py:1112)
]] [[Сжатие узла (определено в
/home/mikhail/.conda/envs/tf2/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py:1112)
]]
Похоже, у меня есть некоторые проблемы с разделением данных на пакеты ИЛИ мои KMeans по умолчанию используют 1024 кластера, даже если я установил другое значение!
Я не могу понять, что нужно изменить, чтобы заставить его работать правильно.
Трассировка огромна, если нужно, я могу прикрепить ее в виде файла.