разбор cifar-10 и изменение imgs на HSV - PullRequest
0 голосов
/ 19 февраля 2020

Я читаю набор данных cifar-10 с этим кодом и хочу найти способ изменить его HSV. Я попытался поместить списки данных и меток вне функций, но я получил эту ошибку: UnboundLocalError: local variable 'data' referenced before assignment

Как я могу вынуть эти списки, чтобы я мог перейти на HSV, а затем на гистограмму набора данных.

import pickle
import numpy as np
from os.path import join
from os import listdir
import matplotlib.pyplot as plt
from tqdm import tqdm
import struct as st

class DataReader:

    def __init__(self,root_dir,type='cifar-100'):
        self.root_dir = root_dir
        self.type = type

    def get_dict_from_pickle(self):
            self.train_dict = unpickle(join(self.root_dir,'train'))
            self.test_dict = unpickle(join(self.root_dir,'test'))

    def get_train_data(self):
        if self.type == 'cifar-100':
            self.get_dict_from_pickle()
            data = np.array(self.train_dict[b'data'])
            lbls_sub = np.array(self.train_dict[b'fine_labels'])
            lbls_class = np.array(self.train_dict[b'coarse_labels'])
            return data,lbls_class,lbls_sub
        elif self.type == 'cifar-10':
            #data = []
            #labels = []
            print("Reading")
            for file_ in tqdm(listdir(self.root_dir)):
                if file_.split('_')[0] == 'data':
                    dict = unpickle(join(self.root_dir,file_))
                    data.extend(dict[b'data'])
                    labels.extend(dict[b'labels'])

            return np.array(data),np.array(labels),None
        elif self.type =='mnist':
            return self.read_mnist()

    def get_test_data(self):
        if self.type == 'cifar-100':
            self.get_dict_from_pickle()
            data = np.array(self.test_dict[b'data'])
            lbls_sub = np.array(self.test_dict[b'fine_labels'])
            lbls_class = np.array(self.test_dict[b'coarse_labels'])
            return data,lbls_class,lbls_sub
        elif self.type == 'cifar-10':
            data = np.empty(shape=(0,3072))
            labels = []
            for file_ in listdir(self.root_dir):
                if file_.split('_')[0] == 'test':
                    dict = unpickle(join(self.root_dir,file_))
                    data = np.vstack((data,dict[b'data']))
                    print(data[data.shape[0]-1])
                    labels.append(dict[b'labels'])
            return np.array(data),np.array(labels),None

    def reshape_to_plot(self,data):
        if self.type == 'mnist':
            return data.reshape(data.shape[0],28,28).astype("uint8")
        return data.reshape(data.shape[0],3,32,32).transpose(0,2,3,1).astype("uint8")

    def plot_imgs(self,in_data,n,random=False):
        data = np.array([d for d in in_data])
        data = self.reshape_to_plot(data)
        x1 = min(n//2,5)
        if x1 == 0:
            x1 = 1
        y1 = (n//x1)
        x = min(x1,y1)
        y = max(x1,y1)
        fig, ax = plt.subplots(x,y,figsize=(5,5))
        i=0
        for j in range(x):
            for k in range(y):
                if random:
                    i = np.random.choice(range(len(data)))
                ax[j][k].set_axis_off()
                ax[j][k].imshow(data[i:i+1][0])
                i+=1
        plt.show()

    def plot_img(self,data):
        if self.type !='mnist':
            assert data.shape == (3072,)
            data = data.reshape(1,3072)
            data = data.reshape(data.shape[0],3,32,32).transpose(0,2,3,1).astype("uint8")
        elif self.type == 'mnist':
            assert data.shape == (28*28,)
            data = data.reshape(1,28,28).astype('uint8')
        fig, ax = plt.subplots(figsize=(5,5))
        ax.imshow(data[0])
        plt.show()

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

1 Ответ

0 голосов
/ 20 февраля 2020

вот что я сделал в конце, и это сработало "" "из keras.datasets import cifar10 import matplotlib.pyplot as plt import cv2

       (x_train, y_train), (x_test, y_test) = cifar10.load_data()
       for i in range(0,50000):
           hsv_image = cv2.cvtColor(x_train[i] , cv2.COLOR_RGB2HSV)
           hue ,  sat ,  val  =  hsv_image [:,:, 0 ],  hsv_image [:,:, 1 ],  
           hsv_image [:,: , 2 ]
       import numpy as np

       plt.figure(figsize=(10,8))
       plt.subplot(311)                             #plot in the first cell
       plt.subplots_adjust(hspace=.5)
       plt.title("Hue")
       plt.hist(np.ndarray.flatten(hue), bins=8)
       plt.subplot(312)                             #plot in the second cell
       plt.title("Saturation")
       plt.hist(np.ndarray.flatten(sat), bins=4)
       plt.subplot(313)                             #plot in the third cell
       plt.title("Luminosity Value")
       plt.hist(np.ndarray.flatten(val), bins=2)
       plt.show()
...