5

I have two dataframe and I'm using collect_set() in agg after using groupby. What's the best way to flatMap the resulting array after aggregating.

schema = ['col1', 'col2', 'col3', 'col4']

a = [[1, [23, 32], [11, 22], [9989]]]

df1 = spark.createDataFrame(a, schema=schema)

b = [[1, [34], [43, 22], [888, 777]]]

df2 = spark.createDataFrame(b, schema=schema)

df = df1.union(
        df2
    ).groupby(
        'col1'
    ).agg(
        collect_set('col2').alias('col2'),
        collect_set('col3').alias('col3'),
        collect_set('col4').alias('col4')
    )

df.collect()

I'm getting this as output:

[Row(col1=1, col2=[[34], [23, 32]], col3=[[11, 22], [43, 22]], col4=[[9989], [888, 777]])]

But, I want this as output:

[Row(col1=1, col2=[23, 32, 34], col3=[11, 22, 43], col4=[9989, 888, 777])]

2 Answers 2

3

You can use udf:

from itertools import chain
from pyspark.sql.types import *
from pyspark.sql.functions import udf

flatten = udf(lambda x: list(chain.from_iterable(x)), ArrayType(IntegerType()))

df.withColumn('col2_flat', flatten('col2'))
Sign up to request clarification or add additional context in comments.

Comments

1

Without UDF I supposed this should work :

from pyspark.sql.functions import array_distinct, flatten

df.withColumn('col2_flat', array_distinct(flatten('col2')))

It will flatten the nested arrays, and then deduplicates.

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.