Как проверить исключения, выброшенные внутри карты в Scala - PullRequest
3 голосов
/ 28 февраля 2020

У меня есть следующая Scala функция

def throwError(spark: SparkSession,df:DataFrame): Unit = {
        import spark.implicits._
        throw new IllegalArgumentException(s"Illegal arguments")
        val predictionAndLabels = df.select("prediction", "label").map {
            case Row(prediction: Double, label: Double) => (prediction, label)
            case other => throw new IllegalArgumentException(s"Illegal arguments")
        }
        predictionAndLabels.show()
}

Я хочу проверить исключения, вызванные вышеуказанной функцией, но мои тесты не пройдены.

"Testing" should "throw error for datetype" in withSparkSession {
    spark => {

     // Creating a dataframe 
      val someData = Seq(
        Row(8, Date.valueOf("2016-09-30")),
        Row(9, Date.valueOf("2017-09-30")),
        Row(10, Date.valueOf("2018-09-30"))
      )

      val someSchema = List(
        StructField("prediction", IntegerType, true),
        StructField("label", DateType , true)
      )

      val someDF = spark.createDataFrame(
        spark.sparkContext.parallelize(someData),
        StructType(someSchema)
      )

    // Testing exception
     val caught = intercept[IllegalArgumentException] {
        throwError(spark,someDF)
      }

     assert(caught.getMessage.contains("Illegal arguments"))
   }
}

Если я перееду throw new IllegalArgumentException(s"Illegal arguments") вне вызова функции map тест проходит.

Как можно проверить исключение, выброшенное функцией throwError?

1 Ответ

1 голос
/ 02 марта 2020

Перехват исключений на уровне строк невозможен с помощью sparkDF. Если вы используете RDD, можно добиться того, что вы пытаетесь сделать.

Проверьте этот блог: https://www.nicolaferraro.me/2016/02/18/exception-handling-in-apache-spark/

Обходной путь для вашей проблемы:

def throwError(spark: SparkSession,df:DataFrame): Unit = {
        import spark.implicits._
        val countOfRowsBeforeCheck = df.count()
        val predictionAndLabels = df.select("prediction", "label").flatMap {
            case Row(prediction: Double, label: Double) => Iterator((prediction, label))
            case other => Iterator.empty
        }
        val countOfRowsAfterCheck = predictionAndLabels.count()
        if(countOfRowsAfterCheck != countOfRowsBeforeCheck){
            throw new IllegalArgumentException(s"Illegal arguments")
        }

        predictionAndLabels.show()
}

Надеюсь, эта помощь !!

...