-1

I'm working with pyspark sql api, and trying to group rows with repeated values into a list of rest of contents. It's similar to transpose, but instead of pivoting all values, will put values into array.

Current output:

group_id | member_id | name
55       | 123       | jake
55       | 234       | tim 
65       | 345       | chris

Desired output:

group_id | members
55       | [[123, 'jake'], [234, 'tim']]
65       | [345, 'chris']
1

1 Answer 1

1

You need to groupby the group_id and use pyspark.sql.functions.collect_list() as the aggregation function.

As for combining the member_id and name columns, you have two options:

Option 1: Use pyspark.sql.functions.array:

from pyspark.sql.functions import array, collect_list

df1 = df.groupBy("group_id")\
    .agg(collect_list(array("member_id", "name")).alias("members"))

df1.show(truncate=False)
#+--------+-------------------------------------------------+
#|group_id|members                                          |
#+--------+-------------------------------------------------+
#|55      |[WrappedArray(123, jake), WrappedArray(234, tim)]|
#|65      |[WrappedArray(345, chris)]                       |
#+--------+-------------------------------------------------+

This returns a WrappedArray of arrays of strings. The integers are converted to strings because you can't have mixed type arrays.

df1.printSchema()
#root
# |-- group_id: integer (nullable = true)
# |-- members: array (nullable = true)
# |    |-- element: array (containsNull = true)
# |    |    |-- element: string (containsNull = true)

Option 2: Use pyspark.sql.functions.struct

from pyspark.sql.functions import collect_list, struct 

df2 = df.groupBy("group_id")\
    .agg(collect_list(struct("member_id", "name")).alias("members"))

df2.show(truncate=False)
#+--------+-----------------------+
#|group_id|members                |
#+--------+-----------------------+
#|65      |[[345,chris]]          |
#|55      |[[123,jake], [234,tim]]|
#+--------+-----------------------+

This returns an array of structs, with named fields for member_id and name

df2.printSchema()
#root
# |-- group_id: integer (nullable = true)
# |-- members: array (nullable = true)
# |    |-- element: struct (containsNull = true)
# |    |    |-- member_id: integer (nullable = true)
# |    |    |-- name: string (nullable = true)

What's useful about the struct method is that you can access elements of the nested array by name using the dot accessor:

df2.select("group_id", "members.member_id").show()
#+--------+----------+
#|group_id| member_id|
#+--------+----------+
#|      65|     [345]|
#|      55|[123, 234]|
#+--------+----------+
Sign up to request clarification or add additional context in comments.

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.