Creating a multiindex pd.DataFrame using hypothesis library

147 views Asked by At

I need to create a pd.DataFrame with a multiindex. The first index level is a simple range from 1...n. The second level is a datetime index. All columns contain floats. Here's my example for n=2.

from datetime import date

import pandas as pd
from hypothesis import given
from hypothesis import strategies as st
from hypothesis.extra.pandas import columns, data_frames, indexes


@given(
    df1=data_frames(
        columns=columns(
            ["asset1", "asset2", "asset3", "cash_asset"],
            elements=st.floats(allow_nan=False, allow_infinity=False),
        ),
        index=indexes(
            elements=st.dates(
                date.fromisoformat("2000-01-01"), date.fromisoformat("2020-12-31")
            ),
            min_size=10,
            unique=True,
        ).map(sorted),
    ),
    df2=data_frames(
        columns=columns(
            ["asset1", "asset2", "asset3", "cash_asset"],
            elements=st.floats(allow_nan=False, allow_infinity=False),
        ),
        index=indexes(
            elements=st.dates(
                date.fromisoformat("2000-01-01"), date.fromisoformat("2020-12-31")
            ),
            min_size=10,
            unique=True,
        ).map(sorted),
    ),
)
def test_index_level(df1, df2):
    df = pd.concat([df1, df2], keys=["df1", "df2"])

    assert df.index.nlevels == 2

I am wondering how to directly create the multiindex using the hypothesis library? It's clear that I can't define df1, df2, etc. manually as in my toy example.
Another constraint is that the level 2 index needs to be equal for all level 1 occurrences.

1

There are 1 answers

1
MrBean Bremen On BEST ANSWER

You can use lists to combine the dataframes instead of defining each one separately.
To make the second index (e.g. the date) the same for each first index, you can first calculate the index and then feed it to the generated lists. Maybe there is a simpler way, but I did it using a composite:

@composite
def df_lists(draw, elements=indexes(
    elements=st.dates(
        date.fromisoformat("2000-01-01"),
        date.fromisoformat("2020-12-31")
    ),
    min_size=10,
    unique=True,
)):
    index = draw(elements.map(sorted))
    df_list = lists(
        data_frames(
            columns=columns(
                ["asset1", "asset2", "asset3", "cash_asset"],
                elements=st.floats(allow_nan=False, allow_infinity=False),
            ),
            index=just(index),  # have to make a strategy from the drawn index values
        ),
        min_size=1, max_size=5  # assume n = 5
    )
    return draw(df_list)


@given(df_lists())
def test_index_level(df_list):
    df = pd.concat(df_list,
                   keys=["df" + str(i + 1) for i in range(len(df_list))])
    assert df.index.nlevels == 2