How to use table.map() with pandas function?

46 views Asked by At

Let's consider the following Table API pseudocode:

table.map(<pandas_udf>).where(<an_expr>).map(<simple_udf>)

The problem is that <simple_udf> does not receive expected rows. The rows are not filtered by <an_expr> and the column names are reset to some default fN names.

I expect the received rows to be filtered by <an_expr> and accessible by their original column names returned by <pandas_udf>.

I've played with the problem a little bit and here are some observations.

If instead of <pandas_udf> I use another simple udf function, everything works as expected. Column names are preserved and rows are filtered.

And if between call to .where and .map I'll insert a conversion to DataStream API and back to Table API, the problem goes away:

ds = st_env.to_data_stream(table)
table = st_env.from_data_stream(ds)

And here is a minimal example, executed in pyflink-shell.sh with Flink 1.17 cluster:

import pandas as pd

from pyflink.table.udf import udf
from pyflink.table.expressions import col

env = StreamExecutionEnvironment.get_execution_environment()
st_env = StreamTableEnvironment.create(env)

table = st_env.from_elements(
    elements=[
        (1, 'China'),
        (2, 'Germany'),
        (3, 'China'),
    ],
    schema=['id', 'country'],
)

@udf(
    result_type=(
        'Row<id INT, country STRING>'
    ),
    func_type="pandas",
)
def example_map_a(df: pd.DataFrame):
    columns = sorted(df.columns)
    print(f'example_map_a: {columns=}')
    # prints:
    # example_map_a: columns=['country', 'id']
    assert columns == ['country', 'id'], columns
    return df


@udf(
    result_type=(
        'Row<id INT, country STRING>'
    ),
)
def example_map_b(row: Row):
    assert hasattr(row, 'country'), row
    return row


# Will raise with
# AssertionError: Row(f0=1, f1='China')


flow = (
    table
    .map(example_map_a)
    .where(col('country') == 'Germany')
    .map(example_map_b)
    .execute().print()
)

# This, however, works:

flow = (
    table
    .map(example_map_a)
    .where(col('country') == 'Germany')
)

ds = st_env.to_data_stream(flow)

flow = (
    st_env.from_data_stream(ds)
    .map(example_map_b)
    .execute().print()
)

Also posted this to dev mailing list, suspecting this is a bug

https://lists.apache.org/thread/k3y32gjbjk615v315ymflzt9v8t9yh7z

0

There are 0 answers