Я недавно начал моделировать алгоритм глубокого обучения с использованием Tensorflow в Python.Я хотел бы иметь возможность использовать SavedModel в Scala с помощью API-интерфейса Tensorflow Java.Тем не менее, вот ошибка, которую я получаю, когда стремлюсь интегрировать ее в мой код:
Program result = Failure(java.lang.IllegalArgumentException: Input to reshape is a tensor with 4 values, but the requested shape has 1
[[{{node graph/Reshape}} = Reshape[T=DT_BOOL, Tshape=DT_INT32, _output_shapes=[[?,?,1]], _device="/job:localhost/replica:0/task:0/device:CPU:0"](graph/SequenceMask/Less, graph/Reshape/shape)]])
Также вот метаданные модели, напечатанные моим кодом Scala:
model metadata: ModelMetadata(Map(serving_default -> SignatureMetadata(tensorflow/serving/predict,Map(chars -> TensorMetadata(chars:0,chars,DTypeString,List(-1, -1, -1)), words -> TensorMetadata(words:0,words,DTypeString,List(-1, -1)), nchars -> TensorMetadata(nchars:0,nchars,DTypeInt32,List(-1, -1)), nwords -> TensorMetadata(nwords:0,nwords,DTypeInt32,List(-1))),Map(pred_ids_ema -> TensorMetadata(cond_11/Merge:0,cond_11/Merge,DTypeInt32,List(-1, -1)), pred_ids -> TensorMetadata(cond/Merge:0,cond/Merge,DTypeInt32,List(-1, -1)), tags_ema -> TensorMetadata(index_to_string_Lookup_1:0,index_to_string_Lookup_1,DTypeString,List(-1, -1)), tags -> TensorMetadata(index_to_string_Lookup:0,index_to_string_Lookup,DTypeString,List(-1, -1))))))
serving signature: SignatureMetadata(tensorflow/serving/predict,Map(chars -> TensorMetadata(chars:0,chars,DTypeString,List(-1, -1, -1)), words -> TensorMetadata(words:0,words,DTypeString,List(-1, -1)), nchars -> TensorMetadata(nchars:0,nchars,DTypeInt32,List(-1, -1)), nwords -> TensorMetadata(nwords:0,nwords,DTypeInt32,List(-1))),Map(pred_ids_ema -> TensorMetadata(cond_11/Merge:0,cond_11/Merge,DTypeInt32,List(-1, -1)), pred_ids -> TensorMetadata(cond/Merge:0,cond/Merge,DTypeInt32,List(-1, -1)), tags_ema -> TensorMetadata(index_to_string_Lookup_1:0,index_to_string_Lookup_1,DTypeString,List(-1, -1)), tags -> TensorMetadata(index_to_string_Lookup:0,index_to_string_Lookup,DTypeString,List(-1, -1))))
serving signature inputs: Map(chars -> TensorMetadata(chars:0,chars,DTypeString,List(-1, -1, -1)), words -> TensorMetadata(words:0,words,DTypeString,List(-1, -1)), nchars -> TensorMetadata(nchars:0,nchars,DTypeInt32,List(-1, -1)), nwords -> TensorMetadata(nwords:0,nwords,DTypeInt32,List(-1)))
serving signature outputs: Map(pred_ids_ema -> TensorMetadata(cond_11/Merge:0,cond_11/Merge,DTypeInt32,List(-1, -1)), pred_ids -> TensorMetadata(cond/Merge:0,cond/Merge,DTypeInt32,List(-1, -1)), tags_ema -> TensorMetadata(index_to_string_Lookup_1:0,index_to_string_Lookup_1,DTypeString,List(-1, -1)), tags -> TensorMetadata(index_to_string_Lookup:0,index_to_string_Lookup,DTypeString,List(-1, -1)))
Воткод, используемый для подачи моей модели и запуска сеанса (обратите внимание, что processFeatures не полностью закодирован, отсутствует некоторая динамичность):
def processFeatures(line: String): (Array[Array[Array[Byte]]], Array[Int], Array[Array[Array[Byte]]], Array[Array[Int]]) = {
val nbWords = line.split(" ").length
val maxNbChars = line.split(" ").map(_.length).foldLeft(0) { (acc, current) =>
if (current < acc) acc
else if (current > acc) current
else acc
}
val words = Array.ofDim[Array[Byte]](1, nbWords)
words(0)(0) = line.getBytes("UTF-8")
logger.info(s"Number of words: $nbWords | Max number of characters: $maxNbChars")
val nChars = Array.ofDim[Int](1, nbWords)
nChars(0)(0) = line.split(" ")(0).length
val chars = Array.ofDim[Byte](1, nbWords, maxNbChars)
chars(0)(0) = line.getBytes("UTF-8")
(words, new Array[Int](nbWords), chars, nChars)
}
val features = processFeatures("toto")
println(Tensor.create(features._1))
println(Tensor.create(features._2))
println(Tensor.create(features._3))
println(Tensor.create(features._4))
val outputs = model.bundle.session.runner
.feed(signature.inputs("words").opName, Tensor.create(features._1))
.feed(signature.inputs("nwords").opName, Tensor.create(features._2))
.feed(signature.inputs("chars").opName, Tensor.create(features._3))
.feed(signature.inputs("nchars").opName, Tensor.create(features._4))
.fetch(output.opName)
.run()
outputs
.asScala
Может быть полезно также показать распечатанные тензорные свойства:
STRING tensor with shape [1, 1]
INT32 tensor with shape [1]
STRING tensor with shape [1, 1]
INT32 tensor with shape [1, 1]
Заранее большое спасибо за вашу помощь.
С уважением