Рекурсивная функция работает с pandas dataframe, но версия pyspark dataframe генерирует ошибочные результаты при переходе состояний - PullRequest
0 голосов
/ 20 сентября 2018

Я объясняю проблему с небольшим набором данных с 4 столбцами.У меня есть кумулятивная матрица перехода вероятностей с кумулятивными вероятностями, как показано ниже.

import pandas as pd
import random
cumTransitionMatrix = pd.DataFrame({'State_1' : [.2,.6,.65,1]
                             ,'State_2' : [.3,.7,.78,1]
                             ,'State_3' : [0 , 0,  1, 1]
                             ,'State_4' : [.5,.7,.85, 1]},index ['State_1','State_2','State_3','State_4']).T

И еще один набор данных с начальными состояниями перехода

df = pd.DataFrame({'current state' : ['State_1','State_2','State_3','State_4','State_4'
                                      ,'State_3','State_1','State_4','State_2','State_4'
                                      ,'State_4','State_2','State_4','State_4','State_1']})

Здесь обратите внимание, что state_3 является конечным состоянием.Итак, если мы достигнем состояния 3, мы останемся там навсегда.

Теперь я хочу смоделировать следующие 5 состояний на основе обратного моделирования из равномерного случайного числа нет.Код, который я написал на python, работает нормально.

stateNames = dict(zip(range(len(cumTransitionMatrix.columns)), cumTransitionMatrix.columns))

def transition(value):

    transitArray = [0] * (len(cumTransitionMatrix.columns)+1)
    x=random.random()
    x=float("{0:.3f}".format(x))
    transitArray[len(cumTransitionMatrix.columns)] = x
    if x <= cumTransitionMatrix.loc[value][0]:
                transitArray[0] = 1
                nextState = stateNames[0]
    else:
        for i in range(1,len(cumTransitionMatrix.columns)):
            if (x <= cumTransitionMatrix.loc[value][i]) and (x > cumTransitionMatrix.loc[value][i-1]):
                transitArray[i] = 1
                nextState = stateNames[i]
                break
    return transitArray, nextState

df['transitArray1'], df['nextState1'] = zip(*df['current state'].apply(transition))

no_of_forecast_month = 5
for i in range(2, no_of_forecast_month+1):
    df['transitArray'+str(i)], df['nextState'+str(i)]=zip(*df['nextState'+str(i-1)].apply(transition))

Но когда я создал версию Pyspark для того же кода.Это дает мне ошибочный результат.После достижения State_3 мы отправляемся в другое состояние, может быть 1 или 2. Не могу найти причину этой ошибки.Пожалуйста, найдите ниже код Pyspark.

spark = SparkSession.builder.appName('pandasToSparkDF').master('local[*]').getOrCreate()

mySchema = StructType([ StructField("current state", StringType(), True)])
df = spark.createDataFrame(PythonDF,schema=mySchema)

stateNames = dict(zip(range(len(cumTransitionMatrix.columns)), cumTransitionMatrix.columns))

def transition(value):

    transitArray = [0] * (len(cumTransitionMatrix.columns)+1)
    x=random.random()
    x=float("{0:.3f}".format(x))
    transitArray[len(cumTransitionMatrix.columns)] = x
    if x <= cumTransitionMatrix.loc[value][0]:
                transitArray[0] = 1
                nextState = stateNames[0]
    else:
        for i in range(1,len(cumTransitionMatrix.columns)):
            if (x <= cumTransitionMatrix.loc[value][i]) and (x > cumTransitionMatrix.loc[value][i-1]):
                transitArray[i] = 1
                nextState = stateNames[i]
                break
#     return transitArray, nextState
    return nextState

generateState = udf(transition)
df=df.withColumn('nextState1',generateState('current state'))

no_of_forecast_month = 5
for i in range(2, no_of_forecast_month+1):
    df=df.withColumn('nextState'+str(i),generateState('nextState'+str(i-1)))

df.show(20)

Пожалуйста, найдите под результатом.

+-------------+----------+----------+----------+----------+----------+
|current state|nextState1|nextState2|nextState3|nextState4|nextState5|
+-------------+----------+----------+----------+----------+----------+
|      State_1|   State_2|   State_4|   State_3|   State_3|   State_4|
|      State_2|   State_2|   State_1|   State_3|   State_4|   State_2|
|      State_3|   State_3|   State_3|   State_3|   State_3|   State_3|
|      State_4|   State_1|   State_3|   State_4|   State_3|   State_3|
|      State_4|   State_1|   State_1|   State_4|   State_2|   State_1|
|      State_3|   State_3|   State_3|   State_3|   State_3|   State_3|
|      State_1|   State_4|   State_3|   State_1|   State_2|   State_2|
|      State_4|   State_4|   State_3|   State_1|   State_4|   State_4|
|      State_2|   State_1|   State_3|   State_1|   State_3|   State_3|
|      State_4|   State_4|   State_4|   State_2|   State_2|   State_1|
+-------------+----------+----------+----------+----------+----------+

Посмотрите в 1-й строке nextState4 и nextState5

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...