Consider the following function acting on a polars.DataFrame
import polars as pl
def f(frame: pl.DataFrame) -> pl.DataFrame:
"""Select columns 'A', 'B' form `frame` and return only those rows for which col. 'C' is greater than 7."""
return frame.filter(pl.col("C") > 7.0).select("A", "B")
if __name__ == "__main__":
frame = pl.from_dict({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})
result = frame.pipe(f)
print(result)
>>> shape: (2, 2)
┌─────┬─────┐
│ A ┆ B │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 2 ┆ 5 │
│ 3 ┆ 6 │
└─────┴─────┘
In order to abstract away from specific column names I can write this as
import polars as pl
def g(a: pl.Series, b: pl.Series, c: pl.Series) -> pl.DataFrame:
frame = pl.concat([a.to_frame(), b.to_frame(), c.to_frame()], how="horizontal")
return frame.filter(pl.col("C") > 7.0).select("A", "B")
if __name__ == "__main__":
frame = pl.from_dict({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})
result = g(frame.get_column("A"), frame.get_column("B"), frame.get_column("C"))
print(result)
>>> shape: (2, 2)
┌─────┬─────┐
│ A ┆ B │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 2 ┆ 5 │
│ 3 ┆ 6 │
└─────┴─────┘
This somehow seems cumbersome as it involves extracting pl.Series
from pl.DataFrame
followed by pl.concat
and multiple recasts to pl.DataFrame
.
Would it be possible to write this as a custom expression instead? I'd like to apply it as outlined below
import polars as pl
def h(a: pl.Expr, b: pl.Expr, c: pl.Expr) -> pl.Expr:
# How to represent f (or g) in terms of only `pl.Expr`?
pass
if __name__ == "__main__":
frame = pl.from_dict({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})
result = frame.select(h(pl.col("A"), pl.col("B"), pl.col("C")))
print(result)
>>> shape: (2, 2)
┌─────┬─────┐
│ A ┆ B │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 2 ┆ 5 │
│ 3 ┆ 6 │
└─────┴─────┘
I am not sure if that is the intended usage of the pl.Expr
type. It seems to be a nice way to not rely on specific column names (although implicitly requiring that those expressions evaluate to 'series' with identical length and appropriate data type for the expression to work).
As you're passing the result to
DataFrame.select()
- you can useExpr.filter()
to build the expressions.