У меня довольно простая последовательная модель Keras, и я хотел бы загрузить ее для вывода на Spark Dataframe.Для этого я надеялся использовать sparkdl.KerasTransformer
.Если я обучу модель, я могу загрузить ее с h5 и использовать tensorflow.keras.models.load_model
и выполнить вывод на numpy.ndarray
без проблем.Однако, когда я загружаю его через sparkdl.KerasTransformer
и применяю его к кадру данных, я получаю:
TypeError: индексы кортежа должны быть целыми числами или слайсами, а не списком
Здесьэто минимальный пример, который включает два типа слоев, которые я хотел бы использовать.
import numpy, pandas, tensorflow.keras, sparkdl
def build_model():
n0 = tensorflow.keras.layers.BatchNormalization(input_shape=(3,),name='n0')
s = tensorflow.keras.layers.Dense(1,activation='sigmoid',name='s')
m = tensorflow.keras.models.Sequential()
m.add(n0)
m.add(s)
m.build(input_shape=(3,))
return m
# get some data (yes its noise, but that's not the issue here)
X = numpy.random.randn(100,3)
y = numpy.random.choice([0,1],size=100)
# build and fit a model
model = build_model()
model.compile(optimizer='adadelta',loss='binary_crossentropy')
history = model.fit(X,y,batch_size=32,epochs=8,verbose=0)
# save the model
model.save(model_filename)
# load the model and compare predictions (no error loading or executing the model through Keras)
m1 = tensorflow.keras.models.load_model(model_filename)
pred = model.predict(X)
pred1 = m1.predict(X)
print(numpy.abs(pred-pred1).max()) # predictions between trained and loaded model agree
# convert the data to a spark DF
df = pandas.DataFrame({"features":X.tolist(),"targets":y,"scores":pred[:,0]})
sparkDF = spark.createDataFrame(df)
# load the model as a sparkdl.KerasTransformer
transformer = sparkdl.KerasTransformer(inputCol="features",outputCol="scoreUDF",modelFile=model_filename)
# apply the model to the dataframe THIS PRODUCES THE ERROR
sparkDF1 = transformer.transform(sparkDF)
Похоже, виновник находится в ...python3.6/site-packages/keras/layers/normalization.py
---> 94 dim = input_shape[self.axis]
Но, кроме взлома кода и перекомпиляции, я не могу найти обходного пути.
Я использую Python 3.6 и Spark 2.4.0 со следующими версиями библиотеки:
- тензор потока: 1.12.0
- керас: 2.1.6-tf
- numpy: 1.14.3
- панды: 0.23.0
Любые советы / помощь приветствуются.