14

I have a dataframe (Spark):

id  value 
3     0
3     1
3     0
4     1
4     0
4     0

I want to create a new dataframe:

3 0
3 1
4 1

I need to remove all the rows after 1 (value) for each id. I tried with window functions in Spark dataframe (Scala) but couldn't find a solution. It seems like I am going in a wrong direction.

I am looking for a solution in Scala.

Output using monotonically_increasing_id:

 scala> val data = Seq((3,0),(3,1),(3,0),(4,1),(4,0),(4,0)).toDF("id", "value")
data: org.apache.spark.sql.DataFrame = [id: int, value: int]

scala> val minIdx = dataWithIndex.filter($"value" === 1).groupBy($"id").agg(min($"idx")).toDF("r_id", "min_idx")
minIdx: org.apache.spark.sql.DataFrame = [r_id: int, min_idx: bigint]

scala> dataWithIndex.join(minIdx,($"r_id" === $"id") && ($"idx" <= $"min_idx")).select($"id", $"value").show
+---+-----+
| id|value|
+---+-----+
|  3|    0|
|  3|    1|
|  4|    1|
+---+-----+

The solution won't work if we did a sorted transformation in the original dataframe. That time the monotonically_increasing_id() is generated based on original DF rather than sorted DF. I have missed that requirement before.

All suggestions are welcome.

5
  • And what did you try so far ? Commented Apr 2, 2016 at 15:50
  • @eliasah I tried some experiments based on the answer here stackoverflow.com/questions/32148208/…. but no success so far Commented Apr 2, 2016 at 16:09
  • Is your DF sorted? Commented Apr 2, 2016 at 20:31
  • @TheArchetypalPaul yes, it is sorted Commented Apr 3, 2016 at 3:13
  • 1
    It's because you are calling show each time. The evaluations are lazy in my code below -- the original val dataWithIndex is only evaluated once my final show is called. But you call show each time, forcing the re-evaluation. Stop calling show, or immediately call cache after creating dataWithIndex Commented Apr 3, 2016 at 4:11

4 Answers 4

12

One way is to use monotonically_increasing_id() and a self-join:

val data = Seq((3,0),(3,1),(3,0),(4,1),(4,0),(4,0)).toDF("id", "value")
data.show
+---+-----+
| id|value|
+---+-----+
|  3|    0|
|  3|    1|
|  3|    0|
|  4|    1|
|  4|    0|
|  4|    0|
+---+-----+

Now we generate a column named idx with an increasing Long:

val dataWithIndex = data.withColumn("idx", monotonically_increasing_id())
// dataWithIndex.cache()

Now we get the min(idx) for each id where value = 1:

val minIdx = dataWithIndex
               .filter($"value" === 1)
               .groupBy($"id")
               .agg(min($"idx"))
               .toDF("r_id", "min_idx")

Now we join the min(idx) back to the original DataFrame:

dataWithIndex.join(
  minIdx,
  ($"r_id" === $"id") && ($"idx" <= $"min_idx")
).select($"id", $"value").show
+---+-----+
| id|value|
+---+-----+
|  3|    0|
|  3|    1|
|  4|    1|
+---+-----+

Note: monotonically_increasing_id() generates its value based on the partition of the row. This value may change each time dataWithIndex is re-evaluated. In my code above, because of lazy evaluation, it's only when I call the final show that monotonically_increasing_id() is evaluated.

If you want to force the value to stay the same, for example so you can use show to evaluate the above step-by-step, uncomment this line above:

//  dataWithIndex.cache()
Sign up to request clarification or add additional context in comments.

4 Comments

Yeah, don't look too deeply at the column generated by monotonically_increasing_id() -- you may get different values everytime you look at it -- the numbers you see are based on the partitioning scheme. Just run the code, don't look at the intermediate values. It works.
If, for sanity sake, you want to see the same values every time -- add the line dataWithIndex.cache(). But it doesn't change the overall results -- it just makes it so you can look at each step under the microscope and not feel like you are going crazy.
thanks @davidGirffin.I didnt get the output properly , that'y I have checked intermediate result.I have updated the output in the question itself.Could you please take a look.
Answered. It's because each time you call show it forces a re-evaluation. It's like quantum mechanics -- you are changing the value by observing it. If you run the code just like I have it -- with only the last show -- it gets the proper result.
2

Hi I found the solution using Window and self join.

val data = Seq((3,0,2),(3,1,3),(3,0,1),(4,1,6),(4,0,5),(4,0,4),(1,0,7),(1,1,8),(1,0,9),(2,1,10),(2,0,11),(2,0,12)).toDF("id", "value","sorted")

data.show

scala> data.show
+---+-----+------+
| id|value|sorted|
+---+-----+------+
|  3|    0|     2|
|  3|    1|     3|
|  3|    0|     1|
|  4|    1|     6|
|  4|    0|     5|
|  4|    0|     4|
|  1|    0|     7|
|  1|    1|     8|
|  1|    0|     9|
|  2|    1|    10|
|  2|    0|    11|
|  2|    0|    12|
+---+-----+------+




val sort_df=data.sort($"sorted")

scala> sort_df.show
+---+-----+------+
| id|value|sorted|
+---+-----+------+
|  3|    0|     1|
|  3|    0|     2|
|  3|    1|     3|
|  4|    0|     4|
|  4|    0|     5|
|  4|    1|     6|
|  1|    0|     7|
|  1|    1|     8|
|  1|    0|     9|
|  2|    1|    10|
|  2|    0|    11|
|  2|    0|    12|
+---+-----+------+



var window=Window.partitionBy("id").orderBy("$sorted")

 val sort_idx=sort_df.select($"*",rowNumber.over(window).as("count_index"))

val minIdx=sort_idx.filter($"value"===1).groupBy("id").agg(min("count_index")).toDF("idx","min_idx")

val result_id=sort_idx.join(minIdx,($"id"===$"idx") &&($"count_index" <= $"min_idx"))

result_id.show

+---+-----+------+-----------+---+-------+
| id|value|sorted|count_index|idx|min_idx|
+---+-----+------+-----------+---+-------+
|  1|    0|     7|          1|  1|      2|
|  1|    1|     8|          2|  1|      2|
|  2|    1|    10|          1|  2|      1|
|  3|    0|     1|          1|  3|      3|
|  3|    0|     2|          2|  3|      3|
|  3|    1|     3|          3|  3|      3|
|  4|    0|     4|          1|  4|      3|
|  4|    0|     5|          2|  4|      3|
|  4|    1|     6|          3|  4|      3|
+---+-----+------+-----------+---+-------+

Still looking for a more optimized solutions.Thanks

Comments

0

You can simply use groupBy like this

val df2 = df1.groupBy("id","value").count().select("id","value")

Here your df1 is

id  value 
3     0
3     1
3     0
4     1
4     0
4     0

And resultant dataframe is df2 which is your expected output like this

id  value 
3     0
3     1
4     1
4     0

Comments

0
use isin method and filter as below:

val data = Seq((3,0,2),(3,1,3),(3,0,1),(4,1,6),(4,0,5),(4,0,4),(1,0,7),(1,1,8),(1,0,9),(2,1,10),(2,0,11),(2,0,12)).toDF("id", "value","sorted")
val idFilter = List(1, 2)
 data.filter($"id".isin(idFilter:_*)).show
+---+-----+------+
| id|value|sorted|
+---+-----+------+
|  1|    0|     7|
|  1|    1|     8|
|  1|    0|     9|
|  2|    1|    10|
|  2|    0|    11|
|  2|    0|    12|
+---+-----+------+

Ex: filter based on val
val valFilter = List(0)
data.filter($"value".isin(valFilter:_*)).show
+---+-----+------+
| id|value|sorted|
+---+-----+------+
|  3|    0|     2|
|  3|    0|     1|
|  4|    0|     5|
|  4|    0|     4|
|  1|    0|     7|
|  1|    0|     9|
|  2|    0|    11|
|  2|    0|    12|
+---+-----+------+

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.