I have a Spark dataframe in Scala as below -
val df = Seq(
(0,0,0,0.0,0),
(1,0,0,0.1,1),
(0,1,0,0.11,1),
(0,0,1,0.12,1),
(1,1,0,0.24,2),
(1,0,1,0.27,2),
(0,1,1,0.3,2),
(1,1,1,0.4,3)
).toDF("A","B","C","rate","total")
Here is how it looks like
scala> df.show
+---+---+---+----+-----+
| A| B| C|rate|total|
+---+---+---+----+-----+
| 0| 0| 0| 0.0| 0|
| 1| 0| 0| 0.1| 1|
| 0| 1| 0|0.11| 1|
| 0| 0| 1|0.12| 1|
| 1| 1| 0|0.24| 2|
| 1| 0| 1|0.27| 2|
| 0| 1| 1| 0.3| 2|
| 1| 1| 1| 0.4| 3|
+---+---+---+----+-----+
A,B and C are channels in this case. 0 and 1 represent absence and presence of channels respectively. 2^3 shows 8 combinations in the data-frame with a column 'total' giving row-wise sum of these 3 channels.
The individual probabilities of these channel occurrence can be given by -
scala> val oneChannelCase = df.filter($"total" === 1).toDF()
scala> oneChannelCase.show()
+---+---+---+----+-----+
| A| B| C|rate|total|
+---+---+---+----+-----+
| 1| 0| 0| 0.1| 1|
| 0| 1| 0|0.11| 1|
| 0| 0| 1|0.12| 1|
+---+---+---+----+-----+
However, I am interested in only pair-wise probabilities of these channels which is given by -
scala> val probs = df.filter($"total" === 2).toDF()
scala> probs.show()
+---+---+---+----+-----+
| A| B| C|rate|total|
+---+---+---+----+-----+
| 1| 1| 0|0.24| 2|
| 1| 0| 1|0.27| 2|
| 0| 1| 1| 0.3| 2|
+---+---+---+----+-----+
What I would like to do is - append 3 new columns to these "probs" dataframe that shows individual probabilities. Below is the output that I am looking for -
A B C rate prob_A prob_B prob_C
1 1 0 0.24 0.1 0.11 0
1 0 1 0.27 0.1 0 0.12
0 1 1 0.3 0 0.11 0.12
To make thing clearer, the first row of output result shows A=1, B=1, C=0. Hence the individual probabilities for A=0.1, B=0.11 and C=0 is appended to the probs dataframe respectively. Similarly, for second row, A=1, B=0, C=1 shows individual probabilities for A=0.1, B=0 and C=0.12 is appended to the probs dataframe respectively.
Here is what I have tried -
scala> val channels = df.columns.filter(v => !(v.contains("rate") | v.contains("total")))
#channels: Array[String] = Array(A, B, C)
scala> val pivotedProb = channels.map(v => f"case when $v = 1 then rate else 0 end as prob_${v}")
scala> val param = pivotedProb.mkString(",")
scala> val probs = spark.sql(f"select *, $param from df")
scala> probs.show()
+---+---+---+----+-----+------+------+------+
| A| B| C|rate|total|prob_A|prob_B|prob_C|
+---+---+---+----+-----+------+------+------+
| 0| 0| 0| 0.0| 0| 0.0| 0.0| 0.0|
| 1| 0| 0| 0.1| 1| 0.1| 0.0| 0.0|
| 0| 1| 0|0.11| 1| 0.0| 0.11| 0.0|
| 0| 0| 1|0.12| 1| 0.0| 0.0| 0.12|
| 1| 1| 0|0.24| 2| 0.24| 0.24| 0.0|
| 1| 0| 1|0.27| 2| 0.27| 0.0| 0.27|
| 0| 1| 1| 0.3| 2| 0.0| 0.3| 0.3|
| 1| 1| 1| 0.4| 3| 0.4| 0.4| 0.4|
+---+---+---+----+-----+------+------+------+
which gives me the wrong output.
Kindly help.