How do I replace `map_groups` with Vectorized solution in Polars?

78 views Asked by At

I have a function vol_buckets() that has an internal function _vol_buckets_engine(). I want to find a better way to execute the same logic.

The data that the function uses has multiple symbols in the dataframe, and I need low_vol and mid_vol to be calc'ed independently forces each symbol. I do not want to use the map_groups() function as it is known to be slow. The purpose of _vol_buckets_engine() is to calculate the low_vol and mid_vol of each symbol from the realized_volatility column, by using the .quantile() method. And then to create a Polars expression vol_bucket, that creates a column and labels the row with either high, med, or low str to represent the volatility bucket that the row is in comparison to other rows for the same symbol.

I would like to achieve this functionality using expressions and .over("symbol") method. How would you suggest that I achieve this per symbol?

Data / Reprex / Functions

Data

data = pl.read_csv(
"""
date,open,high,low,close,volume,symbol,log_returns,realized_volatility
2024-01-04T00:00:00.000000000,181.92,182.86,180.65,181.68,71983600,AAPL,-0.012797549556911925,5.978806063975678
2024-01-05T00:00:00.000000000,181.76,182.53,179.94,180.95,62303300,AAPL,-0.004026147787551615,7.015258917259547
2024-01-08T00:00:00.000000000,181.86,185.36,181.27,185.32,59144500,AAPL,0.023863310538696503,26.007351763364678
2024-01-09T00:00:00.000000000,183.69,184.91,182.5,184.9,42841800,AAPL,-0.002268922155368891,22.575217646603747
2024-01-10T00:00:00.000000000,184.12,186.16,183.69,185.95,46792900,AAPL,0.00566268197800035,20.588143883893334
2024-01-11T00:00:00.000000000,186.3,186.81,183.39,185.35,49128400,AAPL,-0.003231890774338275,18.926784867553852
2024-01-12T00:00:00.000000000,185.82,186.5,184.95,185.68,40444700,AAPL,0.0017788323694407637,17.552459393090984
2024-01-16T00:00:00.000000000,181.93,184.03,180.7,183.4,65603000,AAPL,-0.012355202143817579,17.709747324646568
2024-01-17T00:00:00.000000000,181.04,182.7,180.07,182.45,47317400,AAPL,-0.005193396939907835,16.816515639093087
2024-01-18T00:00:00.000000000,185.85,188.9,185.59,188.39,78005800,AAPL,0.03203811929482647,22.667117031805294
2024-01-19T00:00:00.000000000,189.09,191.71,188.58,191.32,68741000,AAPL,0.015433136634813494,22.541610676008293
2024-01-22T00:00:00.000000000,192.05,195.08,192.01,193.64,60133900,AAPL,0.012053346259397024,21.978012260973173
2024-01-23T00:00:00.000000000,194.77,195.5,193.58,194.93,42355600,AAPL,0.006639754686560195,21.16200840822893
2024-01-24T00:00:00.000000000,195.17,196.13,194.09,194.25,53631300,AAPL,-0.0034945305102969115,20.597372530591368
2024-01-25T00:00:00.000000000,194.97,196.02,192.86,193.92,54822100,AAPL,-0.001700286366807191,19.990264840035817
2024-01-26T00:00:00.000000000,194.02,194.51,191.7,192.17,44594000,AAPL,-0.009065305936613477,19.8880114712976
2024-01-29T00:00:00.000000000,191.77,191.96,189.34,191.49,47145600,AAPL,-0.0035448090082601524,19.409858957711418
2024-01-30T00:00:00.000000000,190.7,191.56,187.23,187.8,55859400,AAPL,-0.019458021161913308,20.389312616544064
2024-01-31T00:00:00.000000000,186.8,186.86,184.12,184.16,55467800,AAPL,-0.01957262180058983,21.107096155513442
2024-02-01T00:00:00.000000000,183.76,186.71,183.59,186.62,64885400,AAPL,0.01326951883230354,21.105585346483075
2024-01-02T00:00:00.000000000,151.54,152.38,148.39,149.93,47339400,AMZN,-0.21890594529906426,76.99938193873325
2024-01-03T00:00:00.000000000,149.2,151.05,148.33,148.47,49425500,AMZN,-0.009785600874907097,75.22905752992065
2024-01-04T00:00:00.000000000,145.59,147.38,144.05,144.57,56039800,AMZN,-0.026619098311553735,73.78045961160902
2024-01-05T00:00:00.000000000,144.69,146.59,144.53,145.24,45124800,AMZN,0.00462372722578408,72.38360500669039
2024-01-08T00:00:00.000000000,146.74,149.4,146.15,149.1,46757100,AMZN,0.02622967522455788,71.80102256621782
2024-01-09T00:00:00.000000000,148.33,151.71,148.21,151.37,43812600,AMZN,0.015109949003781153,70.77149915802471
2024-01-10T00:00:00.000000000,152.06,154.42,151.88,153.73,44421800,AMZN,0.015470646150033573,69.7899272182577
2024-01-11T00:00:00.000000000,155.04,157.17,153.12,155.18,49072700,AMZN,0.00938791654129556,68.69603496216621
2024-01-12T00:00:00.000000000,155.39,156.2,154.01,154.62,40460300,AMZN,-0.00361523957347476,68.6986911830599
2024-01-16T00:00:00.000000000,153.53,154.99,152.15,153.16,41384600,AMZN,-0.00948736728323496,68.67615043687829
2024-01-17T00:00:00.000000000,151.49,152.15,149.91,151.71,34953400,AMZN,-0.009512322849092314,68.68180297623903
2024-01-18T00:00:00.000000000,152.77,153.78,151.82,153.5,37850200,AMZN,0.01172976326725017,68.29781227101783
2024-01-19T00:00:00.000000000,153.83,155.76,152.74,155.34,51033700,AMZN,0.011915695964249906,68.50450864435001
2024-01-22T00:00:00.000000000,156.89,157.05,153.9,154.78,43687500,AMZN,-0.0036115091491879525,68.41681908291058
2024-01-23T00:00:00.000000000,154.85,156.21,153.93,156.02,37986000,AMZN,0.007979450317853853,68.5421687035556
2024-01-24T00:00:00.000000000,157.8,158.51,156.48,156.87,48547300,AMZN,0.005433232707981794,68.58765769948049
2024-01-25T00:00:00.000000000,156.95,158.51,154.55,157.75,43638600,AMZN,0.005594064553173794,68.63766585670909
2024-01-26T00:00:00.000000000,158.42,160.72,157.91,159.12,51047400,AMZN,0.008647133124394024,68.75884930548898
2024-01-29T00:00:00.000000000,159.34,161.29,158.9,161.26,45270400,AMZN,0.013359334711158688,68.07728760528263
2024-01-30T00:00:00.000000000,160.7,161.73,158.49,159.0,45207400,AMZN,-0.01411376703664402,67.82169299607905
2024-01-31T00:00:00.000000000,157.0,159.01,154.81,155.2,50284400,AMZN,-0.02418959447111302,67.77658657236341
2024-02-01T00:00:00.000000000,155.87,159.76,155.62,159.28,76542400,AMZN,0.025949052006822626,68.37882655863204

"""
)

Functions

def vol_buckets(
    data: pl.DataFrame | pl.LazyFrame,
    lo_quantile: float = 0.4,
    hi_quantile: float = 0.8,
    _column_name_volatility: str = "realized_volatility",
) -> pl.DataFrame:
  
    # Group by 'symbol' and apply 'calculate_vol_buckets' to each group
    if isinstance(data, pl.LazyFrame):
        data = data.collect()

    result = data.group_by("symbol").map_groups(
        lambda group_df: _vol_buckets_engine(
            group_df,
            lo_quantile,
            hi_quantile,
            _column_name_volatility,
        )
    )

    return result.lazy()

def _vol_buckets_engine(
    grouped_data: pl.DataFrame | pl.LazyFrame,
    lo_quantile: float,
    hi_quantile: float,
    _column_name_volatility: str,
) -> pl.LazyFrame:

    # Calculate low and high quantiles for the group
    low_vol = (
        grouped_data.lazy()
        .select(pl.col(_column_name_volatility).quantile(lo_quantile))
        .collect()
        .to_series()[0]
    )
    mid_vol = (
        grouped_data.lazy()
        .select(pl.col(_column_name_volatility).quantile(hi_quantile))
        .collect()
        .to_series()[0]
    )

    # Determine the volatility bucket for each row
    vol_bucket = (
        pl.when(pl.col(_column_name_volatility) <= low_vol)
        .then(pl.lit("low"))
        .when(pl.col(_column_name_volatility) <= mid_vol)
        .then(pl.lit("mid"))
        .otherwise(pl.lit("high"))
        .alias("vol_bucket")
    )

    return grouped_data.lazy().with_columns(vol_bucket).collect()

Reprex

result = vol_buckets(data=data, lo_quantile=0.3, hi_quantile=0.65)

Solution

Thanks to @jqurious for suggesting how to use expressions. This is the fastest implementation of the desired function that I could achieve without using .qcut()

Average Execution Time: 61.2 µs ± 1.33 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

here is a composed function that can use both versions in one function, by changing the _boundary_group_down parameter.

import polars as pl


def vol_buckets(
    data: pl.DataFrame | pl.LazyFrame,
    lo_quantile: float = 0.4,
    hi_quantile: float = 0.8,
    _column_name_volatility: str = "realized_volatility",
    *,
    _boundary_group_down: bool = False,
) -> pl.LazyFrame:
    """
    Context: Toolbox || Category: MandelBrot Channel || Sub-Category: Helpers || Command: **vol_buckets**.

    Splitting data observations into 3 volatility buckets: low, mid and high.
    The function does this for each `symbol` present in the data.

    Parameters
    ----------
    data : pl.LazyFrame | pl.DataFrame
        The input dataframe or lazy frame.
    lo_quantile : float
        The lower quantile for bucketing. Default is 0.4.
    hi_quantile : float
        The higher quantile for bucketing. Default is 0.8.
    _column_name_volatility : str
        The name of the column to apply volatility bucketing. Default is
        "realized_volatility".
    _boundary_group_down: bool = False
        If True, then group boundary values down to the lower bucket, using
        `vol_buckets_alt()` If False, then group boundary values up to the
        higher bucket, using the Polars `.qcut()` method.
        Default is False.

    Returns
    -------
    pl.LazyFrame
        The `data` with an additional column: `vol_bucket`
    """

    # _check_required_columns(data, _column_name_volatility, "symbol")

    if not _boundary_group_down:
        # Grouping Boundary Values in Higher Bucket
        out = data.lazy().with_columns(
            pl.col(_column_name_volatility)
            .qcut(
                [lo_quantile, hi_quantile],
                labels=["low", "mid", "high"],
                left_closed=False,
            )
            .over("symbol")
            .alias("vol_bucket")
            .cast(pl.Utf8)
        )
    else:
        out = vol_buckets_alt(
            data, lo_quantile, hi_quantile, _column_name_volatility
        )

    return out


def vol_buckets_alt(
    data: pl.DataFrame | pl.LazyFrame,
    lo_quantile: float = 0.4,
    hi_quantile: float = 0.8,
    _column_name_volatility: str = "realized_volatility",
) -> pl.LazyFrame:
    """
    Context: Toolbox || Category: MandelBrot Channel || Sub-Category: Helpers || Command: **vol_buckets_alt**.

    This is an alternative implementation of `vol_buckets()` using expressions,
    and not using `.qcut()`.
    The biggest difference is how the function groups values on the boundaries
    of quantiles. This function groups boundary values down
    Splitting data observations into 3 volatility buckets: low, mid and high.
    The function does this for each `symbol` present in the data.

    Parameters
    ----------
    data : pl.LazyFrame | pl.DataFrame
        The input dataframe or lazy frame.
    lo_quantile : float
        The lower quantile for bucketing. Default is 0.4.
    hi_quantile : float
        The higher quantile for bucketing. Default is 0.8.
    _column_name_volatility : str
        The name of the column to apply volatility bucketing. Default is "realized_volatility".

    Returns
    -------
    pl.LazyFrame
        The `data` with an additional column: `vol_bucket`

    Notes
    -----
    The biggest difference is how the function groups values on the boundaries
    of quantiles. This function __groups boundary values down__ to the lower bucket.
    So, if there is a value that lies on the mid/low border, this function will
    group it with `low`, whereas `vol_buckets()` will group it with `mid`

    This function is also slightly less performant.
    """
    # Calculate low and high quantiles for each symbol
    low_vol = pl.col(_column_name_volatility).quantile(lo_quantile)
    high_vol = pl.col(_column_name_volatility).quantile(hi_quantile)

    # Determine the volatility bucket for each row using expressions
    vol_bucket = (
        pl.when(pl.col(_column_name_volatility) <= low_vol)
        .then(pl.lit("low"))
        .when(pl.col(_column_name_volatility) <= high_vol)
        .then(pl.lit("mid"))
        .otherwise(pl.lit("high"))
        .alias("vol_bucket")
    )

    # Add the volatility bucket column to the data
    out = data.lazy().with_columns(vol_bucket.over("symbol"))

    return out

Bucket Mismatch:

On the left, _boundary_group_down=False. On the right, _boundary_group_down=True

Bucket Boundary Mismatch

1

There are 1 answers

7
Hericks On BEST ANSWER

You can achieve the same functionality using polars' expression API with pl.Expr.qcut.

def vol_buckets_new(
    data: pl.DataFrame | pl.LazyFrame,
    lo_quantile: float = 0.4,
    hi_quantile: float = 0.8,
    _column_name_volatility: str = "realized_volatility",
) -> pl.DataFrame:

    if isinstance(data, pl.LazyFrame):
        data = data.collect()

    result = data.with_columns(
        pl.col(_column_name_volatility)
        .qcut([lo_quantile, hi_quantile], labels=["low", "mid", "high"])
        .over("symbol")
        .alias("vol_bucket")
    )

    return result.lazy()