sparkdl TypeError при загрузке хорошего файла h5 - PullRequest
0 голосов
/ 07 мая 2019

У меня довольно простая последовательная модель Keras, и я хотел бы загрузить ее для вывода на Spark Dataframe.Для этого я надеялся использовать sparkdl.KerasTransformer.Если я обучу модель, я могу загрузить ее с h5 и использовать tensorflow.keras.models.load_model и выполнить вывод на numpy.ndarray без проблем.Однако, когда я загружаю его через sparkdl.KerasTransformer и применяю его к кадру данных, я получаю:

TypeError: индексы кортежа должны быть целыми числами или слайсами, а не списком

Здесьэто минимальный пример, который включает два типа слоев, которые я хотел бы использовать.

import numpy, pandas, tensorflow.keras, sparkdl

def build_model():
  n0 = tensorflow.keras.layers.BatchNormalization(input_shape=(3,),name='n0')
  s = tensorflow.keras.layers.Dense(1,activation='sigmoid',name='s')

  m = tensorflow.keras.models.Sequential()
  m.add(n0)
  m.add(s)
  m.build(input_shape=(3,))
  return m

# get some data (yes its noise, but that's not the issue here)
X = numpy.random.randn(100,3)
y = numpy.random.choice([0,1],size=100)

# build and fit a model
model = build_model()
model.compile(optimizer='adadelta',loss='binary_crossentropy')
history = model.fit(X,y,batch_size=32,epochs=8,verbose=0)

# save the model
model.save(model_filename)

# load the model and compare predictions (no error loading or executing the model through Keras)
m1 = tensorflow.keras.models.load_model(model_filename)
pred = model.predict(X)
pred1 = m1.predict(X)
print(numpy.abs(pred-pred1).max()) # predictions between trained and loaded model agree

# convert the data to a spark DF
df = pandas.DataFrame({"features":X.tolist(),"targets":y,"scores":pred[:,0]})
sparkDF = spark.createDataFrame(df)

# load the model as a sparkdl.KerasTransformer
transformer = sparkdl.KerasTransformer(inputCol="features",outputCol="scoreUDF",modelFile=model_filename)

# apply the model to the dataframe THIS PRODUCES THE ERROR
sparkDF1 = transformer.transform(sparkDF) 

Похоже, виновник находится в ...python3.6/site-packages/keras/layers/normalization.py

---> 94 dim = input_shape[self.axis]

Но, кроме взлома кода и перекомпиляции, я не могу найти обходного пути.

Я использую Python 3.6 и Spark 2.4.0 со следующими версиями библиотеки:

  • тензор потока: 1.12.0
  • керас: 2.1.6-tf
  • numpy: 1.14.3
  • панды: 0.23.0

Любые советы / помощь приветствуются.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...