Я пытаюсь сделать классификатор, и я продолжаю получать эту ошибку, которая действительно смущает меня.Поскольку я действительно новичок в машинном обучении, я ничего не могу найти в интернете для этого.
ОШИБКА
AssertionError: Incoming Tensor shape must be 4-D
Код
IMG_SIZE = 64
tf.reset_default_graph()
convnet = input_data(shape=[1,IMG_SIZE,IMG_SIZE,1],name='input')
convnet = conv_2d(convnet, 32, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)
convnet = conv_2d(convnet, 64, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)
convnet = conv_2d(convnet, 128, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)
convnet = conv_2d(convnet, 64, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)
convnet = conv_2d(convnet, 32, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)
convnet = fully_connected(convnet, 1024, activation='relu')
convnet = dropout(convnet, 0.8)
convnet = fully_connected(convnet, 2, activation='softmax')
convnet = regression(convnet, optimizer='adam', learning_rate=LR, loss='categorical_crossentropy', name='targets')
model = tflearn.DNN(convnet, tensorboard_dir='log', tensorboard_verbose=0)
model.fit({'input': X_train}, {'targets': y_train}, n_epoch=10,
validation_set=({'input': X_test}, {'targets': y_test}),
snapshot_step=500, show_metric=True, run_id=MODEL_NAME)
если я дам convnet = input_data(shape=[None,IMG_SIZE,IMG_SIZE,1],name='input')
, это выдаст мне эту ошибку
Exception in thread Thread-3:
Traceback (most recent call last):
File "C:\Users\zeele\Miniconda3\lib\threading.py", line 916, in _bootstrap_inner
self.run()
File "C:\Users\zeele\Miniconda3\lib\threading.py", line 864, in run
self._target(*self._args, **self._kwargs)
File "C:\Users\zeele\Miniconda3\lib\site-packages\tflearn\data_flow.py", line 187, in fill_feed_dict_queue
data = self.retrieve_data(batch_ids)
File "C:\Users\zeele\Miniconda3\lib\site-packages\tflearn\data_flow.py", line 222, in retrieve_data
utils.slice_array(self.feed_dict[key], batch_ids)
File "C:\Users\zeele\Miniconda3\lib\site-packages\tflearn\utils.py", line 187, in slice_array
return X[start]
TypeError: 'generator' object is not subscriptable