2

I am trying to count distinct number of entities at different date ranges.

I need to understand how spark performs this operation

val distinct_daily_cust_12month = sqlContext.sql(s"select distinct day_id,txn_type,customer_id from ${db_name}.fact_customer where day_id>='${start_last_12month}' and day_id<='${start_date}' and txn_type not in (6,99)")

val category_mapping = sqlContext.sql(s"select * from datalake.category_mapping");

val daily_cust_12month_ds =distinct_daily_cust_12month.join(broadcast(category_mapping),distinct_daily_cust_12month("txn_type")===category_mapping("id")).select("category","sub_category","customer_id","day_id")

daily_cust_12month_ds.createOrReplaceTempView("daily_cust_12month_ds")

val total_cust_metrics = sqlContext.sql(s"""select 'total' as category,
count(distinct(case when day_id='${start_date}' then customer_id end)) as yest,
count(distinct(case when day_id>='${start_week}' and day_id<='${end_week}' then customer_id end)) as week,
count(distinct(case when day_id>='${start_month}' and day_id<='${start_date}' then customer_id end)) as mtd,
count(distinct(case when day_id>='${start_last_month}' and day_id<='${end_last_month}' then customer_id end)) as ltd,
count(distinct(case when day_id>='${start_last_6month}' and day_id<='${start_date}' then customer_id end)) as lsm,
count(distinct(case when day_id>='${start_last_12month}' and day_id<='${start_date}' then customer_id end)) as ltm
from daily_cust_12month_ds
""")

No Errors, But this is taking a lot of time. I want to know if there is a better way to do this in Spark

2 Answers 2

3

Count distinct works by hash-partitioning the data and then counting distinct elements by partition and finally summing the counts. In general it is a heavy operation due to the full shuffle and there is no silver bullet to that in Spark or most likely any fully distributed system, operations with distinct are inherently difficult to solve in a distributed system.

In some cases there are faster ways to do it:

  • If approximate values are acceptable, approx_count_distinct will usually be much faster as it is based on HyperLogLog and the amount of data to be shuffled is much much less than with the exact implementation.
  • If you can design your pipeline in a way that the data source is already partitioned so that there can't be any duplicates between partitions, the slow step of hash-partitioning the data frame is not needed.

P.S. To understand how count distinct work, you can always use explain:

df.select(countDistinct("foo")).explain()

Example output:

== Physical Plan ==
*(3) HashAggregate(keys=[], functions=[count(distinct foo#3)])
+- Exchange SinglePartition
   +- *(2) HashAggregate(keys=[], functions=[partial_count(distinct foo#3)])
      +- *(2) HashAggregate(keys=[foo#3], functions=[])
         +- Exchange hashpartitioning(foo#3, 200)
            +- *(1) HashAggregate(keys=[foo#3], functions=[])
               +- LocalTableScan [foo#3]
Sign up to request clarification or add additional context in comments.

1 Comment

Thanks ollik1. I guess I cant tune it further. Also when I use explain or toDebugString, the plan is often displayed truncated with "...", Is there a way I can view complete plan
0

I recommend using

f.size(f.collect_set(f.col('test'))).alias('n_distinct_values_in_test')

In place of

f.countDistinct()

This is much faster

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.