Defining a custom expression containing a conditional

177 views Asked by At

I would like to define an expression for computing the sum of a polars.DataFrame column. The desired sum operation has the property that it's the usual polars sum operation, except in the case that there are only null elements in the column. In that case I want the sum to be null also.

This can be achieved using a conditional, e.g.

frame = pl.from_dict({
    "A": [None, None, None], 
    "B": [0.0, None, None]})
result = frame.select([
    pl.when(pl.col(col).is_not_null().sum() > 0)
      .then(pl.col(col).sum())
      .otherwise(None)
          for col in frame.columns
          ])
print(result)
>>> shape: (1, 2)
┌─────┬─────┐
│ A   ┆ B   │
│ --- ┆ --- │
│ f32 ┆ f64 │
╞═════╪═════╡
│ null┆ 0.0 │
└─────┴─────┘

The conditional expression is very long, ideally I'd like to be able to define (a function?) mysum such that I can get the same writing

result = frame.select(pl.all()).mysum()

OR

result = frame.select(pl.all()).map(mysum)

or similar. I simply would like to hide (encapsulate) the conditional statement, possibly passing a min_count argument (like in pandas.DataFrame.sum) so that I can use it from within in the select with_columns and groupby contexts.

Unfortunately I am just getting started with polars and I am not sure what's the best/canonical way to achieve this.

I am also struggling to understand the difference between the following four statements (which all compute the column wise sum). I have indicated my best-guess/preliminary understanding as inline comments, but would be very glad if this could be unraveled for me.

# i) Method call on `polars.DataFrame`
frame.sum()
# ii) Select context with list of expressions? each computing the sum of a single column
frame.select([pl.col("A").sum(), pl.col("B").sum()])
# iii) Select context with a list of two `pl.col` expressions?, method call on the resulting `polars.DataFrame`?
frame.select([pl.col("A"), pl.col("B")]).sum()
# iv) Select context with a `pl.col` expression? applied to a list of columns chained with a (single?) sum expression?
frame.select(pl.col(["A", "B"]).sum())

Finally I'd like to understand whether my 'custom' sum (expression?) can be used with any/all of the four above statements.

Solution

Thanks to all the very helpful explanations, I ended up defining a custom expression which I applied using the polars.Expr.pipe method, cf. below


def nullsum(expr: pl.Expr, min_count: int = 0) -> pl.Expr:
    if min_count > 0:
        return pl.when(expr.is_not_null().sum() >= pl.lit(min_count)).then(expr.sum())
    else:
        return expr.sum()

result = frame.select(pl.col("A", "B").pipe(nullsum, min_count=1))
print(result)
>>> shape: (1, 2)
┌─────┬─────┐
│ A   ┆ B   │
│ --- ┆ --- │
│ f32 ┆ f64 │
╞═════╪═════╡
│ null┆ 0.0 │
└─────┴─────┘
1

There are 1 answers

3
Dean MacGregor On

How to achieve the ability to do result = frame.select(pl.all()).mysum()

The good thing here is that you can do even better than that because:

  1. .select(pl.all()) is completely redundant.

  2. .otherwise(None) is also redundant.

Let's take what you wrote (with some little tweaks) and put it in a function:

def mysum(frame):
    return (
        frame
        .select(pl.when(pl.col(col).is_not_null().any())
                .then(pl.col(col).sum())
                for col in frame.columns)
)

Now you can do mysum(frame) which is close to what you wanted. To get it so that you can do frame.mysum() you need to assign your function into the pl.DataFrame namespace like this

pl.DataFrame.mysum=mysum

With that you can now do

frame.mysum()
shape: (1, 2)
┌──────┬─────┐
│ A    ┆ B   │
│ ---  ┆ --- │
│ f32  ┆ f64 │
╞══════╪═════╡
│ null ┆ 0.0 │
└──────┴─────┘

To get min_count feature

def mysum(frame, min_count=1):
    return (
        frame
        .select(pl.when(pl.col(x).is_not_null().sum()>=pl.lit(min_count))
                   .then(pl.col(x).sum())
                   for x in frame.columns)
)
pl.DataFrame.mysum=mysum

We gave min_count a default value of 1 so you can still do frame.mysum() but you can also do frame.mysum(2) (or however many you want the min_count to be).

Don't be confused by the fact that the function used the name frame to think it only works with this frame. It will work on any frame until you restart the kernel and need to reload the custom function. For instance you can do:


df = pl.DataFrame({
    'a': [1,2,3,None, None, 10],
    'b': [1,2,3,4, None, 10],
    'c': [None,2,None,None, None, 10]
})
df.mysum()
shape: (1, 3)
┌─────┬─────┬─────┐
│ a   ┆ b   ┆ c   │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 16  ┆ 20  ┆ 12  │
└─────┴─────┴─────┘
df.mysum(5)
shape: (1, 3)
┌──────┬─────┬──────┐
│ a    ┆ b   ┆ c    │
│ ---  ┆ --- ┆ ---  │
│ i64  ┆ i64 ┆ i64  │
╞══════╪═════╪══════╡
│ null ┆ 20  ┆ null │
└──────┴─────┴──────┘

and so on.

Distinguishing between your cases

# i) Method call on `polars.DataFrame`
frame.sum()
# ii) Select context with list of expressions? each computing the sum of a single column
frame.select([pl.col("A").sum(), pl.col("B").sum()])
# iii) Select context with a list of two `pl.col` expressions?, method call on the resulting `polars.DataFrame`?
frame.select([pl.col("A"), pl.col("B")]).sum()
# iv) Select context with a `pl.col` expression? applied to a list of columns chained with a (single?) sum expression?
frame.select(pl.col(["A", "B"]).sum())

From here, i and iii are essentially the same. In iii when you do frame.select([pl.col("A"), pl.col("B")]), the .select([pl.col("A"), pl.col("B")]) is redundant because the frame only has those two columns and you're not doing anything with them except returning them. In both cases you're dispatching the DataFrame.sum() method.

ii and iv are also both the same. Polars is flexible in how you dispatch methods to expressions. When you want to do the same operation to multiple columns then you can do, as in what you've done in iv, pl.col('A','B').sum(). Other than, perhaps readability, there's no benefit to doing pl.col("A").sum(), pl.col("B").sum(). Another thing is that polars uses *args and **kwargs almost everywhere so you rarely need to explicitly use lists so instead of

frame.select([pl.col("A").sum(), pl.col("B").sum()])

you can just write

frame.select(pl.col("A").sum(), pl.col("B").sum())

Same for

frame.select(pl.col(["A", "B"]).sum())

it can just be written as

frame.select(pl.col("A", "B").sum())