Skip to content

Commit a789575

Browse files
committed
Improve tool call parsing, allow argument extraction from multiple keys
1 parent afa2d78 commit a789575

File tree

4 files changed

+72
-26
lines changed

4 files changed

+72
-26
lines changed

src/smolagents/agents.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ def execute_tool_call(self, tool_name: str, arguments: Union[Dict[str, str], str
602602
elif tool_name in self.managed_agents:
603603
error_msg = (
604604
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
605-
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
605+
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name].description}"
606606
)
607607
raise AgentExecutionError(error_msg, self.logger)
608608

src/smolagents/models.py

+34-16
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,17 @@ def dict(self):
116116
return json.dumps(get_dict_from_nested_dataclasses(self))
117117

118118

119-
def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]:
119+
def parse_json_if_needed(arguments: Union[str, dict] | None) -> Union[str, dict] | None:
120+
if arguments is None:
121+
return None
122+
120123
if isinstance(arguments, dict):
121124
return arguments
122-
else:
123-
try:
124-
return json.loads(arguments)
125-
except Exception:
126-
return arguments
125+
126+
try:
127+
return json.loads(arguments)
128+
except Exception:
129+
return arguments
127130

128131

129132
class MessageRole(str, Enum):
@@ -228,16 +231,23 @@ def get_clean_message_list(
228231
return output_message_list
229232

230233

231-
def get_tool_call_from_text(text: str, tool_name_key: str, tool_arguments_key: str) -> ChatMessageToolCall:
234+
def get_tool_call_from_text(
235+
text: str, tool_name_key: str, tool_arguments_key: Union[list[str], str]
236+
) -> ChatMessageToolCall:
237+
if isinstance(tool_arguments_key, str):
238+
tool_arguments_key = [tool_arguments_key]
239+
232240
tool_call_dictionary, _ = parse_json_blob(text)
233241
try:
234242
tool_name = tool_call_dictionary[tool_name_key]
235-
except Exception as e:
243+
except KeyError as e:
236244
raise ValueError(
237245
f"Key {tool_name_key=} not found in the generated tool call. Got keys: {list(tool_call_dictionary.keys())} instead"
238246
) from e
239-
tool_arguments = tool_call_dictionary.get(tool_arguments_key, None)
247+
248+
tool_arguments = next((tool_call_dictionary[k] for k in tool_arguments_key if k in tool_call_dictionary), None)
240249
tool_arguments = parse_json_if_needed(tool_arguments)
250+
241251
return ChatMessageToolCall(
242252
id=str(uuid.uuid4()),
243253
type="function",
@@ -250,12 +260,16 @@ def __init__(
250260
self,
251261
flatten_messages_as_text: bool = False,
252262
tool_name_key: str = "name",
253-
tool_arguments_key: str = "arguments",
263+
tool_arguments_keys: list[str] | None = None,
254264
**kwargs,
255265
):
256266
self.flatten_messages_as_text = flatten_messages_as_text
257267
self.tool_name_key = tool_name_key
258-
self.tool_arguments_key = tool_arguments_key
268+
269+
self.tool_arguments_keys = tool_arguments_keys
270+
if tool_arguments_keys is None:
271+
self.tool_arguments_keys = ["arguments", "parameters"]
272+
259273
self.kwargs = kwargs
260274
self.last_input_token_count = None
261275
self.last_output_token_count = None
@@ -488,7 +502,7 @@ def __call__(
488502
)
489503
if tools_to_call_from:
490504
chat_message.tool_calls = [
491-
get_tool_call_from_text(output_text, self.tool_name_key, self.tool_arguments_key)
505+
get_tool_call_from_text(output_text, self.tool_name_key, self.tool_arguments_keys)
492506
]
493507
return chat_message
494508

@@ -586,6 +600,8 @@ def __call__(
586600
for _ in self.stream_generate(self.model, self.tokenizer, prompt=prompt_ids, **completion_kwargs):
587601
self.last_output_token_count += 1
588602
text += _.text
603+
604+
found_stop_sequence = False
589605
for stop_sequence in prepared_stop_sequences:
590606
stop_sequence_start = text.rfind(stop_sequence)
591607
if stop_sequence_start != -1:
@@ -804,7 +820,7 @@ def __call__(
804820
)
805821
if tools_to_call_from:
806822
chat_message.tool_calls = [
807-
get_tool_call_from_text(output_text, self.tool_name_key, self.tool_arguments_key)
823+
get_tool_call_from_text(output_text, self.tool_name_key, self.tool_arguments_keys)
808824
]
809825
return chat_message
810826

@@ -818,9 +834,11 @@ def postprocess_message(self, message: ChatMessage, tools_to_call_from) -> ChatM
818834
message.role = MessageRole.ASSISTANT # Overwrite role if needed
819835
if tools_to_call_from:
820836
if not message.tool_calls:
821-
message.tool_calls = [
822-
get_tool_call_from_text(message.content, self.tool_name_key, self.tool_arguments_key)
823-
]
837+
message.tool_calls = (
838+
[get_tool_call_from_text(message.content, self.tool_name_key, self.tool_arguments_keys)]
839+
if message.content
840+
else []
841+
)
824842
for tool_call in message.tool_calls:
825843
tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments)
826844
return message

src/smolagents/utils.py

+31-9
Original file line numberDiff line numberDiff line change
@@ -136,25 +136,47 @@ def make_json_serializable(obj: Any) -> Any:
136136

137137

138138
def parse_json_blob(json_blob: str) -> Tuple[Dict[str, str], str]:
139-
"Extracts the JSON blob from the input and returns the JSON data and the rest of the input."
139+
"Extracts the first valid JSON blob from the input and returns the JSON data and the rest of the input."
140+
first_accolade_index = json_blob.find("{")
141+
if first_accolade_index == -1:
142+
raise ValueError("No JSON object found in input.")
143+
144+
brace_count = 0
145+
end_index = None
146+
147+
# Iterate from first '{' to find balanced closing '}'
148+
for i in range(first_accolade_index, len(json_blob)):
149+
char = json_blob[i]
150+
if char == "{":
151+
brace_count += 1
152+
elif char == "}":
153+
brace_count -= 1
154+
if brace_count == 0:
155+
end_index = i
156+
break
157+
158+
if end_index is None:
159+
raise ValueError("No complete JSON object found: braces are not balanced.")
160+
161+
json_str = json_blob[first_accolade_index : end_index + 1]
140162
try:
141-
first_accolade_index = json_blob.find("{")
142-
last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
143-
json_data = json_blob[first_accolade_index : last_accolade_index + 1]
144-
json_data = json.loads(json_data, strict=False)
145-
return json_data, json_blob[:first_accolade_index]
163+
json_data = json.loads(json_str)
146164
except json.JSONDecodeError as e:
147165
place = e.pos
148-
if json_blob[place - 1 : place + 2] == "},\n":
166+
if json_str[place - 1 : place + 2] == "},\n":
149167
raise ValueError(
150168
"JSON is invalid: you probably tried to provide multiple tool calls in one action. PROVIDE ONLY ONE TOOL CALL."
151169
)
152170
raise ValueError(
153171
f"The JSON blob you used is invalid due to the following error: {e}.\n"
154-
f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n"
155-
f"'{json_blob[place - 4 : place + 5]}'."
172+
f"JSON blob was: {json_str}, decoding failed on that specific part of the blob:\n"
173+
f"'{json_str[place - 4 : place + 5]}'."
156174
)
157175

176+
# The prefix BEFORE JSON (if needed)
177+
prefix = json_blob[:first_accolade_index]
178+
return json_data, prefix
179+
158180

159181
def parse_code_blobs(text: str) -> str:
160182
"""Extract code blocs from the LLM's output.

tests/test_models.py

+6
Original file line numberDiff line numberDiff line change
@@ -457,3 +457,9 @@ def test_get_tool_call_from_text_numeric_args(self):
457457
result = get_tool_call_from_text(text, "name", "arguments")
458458
assert result.function.name == "calculator"
459459
assert result.function.arguments == 42
460+
461+
def test_get_tool_call_from_text_with_mutltiple_argument_keys(self):
462+
text = '{"name": "calculator", "parameters": 42}'
463+
result = get_tool_call_from_text(text, "name", ["arguments", "parameters"])
464+
assert result.function.name == "calculator"
465+
assert result.function.arguments == 42

0 commit comments

Comments
 (0)