I have a data frame in pyspark like below
df = spark.createDataFrame([
(124,10,8),
(124,20,7),
(125,30,6),
(125,40,5),
(126,50,4),
(126,60,3),
(126,70,2),
(127,80,1)],("ACC_KEY", "AMT", "value"))
df.show()
+-------+---+-----+
|ACC_KEY|AMT|value|
+-------+---+-----+
| 126| 70| 2|
| 126| 60| 3|
| 126| 50| 4|
| 124| 20| 7|
| 124| 10| 8|
| 127| 80| 1|
| 125| 40| 5|
| 125| 30| 6|
+-------+---+-----+
Expected result
+-------+---+-----+-------+-----+-------+
|ACC_KEY|AMT|value|row_now|amt_c|lkp_rev|
+-------+---+-----+-------+-----+-------+
| 126| 70| 2| 1| 70| 72|
| 126| 60| 3| 2| 72| 75|
| 126| 50| 4| 3| 75| 79|
| 124| 20| 7| 1| 20| 27|
| 124| 10| 8| 2| 27| 35|
| 127| 80| 1| 1| 80| 81|
| 125| 40| 5| 1| 40| 45|
| 125| 30| 6| 2| 45| 51|
+-------+---+-----+-------+-----+-------+
Conditions
1) When row_number = 1 then amt_c column = column AMT
2) when row_number != 1 then It should be the lag of column lkp_rev + column value
3) lkp_rev column = amt_c column + value column
I have tried like below
import pyspark.sql.functions as f
from pyspark.sql import Window
# create row_number column
df1 = df.withColumn("row_now", f.row_number().over(Window.partitionBy("ACC_KEY").orderBy(f.col('AMT').desc())))
# amt_c column creation
df2 = df1.withColumn("amt_c", f.when(f.col("row_now") == 1, f.col("AMT")).otherwise(f.col("value") + f.col("AMT")))
How can I achieve what i want