Custom `polars` expression involving multiple columns and filtering

435 views Asked by At

Consider the following function acting on a polars.DataFrame

import polars as pl


def f(frame: pl.DataFrame) -> pl.DataFrame:
    """Select columns 'A', 'B' form `frame` and return only those rows for which col. 'C' is greater than 7."""
    return frame.filter(pl.col("C") > 7.0).select("A", "B")


if __name__ == "__main__":
    frame = pl.from_dict({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})
    result = frame.pipe(f)
    print(result)
    >>> shape: (2, 2)
┌─────┬─────┐
│ A   ┆ B   │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 2   ┆ 5   │
│ 3   ┆ 6   │
└─────┴─────┘

In order to abstract away from specific column names I can write this as

import polars as pl


def g(a: pl.Series, b: pl.Series, c: pl.Series) -> pl.DataFrame:
    frame = pl.concat([a.to_frame(), b.to_frame(), c.to_frame()], how="horizontal")
    return frame.filter(pl.col("C") > 7.0).select("A", "B")

if __name__ == "__main__":
    frame = pl.from_dict({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})
    result = g(frame.get_column("A"), frame.get_column("B"), frame.get_column("C"))    
    print(result)
>>> shape: (2, 2)
┌─────┬─────┐
│ A   ┆ B   │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 2   ┆ 5   │
│ 3   ┆ 6   │
└─────┴─────┘

This somehow seems cumbersome as it involves extracting pl.Series from pl.DataFrame followed by pl.concat and multiple recasts to pl.DataFrame.

Would it be possible to write this as a custom expression instead? I'd like to apply it as outlined below

import polars as pl


def h(a: pl.Expr, b: pl.Expr, c: pl.Expr) -> pl.Expr:
    # How to represent f (or g) in terms of only `pl.Expr`?
    pass


if __name__ == "__main__":
    frame = pl.from_dict({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})
    result = frame.select(h(pl.col("A"), pl.col("B"), pl.col("C")))
    print(result)
>>> shape: (2, 2)
┌─────┬─────┐
│ A   ┆ B   │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 2   ┆ 5   │
│ 3   ┆ 6   │
└─────┴─────┘

I am not sure if that is the intended usage of the pl.Expr type. It seems to be a nice way to not rely on specific column names (although implicitly requiring that those expressions evaluate to 'series' with identical length and appropriate data type for the expression to work).

1

There are 1 answers

0
jqurious On BEST ANSWER

As you're passing the result to DataFrame.select() - you can use Expr.filter() to build the expressions.

def my_func(a, b, c): 
    return a.filter(c), b.filter(c)

df = pl.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})

df.select(my_func(pl.col("A"), pl.col("B"), pl.col("C") > 7))
shape: (2, 2)
┌─────┬─────┐
│ A   ┆ B   │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 2   ┆ 5   │
│ 3   ┆ 6   │
└─────┴─────┘