I have a dataset like the following table below. (This data set will have the same number of elements per ID in different columns, however the number of the elements vary by ID.)
And I would like to transform this dataset to have the following table below. That is, I want to 'explode'/expand the cell values per ID into multiple rows and preserving the actual columns.
Now I have tried to explode the columns with the following script:
from pyspark.sql import functions as F
df = df.withColumn("1", F.explode(F.split(col1, ",")))\
.withColumn("2", F.explode(F.split(col2, ",")))\
.withColumn("3", F.explode(F.split(col3, ",")))
However, the following code provides the first column, col 1 in n times, the second column, col2 in m times and the third column col3 p times, so I end up having nmp rows instead of only n rows.
I have also tried to partition the dataset for each column value and drop duplicates per partition and joining these back together thereafter. However, joining is very memory inefficient and this never actually executed:
from pyspark.sql import function as F
from pyspark.sql.window import Window
part1 = df.select(F.explode(F.split(col_1, ",")).alias(col1),
F.concat("ID", F.row_number().over(Window.partitionBy("ID").orderBy("ID").alias("ID"))\
.dropDuplicates()
part2 = df.select(F.explode(F.split(col_2, ",")).alias(col1),
F.concat("ID", F.row_number().over(Window.partitionBy("ID").orderBy("ID").alias("ID"))\
.dropDuplicates()
df = part1.join(part2, part1.ID == part2.ID).select(part1.ID, part1.col1, part2.col2)
I have also came across the following one liner (for 1 column) for the same problem:
from pyspark.sql import functions as F
df = df.select(F.expr("stack(col1)"))
However, the error I recieve is the following: 'There is a data type mismatch in "stack('col1')" from "stack requires at least 2 arguments."
Is there anyway, I could achieve my desired output?
Any feedback would be much appreciated. Thanks.

