Let's consider the following Table API pseudocode:
table.map(<pandas_udf>).where(<an_expr>).map(<simple_udf>)
The problem is that <simple_udf>
does not receive expected rows.
The rows are not filtered by <an_expr>
and the column names are reset to some default fN
names.
I expect the received rows to be filtered by <an_expr>
and accessible by their original column names returned by <pandas_udf>
.
I've played with the problem a little bit and here are some observations.
If instead of <pandas_udf>
I use another simple udf function, everything works as expected. Column names are preserved and rows are filtered.
And if between call to .where
and .map
I'll insert a conversion to DataStream API and back to Table API, the problem goes away:
ds = st_env.to_data_stream(table)
table = st_env.from_data_stream(ds)
And here is a minimal example, executed in pyflink-shell.sh
with Flink 1.17 cluster:
import pandas as pd
from pyflink.table.udf import udf
from pyflink.table.expressions import col
env = StreamExecutionEnvironment.get_execution_environment()
st_env = StreamTableEnvironment.create(env)
table = st_env.from_elements(
elements=[
(1, 'China'),
(2, 'Germany'),
(3, 'China'),
],
schema=['id', 'country'],
)
@udf(
result_type=(
'Row<id INT, country STRING>'
),
func_type="pandas",
)
def example_map_a(df: pd.DataFrame):
columns = sorted(df.columns)
print(f'example_map_a: {columns=}')
# prints:
# example_map_a: columns=['country', 'id']
assert columns == ['country', 'id'], columns
return df
@udf(
result_type=(
'Row<id INT, country STRING>'
),
)
def example_map_b(row: Row):
assert hasattr(row, 'country'), row
return row
# Will raise with
# AssertionError: Row(f0=1, f1='China')
flow = (
table
.map(example_map_a)
.where(col('country') == 'Germany')
.map(example_map_b)
.execute().print()
)
# This, however, works:
flow = (
table
.map(example_map_a)
.where(col('country') == 'Germany')
)
ds = st_env.to_data_stream(flow)
flow = (
st_env.from_data_stream(ds)
.map(example_map_b)
.execute().print()
)
Also posted this to dev mailing list, suspecting this is a bug
https://lists.apache.org/thread/k3y32gjbjk615v315ymflzt9v8t9yh7z