1

I have a Dataframe df that has, among others, a column of groupID; that is, each observation belongs to a specific group. In total there are 8 groups. I would like to sample from each groupID a certain percent of observations (say, 20%). Here is my approach of doing this:

val sample_df = for ( i <- Array.range(0,7) ) yield {  
             val sel_df = df.filter($"groupID"===i)  
             sel_df.sample(false,0.2,seed1)  
             }  

The result of this code is:

Array[org.apache.spark.sql.DataFrame] = Array([text: string, groupID: int], [text: string, groupID: int])

I applied flatMap() on sample_df, but I got an error:

val flat_df = sample_df.flatMap(x => x)
         <console>:59: error: type mismatch;
         found: org.apache.spark.sql.DataFrame
         required: scala.collection.GenTraversableOnce[?]

How can I get a sampled dataframe?

3 Answers 3

2

As far as I understood, you are trying to get RDD of Row. For that you can simply call:

val rows: RDD[Row] = sample_df.rdd

To explain the error you get better, flatMap requires something traversable like Option but you supplied just a Row.

Also, to get all data to the driver, you can call:

val rows: Array[Row] = sample_df.collect
Sign up to request clarification or add additional context in comments.

2 Comments

Hi, thanks for the reply. Unfortunately, sample_df is an array collection of dataframes (org.apache.spark.sql.DataFrame) and .rdd method does not work on them. What I need is to flat this array collection to just dataframe. That is why I applied flatMap in the first place.
Right, sorry about that. Then Rockie Yang's answer is the correct one.
1

I guess you wanna sample evenly on each group.

sample_df.reduceLeft((result, df) => result.unionAll(df))

Comments

0

It seems to me you just want to take a 20% sample of the entire dataframe? If so, then there is no reason to create 8 different dataframes and then union them back.

df.sample(false, 0.2, seed)

will do the trick. If you want to do different fractions for each groupID then check out df.stat.sampleBy. If you want to be sure that there is exactly 20% of each class in the sample then you'll have to convert to a PairRDD and use stratified sampling like:

df.rdd.map(row => (row(groupIDIndex), row)).sampleByKeyExact(false, Map(0 -> 0.2, 1 -> 0.2, ..., 8 -> 0.2), seed)

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.