2023-03-29

PySpark add rank column to large dataset

I have a large dataframe and I want to compute a metric based on the rank of one of the columns. This metric really only depends on two columns from the dataframe, so I first select the two columns I care about, then compute the metric. Once the two relevant columns are selected, the dataframe looks something like this:

score     | truth
-----------------
0.7543    | 0
0.2144    | 0
0.5698    | 1
0.9221    | 1

The analytic that we want to calculate is called "average percent rank" and we want to calculate it for the ranks of data where truth == 1. So the process is to compute the percent rank for every data point, then select the rows where truth == 1, and finally compute the average percent rank of those data points. However, when we try to compute this, we get OOM errors. One of the issues is that using the pyspark.sql function rank requires using Window, and we want the window to include the entire dataframe (same fore percent_rank). Some code:

w = Window.orderBy(F.col("score"))

avg_percent_rank = (
    df
    .select("score", "truth")
    .withColumn("percent_rank", F.percent_rank().over(w))
    .filter(F.col("truth") == 1)
    .agg(F.mean(F.col("percent_rank")))
)

This results in an OOM error. There are over 6 billion records, and we need to build this for datasets that may be a hundred times larger. Ultimately, the critical operation is the sorting and indexing; we can derive percent_rank from this by dividing by the total number of rows.

Is there a better approach to computing rank than using a Window function?



No comments:

Post a Comment