Градиент случайной потери леса в тензорном потоке - PullRequest
0 голосов
/ 08 июня 2018

Я пытаюсь создать несколько противоборствующих примеров для случайного леса, добавив немного шума для проверки изображений и проверки, не обманывают ли они модель.Я сталкиваюсь с ошибкой при попытке найти градиент функции потерь (который я определил как функцию кросс-энтропии), как показано ниже.Я вставил сообщение об ошибке ниже.Есть идеи, что я делаю не так?

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=False)
mnist1 = input_data.read_data_sets("/tmp/data/", one_hot=True)

# Parameters
num_steps = 500 # Total steps to train
batch_size = 1024 # The number of samples per batch
num_classes = 10 # The 10 digits
num_features = 784 # Each image is 28x28 pixels
num_trees = 10
max_nodes = 1000

tf.reset_default_graph()
# Input and Target data
X = tf.placeholder(tf.float32, shape=[None, num_features])
# For random forest, labels must be integers (the class id)
Y = tf.placeholder(tf.float64, shape=[None])
Y1 = tf.placeholder(tf.float32, [None, 10]) # 0-9 digits recognition => 10 classes


# Random Forest Parameters
hparams = tensor_forest.ForestHParams(num_classes=num_classes,
                                      num_features=num_features,
                                      num_trees=num_trees,
                                      max_nodes=max_nodes).fill()

# Build the Random Forest
forest_graph = tensor_forest.RandomForestGraphs(hparams)
# Get training graph and loss
train_op = forest_graph.training_graph(X, Y)
loss_op = forest_graph.training_loss(X, Y)

# Measure the accuracy
infer_op, _, _ = forest_graph.inference_graph(X)
correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
y_pred = tf.argmax(infer_op, 1)
y_pred1 = tf.nn.softmax(infer_op)


# Minimize error using cross entropy
loss = tf.reduce_mean(-tf.reduce_sum(Y1*tf.log(y_pred1)))
# loss = tf.reduce_mean(tf.square( tf.cast(y_pred, tf.float64) - Y))
grad_input = tf.gradients(loss, X)[0]

# Initialize the variables (i.e. assign their default value) and forest resources
init_vars = tf.group(tf.global_variables_initializer(),
    resources.initialize_resources(resources.shared_resources()))

with tf.Session() as sess:
    sess.run(init_vars)
    #Training
    for i in range(1, num_steps + 1):
        # Prepare Data
        # Get the next batch of MNIST data (only images are needed, not labels)
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        _, l = sess.run([train_op, loss_op], feed_dict={X: batch_x, Y: batch_y})
        if i % 50 == 0 or i == 1:
            acc = sess.run(accuracy_op, feed_dict={X: batch_x, Y: batch_y})
            print('Step %i, Loss: %f, Acc: %f' % (i, l, acc))

    # Test Model
    test_x, test_y = mnist.test.images, mnist.test.labels
    print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: test_x, Y: test_y}))
#     infer,correct_pred = sess.run([infer_op,correct_prediction], feed_dict={X:  mnist.train.images, Y:  mnist.train.labels})
    #prediction for one image
    y_prediction = sess.run(y_pred, feed_dict={X:  mnist.test.images[12:13],Y:  mnist.test.labels[12:13]})

    #adversial 
    grad = sess.run([grad_input], feed_dict={X:mnist.test.images, Y1:mnist1.test.labels})

Сообщение об ошибке выглядит следующим образом:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-3-3fd00c3b2207> in <module>()
     22 
     23     #adversial
---> 24     grad = sess.run([grad_input], feed_dict={X:mnist.test.images, Y1:mnist1.test.labels})
     25 #     grad = np.array(grad)
     26 #     grad = grad.reshape(10000,784)

/home/black-book/anaconda3/envs/py2tensor/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
    898     try:
    899       result = self._run(None, fetches, feed_dict, options_ptr,
--> 900                          run_metadata_ptr)
    901       if run_metadata:
    902         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/home/black-book/anaconda3/envs/py2tensor/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1118     # Create a fetch handler to take care of the structure of fetches.
   1119     fetch_handler = _FetchHandler(
-> 1120         self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
   1121 
   1122     # Run request and get response.

/home/black-book/anaconda3/envs/py2tensor/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in __init__(self, graph, fetches, feeds, feed_handles)
    425     """
    426     with graph.as_default():
--> 427       self._fetch_mapper = _FetchMapper.for_fetch(fetches)
    428     self._fetches = []
    429     self._targets = []

/home/black-book/anaconda3/envs/py2tensor/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in for_fetch(fetch)
    243     elif isinstance(fetch, (list, tuple)):
    244       # NOTE(touts): This is also the code path for namedtuples.
--> 245       return _ListFetchMapper(fetch)
    246     elif isinstance(fetch, dict):
    247       return _DictFetchMapper(fetch)

/home/black-book/anaconda3/envs/py2tensor/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in __init__(self, fetches)
    350     """
    351     self._fetch_type = type(fetches)
--> 352     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
    353     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
    354 

/home/black-book/anaconda3/envs/py2tensor/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in for_fetch(fetch)
    240     if fetch is None:
    241       raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
--> 242                                                                  type(fetch)))
    243     elif isinstance(fetch, (list, tuple)):
    244       # NOTE(touts): This is also the code path for namedtuples.

TypeError: Fetch argument None has invalid type <type 'NoneType'>
...