Я пытаюсь создать файл .npy, разбивая массивы на подгруппы train и test, но, пытаясь это сделать, я сталкиваюсь с этой ошибкой, которая указывает на validation.py и _split.py в папке Sklearn , Я использую Python 3.6.7
import numpy as np
import librosa
import os
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from sklearn import utils
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
from keras.utils import to_categorical
DATA_PATH = "./Angry/"
X=0
y=0
split_ratio=0.6
random_state=42
def get_train_test():
labels, indices, _ = get_labels(DATA_PATH)
# Getting first arrays
X= np.load(labels[0] + '.npy')
y= np.zeros(X.shape[0])
# Append all of the dataset into one single array, same goes for y
for i, label in enumerate(labels[1:]):
x = np.load(label + '.npy')
X = np.vstack((X, x))
y = np.append(y, np.full(x.shape[0], fill_value= (i + 1)))
assert X.shape[0] == len(y)
return train_test_split(X, y, test_size= (1 - split_ratio), random_state=random_state, shuffle=True)
def wav2mfcc(file_path, max_pad_len=11):
wave, sr = librosa.load(file_path, mono=True, sr=None)
wave = wave[::3]
mfcc = librosa.feature.mfcc(wave, sr=16000)
pad_width = max_pad_len - mfcc.shape[1]
mfcc = np.pad(mfcc, pad_width=((0, 0), (0, pad_width)),mode='constant')
return mfcc
def get_labels(path=DATA_PATH):
labels = os.listdir(path)
label_indices = np.arange(0, len(labels))
return labels, label_indices, to_categorical(label_indices)
def save_data_to_array(path=DATA_PATH, max_pad_len=11):
labels, _, _ = get_labels(path)
for label in labels:
# Init mfcc vectors
mfcc_vectors = []
wavfiles = [path + label + '/' + wavfile for wavfile in os.listdir(path + '/' + label)]
for wavfile in wavfiles:
mfcc = wav2mfcc(wavfile, max_pad_len=max_pad_len)
mfcc_vectors.append(mfcc)
np.save(label + '.npy', mfcc_vectors)
if __name__=='__main__':
feature_dim_2 = 11
# Save data to array file first
save_data_to_array(max_len=feature_dim_2)
# # Loading train set and test set
X_train, X_test, y_train, y_test = get_train_test()
# # Feature dimension
feature_dim_1 = 20
channel = 1
epochs = 50
batch_size = 100
verbose = 1
num_classes = 3
# Reshaping to perform 2D convolution
X_train = X_train.reshape(X_train.shape[0], feature_dim_1,feature_dim_2, channel)
X_test = X_test.reshape(X_test.shape[0], feature_dim_1, feature_dim_2, channel)
y_train_hot = to_categorical(y_train)
y_test_hot = to_categorical(y_test)
TypeError Traceback (последний последний вызов)
cell_name in async-def-wrapper()
~/anaconda3/envs/tensorflow/lib/python3.6/site-packages/sklearn
/model_selection/_split.py in train_test_split(*arrays, **options)
2182 test_size = 0.25
2183
--> 2184 arrays = indexable(*arrays)
2185
2186 if shuffle is False:
~/anaconda3/envs/tensorflow/lib/python3.6/site-packages/sklearn/utils/validation.py
в индексируемых (* итерируемых)
258 """Make arrays indexable for cross-validation.
259
--> 260 Checks consistent length, passes through None, and ensures that everything
261 can be indexed by converting sparse matrices to csr and converting
262 non-interable objects to arrays.
~/anaconda3/envs/tensorflow/lib/python3.6/site-packages/sklearn/utils/validation.py
in check_consistent_length (* массивы)
229 memory = Memory(cachedir=memory, verbose=0)
230 else:
--> 231 memory = Memory(location=memory, verbose=0)
232 elif not hasattr(memory, 'cache'):
233 raise ValueError("'memory' should be None, a string or have the same"
~/anaconda3/envs/tensorflow/lib/python3.6/site-packages/sklearn/utils/validation.py
дюйм (.0)
229 memory = Memory(cachedir=memory, verbose=0)
230 else:
--> 231 memory = Memory(location=memory, verbose=0)
232 elif not hasattr(memory, 'cache'):
233 raise ValueError("'memory' should be None, a string or have the same"
~/anaconda3/envs/tensorflow/lib/python3.6/site-packages/sklearn/utils/validation.py
в _num_samples (x)
140 raise TypeError("Singleton array %r cannot be considered"
141 " a valid collection." % x)
--> 142 return x.shape[0]
143 else:
144 return len(x)
TypeError: Singleton array array(0) cannot be considered a valid collection.