Как использовать строковый заполнитель для model_dir в tf.contrib.factorization.KMeansClustering? - PullRequest
0 голосов
/ 05 февраля 2019

Я использую TF версии 1.12 с conda и python 3. Мой вопрос касается значения model_dir в tf.contrib.factorization.KMeansClustering: как использовать заполнитель строк для значения model_dir?

Здеськонтекст: я предварительно обучил KMeans в другой ситуации, контрольные точки находятся в разных model_dir.

Я хочу использовать предсказания этих предварительно обученных моделей внутри графика, в зависимости от каждой ситуации, поэтому мне нужно поместить в этот графикKMeansClustering, который может принимать разные model_dirs.

На графике, который я определил:

ckpt_ph = tf.placeholder(tf.string)
...
kmeans = KMeansClustering(5, model_dir=ckpt_ph,distance_metric='cosine')
def input_fn():
    return tf.train.limit_epochs(tf.convert_to_tensor(x, dtype=tf.float32), num_epochs=1)
centers_idx = list(kmeans.predict(input_fn,predict_keys='cluster_index',checkpoint_path=ckpt_ph,yield_single_examples=False))[0]['cluster_index']
centers_val = kmeans.cluster_centers()
...

И я запускаю его с помощью:

...
for ind in range(nb_cases):
    ...
    sess.run([...], feed_dict={..., ckpt_ph: km_ckpt[ind]})
...

Где km_ckpt - список предварительно обученныхПуть к контрольным точкам KMeansClustering, который я хочу использовать для каждой ситуации.

Я получаю ошибку:

Traceback (most recent call last):
  File "main.py", line 28, in <module>
    tf.app.run()
  File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\platform\app.py", line 125, in run
    _sys.exit(main(argv))
  File "main.py", line 23, in main
    launch_training()
  File "main.py", line 14, in launch_training
    train_mnist.train_model()
  File "C:\Users\Denis\ML\ScatteringReconstruction\src\model\train_mnist.py", line 355, in train_model
    X_r = SR(X_tensor)
  File "C:\Users\Denis\ML\ScatteringReconstruction\src\model\train_mnist.py", line 316, in __call__
    kmeans = KMeansClustering(FLAGS.km_k, model_dir=ckpt_ph, distance_metric='cosine')
  File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\factorization\python\ops\kmeans.py", line 423, in __init__
    config=config)
  File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 189, in __init__
    model_dir)
  File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 1665, in maybe_overwrite_model_dir_and_session_config
    if model_dir:
  File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 671, in __bool__
    raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

Мне кажется, что проблема в том, что в KMeansClustering и KMeansClustering.predict, model_dirожидая Python bool или string, и я даю ему Тензор, но потом я не вижу, чтобы hos использовала предварительно обученные KMeansКластеризация внутри графа.Заранее спасибо за помощь!

...