Как я могу сделать функцию общего на MLReader - PullRequest
1 голос
/ 08 марта 2019

Я работаю в Spark 1.6.3. Вот две функции, которые делают одно и то же:

def modelFromBytesCV(modelArray: Array[Byte]): CountVectorizerModel = {
  val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
  Files.write(tempPath, modelArray)
  CountVectorizerModel.read.load(tempPath.toString)
}

def modelFromBytesIDF(modelArray: Array[Byte]): IDFModel = {
  val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
  Files.write(tempPath, modelArray)
  IDFModel.read.load(tempPath.toString)
}

Я бы хотел сделать эти функции общими. Что меня зацепило, так это то, что общей чертой между объектом CountVectorizerModel и IDFModel является MLReadable [T], который сам по себе должен принимать тип CountVectorizerModel или IDFModel. Это своего рода рекурсивный цикл родительского класса, решение которого я не могу найти.

Для сравнения, создание универсальной модели легко, потому что MLWritable - это общая черта, распространяемая всеми интересующими меня моделями:

def modelToBytes[M <: MLWritable](model: M): Array[Byte] = {
  val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
  model.write.overwrite().save(tempPath.toString)
  Files.readAllBytes(tempPath)
}

Как создать универсальный считыватель, который превратит модель spark-ml в байтовый массив?

1 Ответ

2 голосов
/ 08 марта 2019

Чтобы это работало, вам нужен доступ к определенному MlReadable объекту.

import org.apache.spark.ml.util.MLReadable

def modelFromBytes[M](obj: MLReadable[M], modelArray: Array[Byte]): M = {
  val tempPath: Path = ???
  ...
  obj.read.load(tempPath.toString)
}

, который позже можно использовать как:

val bytes: Array[Byte] = ???
modelFromBytes(CountVectorizerModel, bytes)

Обратите внимание, что, несмотря на первое появление, здесь нет ничего рекурсивного - MLReadable[M] относится к объекту-компаньону, а не к классу как таковому. Так, например, CountVectorizerModel объект равен MLReadable, а CountVectorizeModel класс - нет.

Внутри Spark MLReader обрабатывает это по-другому - создает экземпляр класса, используя отражение , а затем устанавливает его Params. Однако этот путь не будет очень полезным для вас здесь *.

Если требуется совместимость с текущим API, вы можете попытаться сделать читаемый объект неявным:

def modelFromBytes[M](modelArray: Array[Byte])(implicit obj: MLReadable[M]): M = {
  ...
}

, а затем

implicit val readable: MLReadable[CountVectorizerModel] = CountVectorizerModel

modelFromBytes[CountVectorizerModel](bytes)

* Технически говоря, можно получить сопутствующий объект с помощью отражения

def modelFromBytesCV[M <: MLWritable](
    modelArray: Array[Byte])(implicit ct: ClassTag[M]): M = {
  val tempPath: Path = ???
  ...
  val cls = Class.forName(ct.runtimeClass.getName + "$");
  cls.getField("MODULE$").get(cls).asInstanceOf[MLReadable[M]]
    .read.load(tempPath.toString)) 
}

но я не думаю, что этот путь стоит здесь исследовать. В частности, мы не можем предоставить здесь строгие границы типов - использование MLWritable является хаком для ограничения человеческих ошибок, но довольно бесполезно для компилятора.

...