Я создаю программу, которая дает обучающим изображениям фигур, а затем набор тестовых данных фигур, чтобы увидеть, насколько точным является машинное обучение. Моя проблема исходит из части программы поезда, где я продолжаю получать ошибку индекса для моего кода:
import cv2 # working with, mainly resizing, images
import numpy as np # dealing with arrays
import os # dealing with directories
from random import shuffle # mixing up or currently ordered data that might lead our network astray in training.
from tqdm import tqdm
TRAIN_DIR = '/Users/a/desktop/KerasShapes'
TEST_DIR = '/Users/a/desktop/test shapes'
IMG_SIZE = 50
LR = 1e-3
MODEL_NAME = 'square/triangle/-{}-{}.model'.format(LR, '3conv-basic') # just so we remember which saved model is which, sizes must match
def label_img(img):
word_label = img.split('.')[-2]
# conversion to one-hot array [square,triangle,star,optagon,heptagon,circle]
#
if word_label == 'square': return ([1,0,0,0,0,0])
#
elif word_label == 'triangle': return ([0,1,0,0,0,0])
#
elif word_label == 'star': return ([0,0,1,0,0,0])
#
elif word_label == 'optagon': return ([0,0,0,1,0,0])
#
elif word_label == 'heptagon': return ([0,0,0,0,1,0])
#
`elif` word_label == 'circle': return ([0,0,0,0,0,1])
def create_train_data():
training_data = []
for img in tqdm(os.listdir(TRAIN_DIR)):
label = label_img(img)
path = os.path.join(TRAIN_DIR,img)
img = cv2.imread(path,cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (IMG_SIZE,IMG_SIZE))
training_data.append([np.array(img),np.array(label)])
shuffle(training_data)
np.save('train_data.npy', training_data)
return training_data
def process_test_data():
testing_data = []
for img in tqdm(os.listdir(TEST_DIR)):
path = os.path.join(TEST_DIR,img)
img_num = img.split('.')[0]
img = cv2.imread(path,cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (IMG_SIZE,IMG_SIZE))
testing_data.append([np.array(img), img_num])
shuffle(testing_data)
np.save('test_data.npy', testing_data)
return testing_data
train_data = create_train_data()
информация отладчика Когда я запускаю код, я получаю IndexError: индекс списка вне допустимого диапазона в отладчике .
PS заранее благодарю вас за любую помощь, которую вы можете оказать. Я очень новичок в python.