How do I use a multi-input StructuredTool with a Pandas Dataframe Agent?

339 views Asked by At

I have created a Pandas Dataframe Agent with the LLM being a HuggingFace pipeline. I want the agent to use a multi-input StructuredTool, but I always get a ValidationError. I tried to follow the multi-input tool guide here and pass the Action Input as a JSON string but that doesn't work either.

The tool is defined like so:

from langchain.tools import tool

@tool()
def plot_graph(x: str, y: str):
    r"""Plots a scatterplot of the x and y column values of the Dataframe. Provide the x and y arguments as a JSON object. Example: {{"x": "Amount", "y": "Date"}}"""
    print(x, y) # placeholder

tools = [plot_graph]

REPL output:

[StructuredTool(name='plot_graph', description='plot_graph(x: str, y: str) - Plots a scatterplot of the x and y column values of the Dataframe. Provide the x and y arguments as a JSON object. Example: {{"x": "Amount", "y": "Date"}}', args_schema=<class 'pydantic.v1.main.plot_graphSchemaSchema'>, return_direct=False, verbose=False, callbacks=None, callback_manager=None, tags=None, metadata=None, handle_tool_error=False, func=<function plot_graph at 0x0000018294D30680>, coroutine=None)]

The agent is defined and run like this:

agent = create_pandas_dataframe_agent(hf, df, verbose=True, extra_tools=tools, agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION)

query = "Plot a graph of Credit Amount vs Date."
agent.run(query)

Note: I tried running the above without agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION as well.

Which leads the model to generate this:

Thought: I should now plot a graph of Credit Amount vs Date.
Action: plot_graph
Action Input: {"x": "Date", "y": "Credit Amount"}

But then results in this error:

---------------------------------------------------------------------------
ValidationError                           Traceback (most recent call last)
Cell In[34], line 6
      5 query = "Plot a graph of Credit Amount vs Date."
----> 6 agent.run(query)

File ~\Documents\Python\langchain\.venv\Lib\site-packages\langchain\chains\base.py:507, in Chain.run(self, callbacks, tags, metadata, *args, **kwargs)
    505     if len(args) != 1:
    506         raise ValueError("`run` supports only one positional argument.")
--> 507     return self(args[0], callbacks=callbacks, tags=tags, metadata=metadata)[
    508         _output_key
    509     ]
    511 if kwargs and not args:
    512     return self(kwargs, callbacks=callbacks, tags=tags, metadata=metadata)[
    513         _output_key
    514     ]

File ~\Documents\Python\langchain\.venv\Lib\site-packages\langchain\chains\base.py:312, in Chain.__call__(self, inputs, return_only_outputs, callbacks, tags, metadata, run_name, include_run_info)
    310 except BaseException as e:
    311     run_manager.on_chain_error(e)
--> 312     raise e
    313 run_manager.on_chain_end(outputs)
    314 final_outputs: Dict[str, Any] = self.prep_outputs(
    315     inputs, outputs, return_only_outputs
    316 )

File ~\Documents\Python\langchain\.venv\Lib\site-packages\langchain\chains\base.py:306, in Chain.__call__(self, inputs, return_only_outputs, callbacks, tags, metadata, run_name, include_run_info)
    299 run_manager = callback_manager.on_chain_start(
    300     dumpd(self),
    301     inputs,
    302     name=run_name,
    303 )
    304 try:
    305     outputs = (
--> 306         self._call(inputs, run_manager=run_manager)
    307         if new_arg_supported
    308         else self._call(inputs)
    309     )
    310 except BaseException as e:
    311     run_manager.on_chain_error(e)

File ~\Documents\Python\langchain\.venv\Lib\site-packages\langchain\agents\agent.py:1126, in AgentExecutor._call(self, inputs, run_manager)
   1124 # We now enter the agent loop (until it returns something).
   1125 while self._should_continue(iterations, time_elapsed):
-> 1126     next_step_output = self._take_next_step(
   1127         name_to_tool_map,
   1128         color_mapping,
   1129         inputs,
   1130         intermediate_steps,
   1131         run_manager=run_manager,
   1132     )
   1133     if isinstance(next_step_output, AgentFinish):
   1134         return self._return(
   1135             next_step_output, intermediate_steps, run_manager=run_manager
   1136         )

File ~\Documents\Python\langchain\.venv\Lib\site-packages\langchain\agents\agent.py:981, in AgentExecutor._take_next_step(self, name_to_tool_map, color_mapping, inputs, intermediate_steps, run_manager)
    979         tool_run_kwargs["llm_prefix"] = ""
    980     # We then call the tool on the tool input to get an observation
--> 981     observation = tool.run(
    982         agent_action.tool_input,
    983         verbose=self.verbose,
    984         color=color,
    985         callbacks=run_manager.get_child() if run_manager else None,
    986         **tool_run_kwargs,
    987     )
    988 else:
    989     tool_run_kwargs = self.agent.tool_run_logging_kwargs()

File ~\Documents\Python\langchain\.venv\Lib\site-packages\langchain\tools\base.py:314, in BaseTool.run(self, tool_input, verbose, start_color, color, callbacks, tags, metadata, run_name, **kwargs)
    300 def run(
    301     self,
    302     tool_input: Union[str, Dict],
   (...)
    311     **kwargs: Any,
    312 ) -> Any:
    313     """Run the tool."""
--> 314     parsed_input = self._parse_input(tool_input)
    315     if not self.verbose and verbose is not None:
    316         verbose_ = verbose

File ~\Documents\Python\langchain\.venv\Lib\site-packages\langchain\tools\base.py:245, in BaseTool._parse_input(self, tool_input)
    243     if input_args is not None:
    244         key_ = next(iter(input_args.__fields__.keys()))
--> 245         input_args.validate({key_: tool_input})
    246     return tool_input
    247 else:

File ~\Documents\Python\langchain\.venv\Lib\site-packages\pydantic\v1\main.py:711, in BaseModel.validate(cls, value)
    708 value = cls._enforce_dict_if_root(value)
    710 if isinstance(value, dict):
--> 711     return cls(**value)
    712 elif cls.__config__.orm_mode:
    713     return cls.from_orm(value)

File ~\Documents\Python\langchain\.venv\Lib\site-packages\pydantic\v1\main.py:341, in BaseModel.__init__(__pydantic_self__, **data)
    339 values, fields_set, validation_error = validate_model(__pydantic_self__.__class__, data)
    340 if validation_error:
--> 341     raise validation_error
    342 try:
    343     object_setattr(__pydantic_self__, '__dict__', values)

ValidationError: 1 validation error for plot_graphSchemaSchema
y
  field required (type=value_error.missing)

How do I make this work?

0

There are 0 answers