Pyspark юнит-тестирование - PullRequest
0 голосов
/ 23 февраля 2020
import unittest
import warnings
from datetime import datetime

from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
from pyspark.sql.types import StringType, StructField, StructType, TimestampType, FloatType

from ohlcv_service.ohlc_gwa import datetime_col


class ReusedPySparkTestCase(unittest.TestCase):
    sc_values = {}

    @classmethod
    def setUpClass(cls):
        conf = (SparkConf().setMaster('local[2]')
                .setAppName(cls.__name__)
                .set('deploy.authenticate.secret', '111111'))
        cls.sc = SparkContext(conf=conf)
        cls.sc_values[cls.__name__] = cls.sc
        cls.spark = (SparkSession.builder
                     .master('local[2]')
                     .appName('local-testing-pyspark-context')
                     .getOrCreate())

    @classmethod
    def tearDownClass(cls):
        print('....calling stop tearDownClass, the content of sc_values=', cls.sc_values, '\n')
        for key, sc in cls.sc_values.items():
            print('....closing=', key, '\n')
            sc.stop()

        cls.sc_values.clear()


class TestDateTimeCol(ReusedPySparkTestCase):

    def setUp(self):
        # Ignore ResourceWarning: unclosed socket.socket!
        warnings.simplefilter("ignore", ResourceWarning)

    def test_datetime_col(self):
        test_data_frame = self.create_data_frame(rows=[['GWA',
                                                        '2b600c2a-782f-4ccc-a675-bbbd7d91fde4',
                                                        '02fb81fa-91cf-4eab-a07e-0df3c107fbf8',
                                                        '2019-06-01T00:00:00.000Z',
                                                        0.001243179008694,
                                                        0.001243179008694,
                                                        0.001243179008694,
                                                        0.001243179008694,
                                                        0.001243179008694]],
                                                 columns=[StructField('indexType', StringType(), False),
                                                          StructField('id', StringType(), False),
                                                          StructField('indexId', StringType(), False),
                                                          StructField('timestamp', StringType(), False),
                                                          StructField('price', FloatType(), False),
                                                          StructField('open', FloatType(), False),
                                                          StructField('high', FloatType(), False),
                                                          StructField('low', FloatType(), False),
                                                          StructField('close', FloatType(), False)])
        expected = self.create_data_frame(rows=[['GWA',
                                                 '2b600c2a-782f-4ccc-a675-bbbd7d91fde4',
                                                 '02fb81fa-91cf-4eab-a07e-0df3c107fbf8',
                                                 '2019-06-01T00:00:00.000Z',
                                                 '1559347200',
                                                 0.001243179008694,
                                                 0.001243179008694,
                                                 0.001243179008694,
                                                 0.001243179008694,
                                                 0.001243179008694]],
                                          columns=[StructField('indexType', StringType(), False),
                                                   StructField('id', StringType(), False),
                                                   StructField('indexId', StringType(), False),
                                                   StructField('timestamp', StringType(), False),
                                                   StructField('datetime', TimestampType(), True),
                                                   StructField('price', FloatType(), False),
                                                   StructField('open', FloatType(), False),
                                                   StructField('high', FloatType(), False),
                                                   StructField('low', FloatType(), False),
                                                   StructField('close', FloatType(), False)])
        print(expected)
        convert_to_datetime = datetime_col(test_data_frame)
        self.assertEqual(expected, convert_to_datetime)

    def create_data_frame(self, rows, columns):
        rdd = self.sc.parallelize(rows)
        df = self.spark.createDataFrame(rdd.collect(), test_schema(columns=columns))
        return df


def test_schema(columns):
    return StructType(columns)


if __name__ == '__main__':
    unittest.main()

Ошибка

TimestampType can not accept object '1559347200' in type <class 'str'>

Функция datetime_col

def datetime_col(df):
      return df.select("indexType", "id", "indexId", "timestamp",
                     (F.col("timestamp").cast(TimestampType)).alias("datetime"),
                     "price", "open", "high", "low", "close")

Функции col типа datetime преобразуют метку времени из строки в метку времени формат. Это работает должным образом в ноутбуке EMR-Zeppelin, но когда я пытаюсь выполнить модульное тестирование, выдает вышеуказанную ошибку. Версия spark и pyspark в моей локальной версии 2.3.1. Как устранить эту ошибку. Когда я пытаюсь конвертировать spark df в pandas df, он конвертирует метку времени как + 12.

1 Ответ

0 голосов
/ 23 февраля 2020

Я не могу воспроизвести вашу проблему в вашей настройке EMR, вы не публикуете много информации, и я не смог бы ее настроить в любом случае. Но в вашем тестовом примере есть несколько проблем, с которыми я могу попытаться помочь.

Сообщение об ошибке, которое вы видите, происходит потому, что вы не можете привести string или более правильно int непосредственно к Timestamp , Вам нужно использовать to_unixtime. Нечто подобное работает нормально.

expected = self.create_data_frame(rows=[['GWA',
                                         '2b600c2a-782f-4ccc-a675-bbbd7d91fde4',
                                         '02fb81fa-91cf-4eab-a07e-0df3c107fbf8',
                                         '2019-06-01T00:00:00.000Z',
                                         None,
                                         0.001243179008694,
                                         0.001243179008694,
                                         0.001243179008694,
                                         0.001243179008694,
                                         0.001243179008694]],
                                  columns=[StructField('indexType', StringType(), False),
                                           StructField('id', StringType(), False),
                                           StructField('indexId', StringType(), False),
                                           StructField('timestamp', StringType(), False),
                                           StructField('datetime', TimestampType(), True),
                                           StructField('price', FloatType(), False),
                                           StructField('open', FloatType(), False),
                                           StructField('high', FloatType(), False),
                                           StructField('low', FloatType(), False),
                                           StructField('close', FloatType(), False)])
expected = expected.withColumn('datetime', from_unixtime(F.lit(1559347200)).cast(TimestampType()))

Вторая проблема заключается в том, что ваша функция datetime_col может нормально работать в кластере (как я говорю, я не могу воспроизвести), но она не работает локально. Следующий способ, безусловно, будет работать в обоих случаях.

def datetime_col(df):
    return df.select("indexType", "id", "indexId", "timestamp",
                     (to_timestamp(F.col("timestamp"), 
                                   "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'")).alias("datetime"),
                     "price", "open", "high", "low", "close")

Вам нужно установить часовой пояс, чтобы все работало нормально (@ your setupClass).

cls.spark.conf.set("spark.sql.session.timeZone", "UTC")

И, наконец, в вашем assert вам нужно collect данных, чтобы сравнивать содержимое ваших фреймов данных.

self.assertEqual(expected.collect(), convert_to_datetime.collect())

Надеюсь, это поможет.

...