I have a function vol_buckets() that has an internal function _vol_buckets_engine(). I want to find a better way to execute the same logic.
The data that the function uses has multiple symbols in the dataframe, and I need low_vol and mid_vol to be calc'ed independently forces each symbol. I do not want to use the map_groups() function as it is known to be slow. The purpose of _vol_buckets_engine() is to calculate the low_vol and mid_vol of each symbol from the realized_volatility column, by using the .quantile() method. And then to create a Polars expression vol_bucket, that creates a column and labels the row with either high, med, or low str to represent the volatility bucket that the row is in comparison to other rows for the same symbol.
I would like to achieve this functionality using expressions and .over("symbol") method. How would you suggest that I achieve this per symbol?
Data / Reprex / Functions
Data
data = pl.read_csv(
"""
date,open,high,low,close,volume,symbol,log_returns,realized_volatility
2024-01-04T00:00:00.000000000,181.92,182.86,180.65,181.68,71983600,AAPL,-0.012797549556911925,5.978806063975678
2024-01-05T00:00:00.000000000,181.76,182.53,179.94,180.95,62303300,AAPL,-0.004026147787551615,7.015258917259547
2024-01-08T00:00:00.000000000,181.86,185.36,181.27,185.32,59144500,AAPL,0.023863310538696503,26.007351763364678
2024-01-09T00:00:00.000000000,183.69,184.91,182.5,184.9,42841800,AAPL,-0.002268922155368891,22.575217646603747
2024-01-10T00:00:00.000000000,184.12,186.16,183.69,185.95,46792900,AAPL,0.00566268197800035,20.588143883893334
2024-01-11T00:00:00.000000000,186.3,186.81,183.39,185.35,49128400,AAPL,-0.003231890774338275,18.926784867553852
2024-01-12T00:00:00.000000000,185.82,186.5,184.95,185.68,40444700,AAPL,0.0017788323694407637,17.552459393090984
2024-01-16T00:00:00.000000000,181.93,184.03,180.7,183.4,65603000,AAPL,-0.012355202143817579,17.709747324646568
2024-01-17T00:00:00.000000000,181.04,182.7,180.07,182.45,47317400,AAPL,-0.005193396939907835,16.816515639093087
2024-01-18T00:00:00.000000000,185.85,188.9,185.59,188.39,78005800,AAPL,0.03203811929482647,22.667117031805294
2024-01-19T00:00:00.000000000,189.09,191.71,188.58,191.32,68741000,AAPL,0.015433136634813494,22.541610676008293
2024-01-22T00:00:00.000000000,192.05,195.08,192.01,193.64,60133900,AAPL,0.012053346259397024,21.978012260973173
2024-01-23T00:00:00.000000000,194.77,195.5,193.58,194.93,42355600,AAPL,0.006639754686560195,21.16200840822893
2024-01-24T00:00:00.000000000,195.17,196.13,194.09,194.25,53631300,AAPL,-0.0034945305102969115,20.597372530591368
2024-01-25T00:00:00.000000000,194.97,196.02,192.86,193.92,54822100,AAPL,-0.001700286366807191,19.990264840035817
2024-01-26T00:00:00.000000000,194.02,194.51,191.7,192.17,44594000,AAPL,-0.009065305936613477,19.8880114712976
2024-01-29T00:00:00.000000000,191.77,191.96,189.34,191.49,47145600,AAPL,-0.0035448090082601524,19.409858957711418
2024-01-30T00:00:00.000000000,190.7,191.56,187.23,187.8,55859400,AAPL,-0.019458021161913308,20.389312616544064
2024-01-31T00:00:00.000000000,186.8,186.86,184.12,184.16,55467800,AAPL,-0.01957262180058983,21.107096155513442
2024-02-01T00:00:00.000000000,183.76,186.71,183.59,186.62,64885400,AAPL,0.01326951883230354,21.105585346483075
2024-01-02T00:00:00.000000000,151.54,152.38,148.39,149.93,47339400,AMZN,-0.21890594529906426,76.99938193873325
2024-01-03T00:00:00.000000000,149.2,151.05,148.33,148.47,49425500,AMZN,-0.009785600874907097,75.22905752992065
2024-01-04T00:00:00.000000000,145.59,147.38,144.05,144.57,56039800,AMZN,-0.026619098311553735,73.78045961160902
2024-01-05T00:00:00.000000000,144.69,146.59,144.53,145.24,45124800,AMZN,0.00462372722578408,72.38360500669039
2024-01-08T00:00:00.000000000,146.74,149.4,146.15,149.1,46757100,AMZN,0.02622967522455788,71.80102256621782
2024-01-09T00:00:00.000000000,148.33,151.71,148.21,151.37,43812600,AMZN,0.015109949003781153,70.77149915802471
2024-01-10T00:00:00.000000000,152.06,154.42,151.88,153.73,44421800,AMZN,0.015470646150033573,69.7899272182577
2024-01-11T00:00:00.000000000,155.04,157.17,153.12,155.18,49072700,AMZN,0.00938791654129556,68.69603496216621
2024-01-12T00:00:00.000000000,155.39,156.2,154.01,154.62,40460300,AMZN,-0.00361523957347476,68.6986911830599
2024-01-16T00:00:00.000000000,153.53,154.99,152.15,153.16,41384600,AMZN,-0.00948736728323496,68.67615043687829
2024-01-17T00:00:00.000000000,151.49,152.15,149.91,151.71,34953400,AMZN,-0.009512322849092314,68.68180297623903
2024-01-18T00:00:00.000000000,152.77,153.78,151.82,153.5,37850200,AMZN,0.01172976326725017,68.29781227101783
2024-01-19T00:00:00.000000000,153.83,155.76,152.74,155.34,51033700,AMZN,0.011915695964249906,68.50450864435001
2024-01-22T00:00:00.000000000,156.89,157.05,153.9,154.78,43687500,AMZN,-0.0036115091491879525,68.41681908291058
2024-01-23T00:00:00.000000000,154.85,156.21,153.93,156.02,37986000,AMZN,0.007979450317853853,68.5421687035556
2024-01-24T00:00:00.000000000,157.8,158.51,156.48,156.87,48547300,AMZN,0.005433232707981794,68.58765769948049
2024-01-25T00:00:00.000000000,156.95,158.51,154.55,157.75,43638600,AMZN,0.005594064553173794,68.63766585670909
2024-01-26T00:00:00.000000000,158.42,160.72,157.91,159.12,51047400,AMZN,0.008647133124394024,68.75884930548898
2024-01-29T00:00:00.000000000,159.34,161.29,158.9,161.26,45270400,AMZN,0.013359334711158688,68.07728760528263
2024-01-30T00:00:00.000000000,160.7,161.73,158.49,159.0,45207400,AMZN,-0.01411376703664402,67.82169299607905
2024-01-31T00:00:00.000000000,157.0,159.01,154.81,155.2,50284400,AMZN,-0.02418959447111302,67.77658657236341
2024-02-01T00:00:00.000000000,155.87,159.76,155.62,159.28,76542400,AMZN,0.025949052006822626,68.37882655863204
"""
)
Functions
def vol_buckets(
data: pl.DataFrame | pl.LazyFrame,
lo_quantile: float = 0.4,
hi_quantile: float = 0.8,
_column_name_volatility: str = "realized_volatility",
) -> pl.DataFrame:
# Group by 'symbol' and apply 'calculate_vol_buckets' to each group
if isinstance(data, pl.LazyFrame):
data = data.collect()
result = data.group_by("symbol").map_groups(
lambda group_df: _vol_buckets_engine(
group_df,
lo_quantile,
hi_quantile,
_column_name_volatility,
)
)
return result.lazy()
def _vol_buckets_engine(
grouped_data: pl.DataFrame | pl.LazyFrame,
lo_quantile: float,
hi_quantile: float,
_column_name_volatility: str,
) -> pl.LazyFrame:
# Calculate low and high quantiles for the group
low_vol = (
grouped_data.lazy()
.select(pl.col(_column_name_volatility).quantile(lo_quantile))
.collect()
.to_series()[0]
)
mid_vol = (
grouped_data.lazy()
.select(pl.col(_column_name_volatility).quantile(hi_quantile))
.collect()
.to_series()[0]
)
# Determine the volatility bucket for each row
vol_bucket = (
pl.when(pl.col(_column_name_volatility) <= low_vol)
.then(pl.lit("low"))
.when(pl.col(_column_name_volatility) <= mid_vol)
.then(pl.lit("mid"))
.otherwise(pl.lit("high"))
.alias("vol_bucket")
)
return grouped_data.lazy().with_columns(vol_bucket).collect()
Reprex
result = vol_buckets(data=data, lo_quantile=0.3, hi_quantile=0.65)
Solution
Thanks to @jqurious for suggesting how to use expressions. This is the fastest implementation of the desired function that I could achieve without using .qcut()
Average Execution Time: 61.2 µs ± 1.33 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
here is a composed function that can use both versions in one function, by changing the _boundary_group_down parameter.
import polars as pl
def vol_buckets(
data: pl.DataFrame | pl.LazyFrame,
lo_quantile: float = 0.4,
hi_quantile: float = 0.8,
_column_name_volatility: str = "realized_volatility",
*,
_boundary_group_down: bool = False,
) -> pl.LazyFrame:
"""
Context: Toolbox || Category: MandelBrot Channel || Sub-Category: Helpers || Command: **vol_buckets**.
Splitting data observations into 3 volatility buckets: low, mid and high.
The function does this for each `symbol` present in the data.
Parameters
----------
data : pl.LazyFrame | pl.DataFrame
The input dataframe or lazy frame.
lo_quantile : float
The lower quantile for bucketing. Default is 0.4.
hi_quantile : float
The higher quantile for bucketing. Default is 0.8.
_column_name_volatility : str
The name of the column to apply volatility bucketing. Default is
"realized_volatility".
_boundary_group_down: bool = False
If True, then group boundary values down to the lower bucket, using
`vol_buckets_alt()` If False, then group boundary values up to the
higher bucket, using the Polars `.qcut()` method.
Default is False.
Returns
-------
pl.LazyFrame
The `data` with an additional column: `vol_bucket`
"""
# _check_required_columns(data, _column_name_volatility, "symbol")
if not _boundary_group_down:
# Grouping Boundary Values in Higher Bucket
out = data.lazy().with_columns(
pl.col(_column_name_volatility)
.qcut(
[lo_quantile, hi_quantile],
labels=["low", "mid", "high"],
left_closed=False,
)
.over("symbol")
.alias("vol_bucket")
.cast(pl.Utf8)
)
else:
out = vol_buckets_alt(
data, lo_quantile, hi_quantile, _column_name_volatility
)
return out
def vol_buckets_alt(
data: pl.DataFrame | pl.LazyFrame,
lo_quantile: float = 0.4,
hi_quantile: float = 0.8,
_column_name_volatility: str = "realized_volatility",
) -> pl.LazyFrame:
"""
Context: Toolbox || Category: MandelBrot Channel || Sub-Category: Helpers || Command: **vol_buckets_alt**.
This is an alternative implementation of `vol_buckets()` using expressions,
and not using `.qcut()`.
The biggest difference is how the function groups values on the boundaries
of quantiles. This function groups boundary values down
Splitting data observations into 3 volatility buckets: low, mid and high.
The function does this for each `symbol` present in the data.
Parameters
----------
data : pl.LazyFrame | pl.DataFrame
The input dataframe or lazy frame.
lo_quantile : float
The lower quantile for bucketing. Default is 0.4.
hi_quantile : float
The higher quantile for bucketing. Default is 0.8.
_column_name_volatility : str
The name of the column to apply volatility bucketing. Default is "realized_volatility".
Returns
-------
pl.LazyFrame
The `data` with an additional column: `vol_bucket`
Notes
-----
The biggest difference is how the function groups values on the boundaries
of quantiles. This function __groups boundary values down__ to the lower bucket.
So, if there is a value that lies on the mid/low border, this function will
group it with `low`, whereas `vol_buckets()` will group it with `mid`
This function is also slightly less performant.
"""
# Calculate low and high quantiles for each symbol
low_vol = pl.col(_column_name_volatility).quantile(lo_quantile)
high_vol = pl.col(_column_name_volatility).quantile(hi_quantile)
# Determine the volatility bucket for each row using expressions
vol_bucket = (
pl.when(pl.col(_column_name_volatility) <= low_vol)
.then(pl.lit("low"))
.when(pl.col(_column_name_volatility) <= high_vol)
.then(pl.lit("mid"))
.otherwise(pl.lit("high"))
.alias("vol_bucket")
)
# Add the volatility bucket column to the data
out = data.lazy().with_columns(vol_bucket.over("symbol"))
return out
Bucket Mismatch:
On the left, _boundary_group_down=False. On the right, _boundary_group_down=True

You can achieve the same functionality using polars' expression API with
pl.Expr.qcut.