Я пытаюсь создать несколько противоборствующих примеров для случайного леса, добавив немного шума для проверки изображений и проверки, не обманывают ли они модель.Я сталкиваюсь с ошибкой при попытке найти градиент функции потерь (который я определил как функцию кросс-энтропии), как показано ниже.Я вставил сообщение об ошибке ниже.Есть идеи, что я делаю не так?
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=False)
mnist1 = input_data.read_data_sets("/tmp/data/", one_hot=True)
# Parameters
num_steps = 500 # Total steps to train
batch_size = 1024 # The number of samples per batch
num_classes = 10 # The 10 digits
num_features = 784 # Each image is 28x28 pixels
num_trees = 10
max_nodes = 1000
tf.reset_default_graph()
# Input and Target data
X = tf.placeholder(tf.float32, shape=[None, num_features])
# For random forest, labels must be integers (the class id)
Y = tf.placeholder(tf.float64, shape=[None])
Y1 = tf.placeholder(tf.float32, [None, 10]) # 0-9 digits recognition => 10 classes
# Random Forest Parameters
hparams = tensor_forest.ForestHParams(num_classes=num_classes,
num_features=num_features,
num_trees=num_trees,
max_nodes=max_nodes).fill()
# Build the Random Forest
forest_graph = tensor_forest.RandomForestGraphs(hparams)
# Get training graph and loss
train_op = forest_graph.training_graph(X, Y)
loss_op = forest_graph.training_loss(X, Y)
# Measure the accuracy
infer_op, _, _ = forest_graph.inference_graph(X)
correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
y_pred = tf.argmax(infer_op, 1)
y_pred1 = tf.nn.softmax(infer_op)
# Minimize error using cross entropy
loss = tf.reduce_mean(-tf.reduce_sum(Y1*tf.log(y_pred1)))
# loss = tf.reduce_mean(tf.square( tf.cast(y_pred, tf.float64) - Y))
grad_input = tf.gradients(loss, X)[0]
# Initialize the variables (i.e. assign their default value) and forest resources
init_vars = tf.group(tf.global_variables_initializer(),
resources.initialize_resources(resources.shared_resources()))
with tf.Session() as sess:
sess.run(init_vars)
#Training
for i in range(1, num_steps + 1):
# Prepare Data
# Get the next batch of MNIST data (only images are needed, not labels)
batch_x, batch_y = mnist.train.next_batch(batch_size)
_, l = sess.run([train_op, loss_op], feed_dict={X: batch_x, Y: batch_y})
if i % 50 == 0 or i == 1:
acc = sess.run(accuracy_op, feed_dict={X: batch_x, Y: batch_y})
print('Step %i, Loss: %f, Acc: %f' % (i, l, acc))
# Test Model
test_x, test_y = mnist.test.images, mnist.test.labels
print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: test_x, Y: test_y}))
# infer,correct_pred = sess.run([infer_op,correct_prediction], feed_dict={X: mnist.train.images, Y: mnist.train.labels})
#prediction for one image
y_prediction = sess.run(y_pred, feed_dict={X: mnist.test.images[12:13],Y: mnist.test.labels[12:13]})
#adversial
grad = sess.run([grad_input], feed_dict={X:mnist.test.images, Y1:mnist1.test.labels})
Сообщение об ошибке выглядит следующим образом:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-3-3fd00c3b2207> in <module>()
22
23 #adversial
---> 24 grad = sess.run([grad_input], feed_dict={X:mnist.test.images, Y1:mnist1.test.labels})
25 # grad = np.array(grad)
26 # grad = grad.reshape(10000,784)
/home/black-book/anaconda3/envs/py2tensor/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
898 try:
899 result = self._run(None, fetches, feed_dict, options_ptr,
--> 900 run_metadata_ptr)
901 if run_metadata:
902 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
/home/black-book/anaconda3/envs/py2tensor/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
1118 # Create a fetch handler to take care of the structure of fetches.
1119 fetch_handler = _FetchHandler(
-> 1120 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
1121
1122 # Run request and get response.
/home/black-book/anaconda3/envs/py2tensor/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in __init__(self, graph, fetches, feeds, feed_handles)
425 """
426 with graph.as_default():
--> 427 self._fetch_mapper = _FetchMapper.for_fetch(fetches)
428 self._fetches = []
429 self._targets = []
/home/black-book/anaconda3/envs/py2tensor/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in for_fetch(fetch)
243 elif isinstance(fetch, (list, tuple)):
244 # NOTE(touts): This is also the code path for namedtuples.
--> 245 return _ListFetchMapper(fetch)
246 elif isinstance(fetch, dict):
247 return _DictFetchMapper(fetch)
/home/black-book/anaconda3/envs/py2tensor/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in __init__(self, fetches)
350 """
351 self._fetch_type = type(fetches)
--> 352 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
353 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
354
/home/black-book/anaconda3/envs/py2tensor/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in for_fetch(fetch)
240 if fetch is None:
241 raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
--> 242 type(fetch)))
243 elif isinstance(fetch, (list, tuple)):
244 # NOTE(touts): This is also the code path for namedtuples.
TypeError: Fetch argument None has invalid type <type 'NoneType'>