30 наборов кажутся слишком низкими для задачи, так как вход и выход имеют слишком много измерений (30).Если вам необходимо отобразить эти высокоразмерные данные, вам понадобятся тысячи примеров (больше наборов).
Я бы предложил смоделировать преобразования для генерации нескольких тысяч выборок.Затем используйте небольшую нейронную сеть, чтобы предсказать Y по X. Поскольку входные данные не имеют пространственных или временных измерений, вместо этого представляют дискретные точки, я не думаю, что сверточные или рекуррентные модели будут полезны.
Итак,начать с малого MLP со средним квадратом потери ошибок.Однако, если выходные точки всегда целочисленные, вы можете смоделировать их как проблему классификации, учитывая, что диапазон невелик.
Я добавил небольшую модель нейронной сети в кератах, чтобы предсказать преобразование.
import numpy as np
import keras
import tensorflow
from keras.layers import Input, Dense, Reshape
from keras.models import Model
X = np.random.randint(-100, 100, (3000, 10, 3)) # 10 3d points
Y = 2*(X + 5)/7 # this is our simple transformation operation
print(X.shape)
print(Y.shape)
in_m = Input(shape=(30,)) # input layer
f1_fc = Dense(100, activation = 'relu')(in_m) # first fc layer
f2_fc = Dense(30, activation = 'linear')(f1_fc) # second fc layer
simple_model = Model(in_m, f2_fc)
simple_model.summary()
simple_model.compile(loss='mse', metrics=['mae'], optimizer='adam')
X_flat = np.reshape(X, (3000, 30))
Y_flat = np.reshape(Y, (3000, 30))
hist = simple_model.fit(X_flat, Y_flat, epochs = 100, validation_split = 0.2, batch_size = 20)
Выход:
(3000, 10, 3)
(3000, 10, 3)
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_9 (InputLayer) (None, 30) 0
_________________________________________________________________
dense_15 (Dense) (None, 100) 3100
_________________________________________________________________
dense_16 (Dense) (None, 30) 3030
=================================================================
Total params: 6,130
Trainable params: 6,130
Non-trainable params: 0