Я использую API данных Tensorflow (Tensorflow 2.0) для обучения модели глубокого обучения на 4 графических процессорах. У меня 3 процессора, и загрузка процессора составляет 30%, у меня 4 графических процессора, и загрузка моего графического процессора составляет 100%, но я застрял в первой эпохе. Вот мой код:
def get_dataset():
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.05,shuffle=True,random_state=42)
Train=tensorflow.data.Dataset.from_tensor_slices((x_train,y_train))
Train = Train.batch(128,drop_remainder=True)
Train = Train.prefetch(buffer_size=10)
Test=tensorflow.data.Dataset.from_tensor_slices((x_test,y_test))
Test = Test.batch(len(x_test),drop_remainder=True)
Test = Test.prefetch(buffer_size=10)
return(Train,Test,)
train_dataset, test_dataset = get_dataset()
model.fit(train_dataset,validation_data=test_dataset, callbacks=callbacks_list, verbose=2,epochs=10)
Кто-нибудь знает, что происходит? Когда я использую один CPU и 2 GPU, то все работает нормально! Но когда я использую 3CPU и 4GPU, я застреваю в первой эпохе! (Размер набора данных = 400 тыс. Выборок)