Skip to content

Commit bac939e

Browse files
committed
For TooCallingAction fix extraction of function arguments from response
1 parent ed51b6b commit bac939e

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

src/smolagents/models.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,27 @@ def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]:
126126
return arguments
127127

128128

129-
def parse_tool_args_if_needed(message: ChatMessage) -> ChatMessage:
129+
def parse_tool_args_if_needed(
130+
message: ChatMessage,
131+
tool_name_key: str = "name",
132+
tool_arguments_key: str = "arguments",
133+
) -> ChatMessage:
130134
if message.tool_calls is not None:
131135
for tool_call in message.tool_calls:
132136
tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments)
137+
elif isinstance(message.content, str) and message.content.startswith("Action:"):
138+
output = message.content.split("Action:", 1)[1].strip()
139+
start_index = output.index("{")
140+
end_index = output.rindex("}")
141+
output = output[start_index : end_index + 1]
142+
143+
return get_tool_call_chat_message_from_text(
144+
output,
145+
tool_name_key,
146+
tool_arguments_key,
147+
)
148+
149+
__import__("pdb").set_trace()
133150
return message
134151

135152

@@ -490,7 +507,7 @@ def __call__(
490507
self.last_output_token_count = response.usage.completion_tokens
491508
message = ChatMessage.from_hf_api(response.choices[0].message, raw=response)
492509
if tools_to_call_from is not None:
493-
return parse_tool_args_if_needed(message)
510+
return parse_tool_args_if_needed(message, self.tool_name_key, self.tool_arguments_key)
494511
return message
495512

496513

@@ -685,6 +702,8 @@ def __call__(
685702
for _ in self.stream_generate(self.model, self.tokenizer, prompt=prompt_ids, **completion_kwargs):
686703
self.last_output_token_count += 1
687704
text += _.text
705+
706+
found_stop_sequence = False
688707
for stop_sequence in prepared_stop_sequences:
689708
stop_sequence_start = text.rfind(stop_sequence)
690709
if stop_sequence_start != -1:
@@ -990,7 +1009,7 @@ def __call__(
9901009
message.raw = response
9911010

9921011
if tools_to_call_from is not None:
993-
return parse_tool_args_if_needed(message)
1012+
return parse_tool_args_if_needed(message, self.tool_name_key, self.tool_arguments_key)
9941013
return message
9951014

9961015

@@ -1076,7 +1095,7 @@ def __call__(
10761095
)
10771096
message.raw = response
10781097
if tools_to_call_from is not None:
1079-
return parse_tool_args_if_needed(message)
1098+
return parse_tool_args_if_needed(message, self.tool_name_key, self.tool_arguments_key)
10801099
return message
10811100

10821101

tests/test_models.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_transformers_message_vl_no_tool(self):
109109
assert output == "Hello! How can"
110110

111111
def test_parse_tool_args_if_needed(self):
112-
original_message = ChatMessage(role="user", content=[{"type": "text", "text": "Hello!"}])
112+
original_message = ChatMessage(role="user", content={"type": "text", "text": "Hello!"})
113113
parsed_message = parse_tool_args_if_needed(original_message)
114114
assert parsed_message == original_message
115115

@@ -365,3 +365,16 @@ def test_flatten_messages_as_text_for_all_models(
365365

366366
model = model_class(**{"model_id": "test-model", **model_kwargs})
367367
assert model.flatten_messages_as_text is expected_flatten_messages_as_text, f"{model_class.__name__} failed"
368+
369+
370+
@pytest.mark.parametrize(
371+
"content",
372+
[
373+
'Action:\n{"name": "tool", "arguments": {}}',
374+
'Action: {"name": "tool", "arguments": {}}',
375+
],
376+
)
377+
def test_parse_tool_args_if_needed_returns_message_with_tools_set_for_tool_calling_agent(content):
378+
original_message = ChatMessage(role="user", content=content)
379+
parsed_message = parse_tool_args_if_needed(original_message)
380+
assert parsed_message.tool_calls

0 commit comments

Comments
 (0)