Единственное решение, которое я смог найти, было иметь EvaluateAheadIterator (тот, который оценивает заголовок буфера перед вызовом iterator.next)
import scala.collection.AbstractIterator
import scala.util.control.NonFatal
class EvalAheadIterator[+A](iter : Iterator[A]) extends AbstractIterator[A] {
private val bufferedIter : BufferedIterator[A] = iter.buffered
override def hasNext: Boolean =
if(bufferedIter.hasNext){
try{
bufferedIter.head //evaluate the head and trigger potential exceptions
true
}catch{
case NonFatal(e) =>
println("caught exception ahead of time")
false
}
}else{
false
}
override def next() : A = bufferedIter.next()
}
Теперь мы должны применить EvalAheadIterator в mapPartition:
//simulation of reading a stream from s3
def readFromS3(partition: Int) : Iterator[(Int, String)] = {
Iterator.tabulate(3){idx =>
// simulate an error only on partition 3 record 2
(idx, if(partition == 3 && idx == 2) throw new RuntimeException("error") else s"elem $idx on partition $partition" )
}
}
val rdd = sc.parallelize(Seq(1,2,3,4))
.mapPartitionsWithIndex((partitionIndex, iter) => readFromS3(partitionIndex))
.mapPartitions{iter => new EvalAheadIterator(iter)}
// I can do whatever I want here
//this is what triggers the evaluation of the iterator
val partitionedRdd = rdd.partitionBy(new org.apache.spark.HashPartitioner(2))
// I can do whatever I want here
//desperately trying to catch the exception
partitionedRdd.foreachPartition{ iter =>
try{
iter.foreach(println)
}catch{
case _ => println("error caught")
}
}