2

I'm brand new to PySpark and I'm trying to convert some python code that derives a new variable 'COUNT_IDX'. The new variable has an initial value of 1, but is incremented by 1 when a condition is met. Otherwise the new variable value will be the same value as it was on the last record.

The condition to increment is when: TRIP_CD not equal to the previous record TRIP_CD or SIGN not equal to the previous record SIGN or time_diff not equal 1.

Python code (pandas dataframe):

df['COUNT_IDX'] = 1

for i in range(1, len(df)):
    if ((df['TRIP_CD'].iloc[i] != df['TRIP_CD'].iloc[i - 1])
          or (df['SIGN'].iloc[i] != df['SIGN'].iloc[i-1])
          or df['time_diff'].iloc[i] != 1):
        df['COUNT_IDX'].iloc[i] = df['COUNT_IDX'].iloc[i-1] + 1
    else:
        df['COUNT_IDX'].iloc[i] = df['COUNT_IDX'].iloc[i-1]

Here is the expected results:

TRIP_CD   SIGN   time_diff  COUNT_IDX
2711      -      1          1
2711      -      1          1
2711      +      2          2
2711      -      1          3
2711      -      1          3
2854      -      1          4
2854      +      1          5

In PySpark, I initialize COUNT_IDX as 1. Then using the Window function, I took the lags of TRIP_CD and SIGN and calculated the time_diff, then tried:

df = sqlContext.sql('''
   select TRIP, TRIP_CD, SIGN, TIME_STAMP, seconds_diff,
   case when TRIP_CD != TRIP_lag or SIGN != SIGN_lag  or  seconds_diff != 1 
        then (lag(COUNT_INDEX) over(partition by TRIP order by TRIP, TIME_STAMP))+1
        else (lag(COUNT_INDEX) over(partition by TRIP order by TRIP, TIME_STAMP)) 
        end as COUNT_INDEX from df''')

This is giving me something like:

TRIP_CD   SIGN   time_diff  COUNT_IDX
2711      -      1          1
2711      -      1          1
2711      +      2          2
2711      -      1          2
2711      -      1          1
2854      -      1          2
2854      +      1          2

If COUNT_IDX is updated on a previous record, COUNT_IDX on the current record isn't recognizing that change to calculate. It's like the COUNTI_IDX is not being overwritten or it's not being evaluated from row to row. Any ideas at how I can get around this?

1 Answer 1

1

You need cumulative sum here:

-- cumulative sum
SUM(CAST(  
  -- if at least one condition has been satisfied
  -- we take 1 otherwise 0
  TRIP_CD != TRIP_lag OR SIGN != SIGN_lag OR seconds_diff != 1 AS LONG
)) OVER W
...
WINDOW W AS (PARTITION BY trip ORDER BY times_stamp)
Sign up to request clarification or add additional context in comments.

2 Comments

This is a creative solution, however, I haven't quite gotten it to work yet. Are you putting this in a withColumn statement to create a new column with cumulative sum or is this supposed to be in SQL? Thanks!
This is intended to replace your SQL query between case when and end. Window definition can be inlined if you prefer. Since there some missing columns in data you show it is just pseudocode though.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.