@@ -116,14 +116,17 @@ def dict(self):
116
116
return json .dumps (get_dict_from_nested_dataclasses (self ))
117
117
118
118
119
- def parse_json_if_needed (arguments : Union [str , dict ]) -> Union [str , dict ]:
119
+ def parse_json_if_needed (arguments : Union [str , dict ] | None ) -> Union [str , dict ] | None :
120
+ if arguments is None :
121
+ return None
122
+
120
123
if isinstance (arguments , dict ):
121
124
return arguments
122
- else :
123
- try :
124
- return json .loads (arguments )
125
- except Exception :
126
- return arguments
125
+
126
+ try :
127
+ return json .loads (arguments )
128
+ except Exception :
129
+ return arguments
127
130
128
131
129
132
class MessageRole (str , Enum ):
@@ -228,16 +231,23 @@ def get_clean_message_list(
228
231
return output_message_list
229
232
230
233
231
- def get_tool_call_from_text (text : str , tool_name_key : str , tool_arguments_key : str ) -> ChatMessageToolCall :
234
+ def get_tool_call_from_text (
235
+ text : str , tool_name_key : str , tool_arguments_key : Union [list [str ], str ]
236
+ ) -> ChatMessageToolCall :
237
+ if isinstance (tool_arguments_key , str ):
238
+ tool_arguments_key = [tool_arguments_key ]
239
+
232
240
tool_call_dictionary , _ = parse_json_blob (text )
233
241
try :
234
242
tool_name = tool_call_dictionary [tool_name_key ]
235
- except Exception as e :
243
+ except KeyError as e :
236
244
raise ValueError (
237
245
f"Key { tool_name_key = } not found in the generated tool call. Got keys: { list (tool_call_dictionary .keys ())} instead"
238
246
) from e
239
- tool_arguments = tool_call_dictionary .get (tool_arguments_key , None )
247
+
248
+ tool_arguments = next ((tool_call_dictionary [k ] for k in tool_arguments_key if k in tool_call_dictionary ), None )
240
249
tool_arguments = parse_json_if_needed (tool_arguments )
250
+
241
251
return ChatMessageToolCall (
242
252
id = str (uuid .uuid4 ()),
243
253
type = "function" ,
@@ -250,12 +260,16 @@ def __init__(
250
260
self ,
251
261
flatten_messages_as_text : bool = False ,
252
262
tool_name_key : str = "name" ,
253
- tool_arguments_key : str = "arguments" ,
263
+ tool_arguments_keys : list [ str ] | None = None ,
254
264
** kwargs ,
255
265
):
256
266
self .flatten_messages_as_text = flatten_messages_as_text
257
267
self .tool_name_key = tool_name_key
258
- self .tool_arguments_key = tool_arguments_key
268
+
269
+ self .tool_arguments_keys = tool_arguments_keys
270
+ if tool_arguments_keys is None :
271
+ self .tool_arguments_keys = ["arguments" , "parameters" ]
272
+
259
273
self .kwargs = kwargs
260
274
self .last_input_token_count = None
261
275
self .last_output_token_count = None
@@ -488,7 +502,7 @@ def __call__(
488
502
)
489
503
if tools_to_call_from :
490
504
chat_message .tool_calls = [
491
- get_tool_call_from_text (output_text , self .tool_name_key , self .tool_arguments_key )
505
+ get_tool_call_from_text (output_text , self .tool_name_key , self .tool_arguments_keys )
492
506
]
493
507
return chat_message
494
508
@@ -586,6 +600,8 @@ def __call__(
586
600
for _ in self .stream_generate (self .model , self .tokenizer , prompt = prompt_ids , ** completion_kwargs ):
587
601
self .last_output_token_count += 1
588
602
text += _ .text
603
+
604
+ found_stop_sequence = False
589
605
for stop_sequence in prepared_stop_sequences :
590
606
stop_sequence_start = text .rfind (stop_sequence )
591
607
if stop_sequence_start != - 1 :
@@ -804,7 +820,7 @@ def __call__(
804
820
)
805
821
if tools_to_call_from :
806
822
chat_message .tool_calls = [
807
- get_tool_call_from_text (output_text , self .tool_name_key , self .tool_arguments_key )
823
+ get_tool_call_from_text (output_text , self .tool_name_key , self .tool_arguments_keys )
808
824
]
809
825
return chat_message
810
826
@@ -818,9 +834,11 @@ def postprocess_message(self, message: ChatMessage, tools_to_call_from) -> ChatM
818
834
message .role = MessageRole .ASSISTANT # Overwrite role if needed
819
835
if tools_to_call_from :
820
836
if not message .tool_calls :
821
- message .tool_calls = [
822
- get_tool_call_from_text (message .content , self .tool_name_key , self .tool_arguments_key )
823
- ]
837
+ message .tool_calls = (
838
+ [get_tool_call_from_text (message .content , self .tool_name_key , self .tool_arguments_keys )]
839
+ if message .content
840
+ else []
841
+ )
824
842
for tool_call in message .tool_calls :
825
843
tool_call .function .arguments = parse_json_if_needed (tool_call .function .arguments )
826
844
return message
0 commit comments