3

Suppose I have a dataframe in Spark as shown below -

val df = Seq(
(0,0,0,0.0),
(1,0,0,0.1),
(0,1,0,0.11),
(0,0,1,0.12),
(1,1,0,0.24),
(1,0,1,0.27),
(0,1,1,0.30),
(1,1,1,0.40)
).toDF("A","B","C","rate")

Here is how it looks like -

scala> df.show()
+---+---+---+----+
|  A|  B|  C|rate|
+---+---+---+----+
|  0|  0|  0| 0.0|
|  1|  0|  0| 0.1|
|  0|  1|  0|0.11|
|  0|  0|  1|0.12|
|  1|  1|  0|0.24|
|  1|  0|  1|0.27|
|  0|  1|  1| 0.3|
|  1|  1|  1| 0.4|
+---+---+---+----+

A,B and C are the advertising channels in this case. 0 and 1 represent absence and presence of channels respectively. 2^3 shows 8 combinations in the data-frame.

I want to filter records from this data-frame that shows presence of 2 channels at a time( AB, AC, BC) . Here is how I want my output to be -

+---+---+---+----+
|  A|  B|  C|rate|
+---+---+---+----+
|  1|  1|  0|0.24|
|  1|  0|  1|0.27|
|  0|  1|  1| 0.3|
+---+---+---+----+

I can write 3 statements to get the output by doing -

scala> df.filter($"A" === 1 && $"B" === 1 && $"C" === 0).show()
+---+---+---+----+
|  A|  B|  C|rate|
+---+---+---+----+
|  1|  1|  0|0.24|
+---+---+---+----+


scala> df.filter($"A" === 1 && $"B" === 0  && $"C" === 1).show()
+---+---+---+----+
|  A|  B|  C|rate|
+---+---+---+----+
|  1|  0|  1|0.27|
+---+---+---+----+


scala> df.filter($"A" === 0 && $"B" === 1 && $"C" === 1).show()
+---+---+---+----+
|  A|  B|  C|rate|
+---+---+---+----+
|  0|  1|  1| 0.3|
+---+---+---+----+

However, I want to achieve this using either a single statement that does my job or a function that helps me get the output. I was thinking of using a case statement to match the values. However in general my dataframe might consist of more than 3 channels -

scala> df.show()
+---+---+---+---+----+
|  A|  B|  C|  D|rate|
+---+---+---+---+----+
|  0|  0|  0|  0| 0.0|
|  0|  0|  0|  1| 0.1|
|  0|  0|  1|  0| 0.1|
|  0|  0|  1|  1|0.59|
|  0|  1|  0|  0| 0.1|
|  0|  1|  0|  1|0.89|
|  0|  1|  1|  0|0.39|
|  0|  1|  1|  1| 0.4|
|  1|  0|  0|  0| 0.0|
|  1|  0|  0|  1|0.99|
|  1|  0|  1|  0|0.49|
|  1|  0|  1|  1| 0.1|
|  1|  1|  0|  0|0.79|
|  1|  1|  0|  1| 0.1|
|  1|  1|  1|  0| 0.1|
|  1|  1|  1|  1| 0.1|
+---+---+---+---+----+

In this scenario I would want my output as -

scala> df.show()
+---+---+---+---+----+
|  A|  B|  C|  D|rate|
+---+---+---+---+----+
|  0|  0|  1|  1|0.59|
|  0|  1|  0|  1|0.89|
|  0|  1|  1|  0|0.39|
|  1|  0|  0|  1|0.99|
|  1|  0|  1|  0|0.49|
|  1|  1|  0|  0|0.79|
+---+---+---+---+----+

which shows rates for paired presence of channels => (AB, AC, AD, BC, BD, CD).

Kindly help.

1 Answer 1

2

One way could be to sum the columns and then filter only when the result of the sum is 2.

import org.apache.spark.sql.functions._

df.withColumn("res", $"A" + $"B" + $"C").filter($"res" === lit(2)).drop("res").show

The output is:

+---+---+---+----+
|  A|  B|  C|rate|
+---+---+---+----+
|  1|  1|  0|0.24|
|  1|  0|  1|0.27|
|  0|  1|  1| 0.3|
+---+---+---+----+
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.