Skip to content

ToolCalling actions are no longer being recognised from model response #992

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 32 additions & 16 deletions src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,17 @@ def from_hf_api(cls, message: "ChatCompletionOutputMessage", raw) -> "ChatMessag
return cls.from_dict(asdict(message), raw=raw)


def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]:
def parse_json_if_needed(arguments: Union[str, dict] | None) -> Union[str, dict] | None:
if arguments is None:
return None

if isinstance(arguments, dict):
return arguments
else:
try:
return json.loads(arguments)
except Exception:
return arguments

try:
return json.loads(arguments)
except Exception:
return arguments


class MessageRole(str, Enum):
Expand Down Expand Up @@ -241,16 +244,23 @@ def get_clean_message_list(
return output_message_list


def get_tool_call_from_text(text: str, tool_name_key: str, tool_arguments_key: str) -> ChatMessageToolCall:
def get_tool_call_from_text(
text: str, tool_name_key: str, tool_arguments_key: Union[list[str], str]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I understood well the last argument should betool_arguments_keys instead of tool_arguments_key

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to keep the name not to break anything. As you can see it's either a list or a single value

) -> ChatMessageToolCall:
if isinstance(tool_arguments_key, str):
tool_arguments_key = [tool_arguments_key]

tool_call_dictionary, _ = parse_json_blob(text)
try:
tool_name = tool_call_dictionary[tool_name_key]
except Exception as e:
except KeyError as e:
raise ValueError(
f"Key {tool_name_key=} not found in the generated tool call. Got keys: {list(tool_call_dictionary.keys())} instead"
) from e
tool_arguments = tool_call_dictionary.get(tool_arguments_key, None)

tool_arguments = next((tool_call_dictionary[k] for k in tool_arguments_key if k in tool_call_dictionary), None)
tool_arguments = parse_json_if_needed(tool_arguments)

return ChatMessageToolCall(
id=str(uuid.uuid4()),
type="function",
Expand All @@ -263,12 +273,16 @@ def __init__(
self,
flatten_messages_as_text: bool = False,
tool_name_key: str = "name",
tool_arguments_key: str = "arguments",
tool_arguments_keys: list[str] | None = None,
**kwargs,
):
self.flatten_messages_as_text = flatten_messages_as_text
self.tool_name_key = tool_name_key
self.tool_arguments_key = tool_arguments_key

self.tool_arguments_keys = tool_arguments_keys
if tool_arguments_keys is None:
self.tool_arguments_keys = ["arguments", "parameters"]

self.kwargs = kwargs
self.last_input_token_count = None
self.last_output_token_count = None
Expand Down Expand Up @@ -511,7 +525,7 @@ def __call__(
)
if tools_to_call_from:
chat_message.tool_calls = [
get_tool_call_from_text(output_text, self.tool_name_key, self.tool_arguments_key)
get_tool_call_from_text(output_text, self.tool_name_key, self.tool_arguments_keys)
]
return chat_message

Expand Down Expand Up @@ -822,7 +836,7 @@ def __call__(
)
if tools_to_call_from:
chat_message.tool_calls = [
get_tool_call_from_text(output_text, self.tool_name_key, self.tool_arguments_key)
get_tool_call_from_text(output_text, self.tool_name_key, self.tool_arguments_keys)
]
return chat_message

Expand Down Expand Up @@ -862,9 +876,11 @@ def postprocess_message(self, message: ChatMessage, tools_to_call_from) -> ChatM
message.role = MessageRole.ASSISTANT # Overwrite role if needed
if tools_to_call_from:
if not message.tool_calls:
message.tool_calls = [
get_tool_call_from_text(message.content, self.tool_name_key, self.tool_arguments_key)
]
message.tool_calls = (
[get_tool_call_from_text(message.content, self.tool_name_key, self.tool_arguments_keys)]
if message.content
else []
)
for tool_call in message.tool_calls:
tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments)
return message
Expand Down
102 changes: 84 additions & 18 deletions src/smolagents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,27 +149,93 @@ def make_json_serializable(obj: Any) -> Any:
return str(obj)


def parse_json_blob(json_blob: str) -> Tuple[Dict[str, str], str]:
"Extracts the JSON blob from the input and returns the JSON data and the rest of the input."
def parse_json_blob(json_blob: str) -> Tuple[Dict[str, Any], str]:
"""Extracts the first valid JSON blob from the input and returns the JSON data and the prefix.

Args:
json_blob: String containing a JSON object, possibly with text before and/or after.

Returns:
Tuple of (parsed JSON dict, text before the JSON object)

Raises:
ValueError: If no valid JSON object is found or JSON parsing fails
"""
# Find the first opening brace
first_accolade_index = json_blob.find("{")
if first_accolade_index == -1:
raise ValueError("No JSON object found in input.")

# Track balanced braces
brace_count = 0
in_string = False
escape_char = False
i = first_accolade_index

try:
first_accolade_index = json_blob.find("{")
last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
json_data = json_blob[first_accolade_index : last_accolade_index + 1]
json_data = json.loads(json_data, strict=False)
return json_data, json_blob[:first_accolade_index]
except IndexError:
raise ValueError("The model output does not contain any JSON blob.")
except json.JSONDecodeError as e:
place = e.pos
if json_blob[place - 1 : place + 2] == "},\n":
# Scan for the balanced closing brace
while i < len(json_blob):
char = json_blob[i]

if in_string:
if escape_char:
escape_char = False
elif char == "\\":
escape_char = True
elif char == '"':
in_string = False
else:
if char == '"':
in_string = True
elif char == "{":
brace_count += 1
elif char == "}":
brace_count -= 1
if brace_count == 0:
break

i += 1

# Check if we reached the end without finding a balanced close
if i >= len(json_blob) or brace_count != 0:
if in_string:
raise ValueError("Incomplete JSON object: unclosed string literal.")
elif brace_count > 0:
raise ValueError(f"Incomplete JSON object: missing {brace_count} closing braces.")
else:
# This shouldn't happen with our algorithm, but just in case
raise ValueError("Malformed JSON: unbalanced braces.")

# Extract the complete JSON string
json_str = json_blob[first_accolade_index : i + 1]

# Handle the case where there's another JSON object immediately after
if i + 1 < len(json_blob) and json_blob[i + 1] == "{":
# This is likely valid - we extracted the first complete JSON object
pass

# Try to parse the JSON
try:
json_data = json.loads(json_str)
return json_data, json_blob[:first_accolade_index]
except json.JSONDecodeError as e:
context_range = 10
context_start = max(0, e.pos - context_range)
context_end = min(len(json_str), e.pos + context_range)
position_marker = " " * (min(e.pos, context_range)) + "^"

raise ValueError(
"JSON is invalid: you probably tried to provide multiple tool calls in one action. PROVIDE ONLY ONE TOOL CALL."
f"Invalid JSON: {e.msg} at position {e.pos}.\n"
f"Context:\n{json_str[context_start:context_end]}\n{position_marker}\n"
f"Make sure your JSON is properly formatted."
)
raise ValueError(
f"The JSON blob you used is invalid due to the following error: {e}.\n"
f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n"
f"'{json_blob[place - 4 : place + 5]}'."
)

except IndexError:
# Handle case where we go out of bounds
if brace_count > 0:
raise ValueError(f"Incomplete JSON object: missing {brace_count} closing braces.")
else:
raise ValueError("Malformed JSON: unexpected end of input.")


def parse_code_blobs(text: str) -> str:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,3 +475,9 @@ def test_get_tool_call_from_text_numeric_args(self):
result = get_tool_call_from_text(text, "name", "arguments")
assert result.function.name == "calculator"
assert result.function.arguments == 42

def test_get_tool_call_from_text_with_mutltiple_argument_keys(self):
text = '{"name": "calculator", "parameters": 42}'
result = get_tool_call_from_text(text, "name", ["arguments", "parameters"])
assert result.function.name == "calculator"
assert result.function.arguments == 42
133 changes: 108 additions & 25 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,58 +423,141 @@ def forward(self, task: str) -> str:
@pytest.mark.parametrize(
"raw_json, expected_data, expected_blob",
[
# Nested objects
(
"""{}""",
{},
"""{"outer": {"inner": "value"}}""",
{"outer": {"inner": "value"}},
"",
),
# Multiple nested levels with arrays
(
"""Text{}""",
{},
"Text",
"""{"level1": {"level2": {"level3": [1, 2, {"key": "value"}]}}}""",
{"level1": {"level2": {"level3": [1, 2, {"key": "value"}]}}},
"",
),
# String containing braces
(
"""{"simple": "json"}""",
{"simple": "json"},
"""{"text": "This is {not} a real brace"}""",
{"text": "This is {not} a real brace"},
"",
),
# Multiple JSON-like strings in content, should get first valid one
(
"""Text before {"first": "json"} {"second": "json"}""",
{"first": "json"},
"Text before ",
),
# Escaped quotes in strings
(
"""{"key": "value with \\"quotes\\""}""",
{"key": 'value with "quotes"'},
"",
),
# JSON with whitespace
(
""" {
"spaced": "content"
} """,
{"spaced": "content"},
" ",
),
# Empty object with prefix/suffix
(
"""With text here{"simple": "json"}""",
{"simple": "json"},
"With text here",
"""Action: {} More text""",
{},
"Action: ",
),
# Special characters in string values
(
"""{"simple": "json"}With text after""",
{"simple": "json"},
"""{"special": "\\n\\t\\r"}""",
{"special": "\n\t\r"},
"",
),
# Unicode characters
(
"""With text before{"simple": "json"}And text after""",
{"simple": "json"},
"With text before",
"""{"unicode": "你好世界"}""",
{"unicode": "你好世界"},
"",
),
# Complex keys with quotes
(
"""{"quoted key \\"in\\" brackets": "value"}""",
{'quoted key "in" brackets': "value"},
"",
),
],
)
def test_parse_json_blob_with_valid_json(raw_json, expected_data, expected_blob):
def test_parse_json_blob_advanced_valid_cases(raw_json, expected_data, expected_blob):
data, blob = parse_json_blob(raw_json)

assert data == expected_data
assert blob == expected_blob


@pytest.mark.parametrize(
"raw_json",
"raw_json, expected_exception_text",
[
"""simple": "json"}""",
"""With text here"simple": "json"}""",
"""{"simple": ""json"}With text after""",
"""{"simple": "json"With text after""",
"}}",
# Unclosed JSON object
(
"""{"unclosed": "object""",
"Incomplete JSON object: unclosed string literal",
),
# Unclosed string literal - gets caught by JSON decoder
(
"""{"unclosed: "string}""",
"Invalid JSON: Expecting",
),
# Unterminated escape sequence
(
"""{"bad_escape": "\\""",
"Incomplete JSON object: unclosed string literal",
),
# No JSON at all
(
"""Just plain text without any JSON""",
"No JSON object found in input",
),
# Invalid JSON syntax - missing colon
(
"""{"missing" "colon"}""",
"Invalid JSON",
),
# Invalid JSON syntax - trailing comma
(
"""{"trailing": "comma", }""",
"Invalid JSON",
),
# Deeply nested unclosed braces
(
"""{"level1": {"level2": {"level3": {"level4": "value"}}""",
"Incomplete JSON object: missing 2 closing braces",
),
# Mixed unclosed string and brace
(
"""{"mixed": "problem{"nested": true}""",
"Incomplete JSON object: unclosed string literal",
),
],
)
def test_parse_json_blob_with_invalid_json(raw_json):
with pytest.raises(Exception):
def test_parse_json_blob_advanced_invalid_cases(raw_json, expected_exception_text):
with pytest.raises(ValueError) as excinfo:
parse_json_blob(raw_json)
assert expected_exception_text in str(excinfo.value)


def test_parse_json_blob_with_unicode_surrogate_pairs():
"""Test correct handling of Unicode surrogate pairs"""
# This tests emoji and other characters that use surrogate pairs
json_str = '{"emoji": "😀🌍🚀"}'
data, _ = parse_json_blob(json_str)
assert data["emoji"] == "😀🌍🚀"


def test_parse_json_blob_with_special_characters_in_keys():
"""Test parsing JSON with special characters in keys"""
json_str = '{"special\\nkey": "value"}'
data, _ = parse_json_blob(json_str)
assert "special\nkey" in data
assert data["special\nkey"] == "value"


@pytest.mark.parametrize(
Expand Down