Я использую TF версии 1.12 с conda и python 3. Мой вопрос касается значения model_dir в tf.contrib.factorization.KMeansClustering: как использовать заполнитель строк для значения model_dir?
Здеськонтекст: я предварительно обучил KMeans в другой ситуации, контрольные точки находятся в разных model_dir.
Я хочу использовать предсказания этих предварительно обученных моделей внутри графика, в зависимости от каждой ситуации, поэтому мне нужно поместить в этот графикKMeansClustering, который может принимать разные model_dirs.
На графике, который я определил:
ckpt_ph = tf.placeholder(tf.string)
...
kmeans = KMeansClustering(5, model_dir=ckpt_ph,distance_metric='cosine')
def input_fn():
return tf.train.limit_epochs(tf.convert_to_tensor(x, dtype=tf.float32), num_epochs=1)
centers_idx = list(kmeans.predict(input_fn,predict_keys='cluster_index',checkpoint_path=ckpt_ph,yield_single_examples=False))[0]['cluster_index']
centers_val = kmeans.cluster_centers()
...
И я запускаю его с помощью:
...
for ind in range(nb_cases):
...
sess.run([...], feed_dict={..., ckpt_ph: km_ckpt[ind]})
...
Где km_ckpt - список предварительно обученныхПуть к контрольным точкам KMeansClustering, который я хочу использовать для каждой ситуации.
Я получаю ошибку:
Traceback (most recent call last):
File "main.py", line 28, in <module>
tf.app.run()
File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\platform\app.py", line 125, in run
_sys.exit(main(argv))
File "main.py", line 23, in main
launch_training()
File "main.py", line 14, in launch_training
train_mnist.train_model()
File "C:\Users\Denis\ML\ScatteringReconstruction\src\model\train_mnist.py", line 355, in train_model
X_r = SR(X_tensor)
File "C:\Users\Denis\ML\ScatteringReconstruction\src\model\train_mnist.py", line 316, in __call__
kmeans = KMeansClustering(FLAGS.km_k, model_dir=ckpt_ph, distance_metric='cosine')
File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\factorization\python\ops\kmeans.py", line 423, in __init__
config=config)
File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 189, in __init__
model_dir)
File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 1665, in maybe_overwrite_model_dir_and_session_config
if model_dir:
File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 671, in __bool__
raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.
Мне кажется, что проблема в том, что в KMeansClustering и KMeansClustering.predict, model_dirожидая Python bool или string, и я даю ему Тензор, но потом я не вижу, чтобы hos использовала предварительно обученные KMeansКластеризация внутри графа.Заранее спасибо за помощь!