@@ -126,10 +126,27 @@ def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]:
126
126
return arguments
127
127
128
128
129
- def parse_tool_args_if_needed (message : ChatMessage ) -> ChatMessage :
129
+ def parse_tool_args_if_needed (
130
+ message : ChatMessage ,
131
+ tool_name_key : str = "name" ,
132
+ tool_arguments_key : str = "arguments" ,
133
+ ) -> ChatMessage :
130
134
if message .tool_calls is not None :
131
135
for tool_call in message .tool_calls :
132
136
tool_call .function .arguments = parse_json_if_needed (tool_call .function .arguments )
137
+ elif isinstance (message .content , str ) and message .content .startswith ("Action:" ):
138
+ output = message .content .split ("Action:" , 1 )[1 ].strip ()
139
+ start_index = output .index ("{" )
140
+ end_index = output .rindex ("}" )
141
+ output = output [start_index : end_index + 1 ]
142
+
143
+ return get_tool_call_chat_message_from_text (
144
+ output ,
145
+ tool_name_key ,
146
+ tool_arguments_key ,
147
+ )
148
+
149
+ __import__ ("pdb" ).set_trace ()
133
150
return message
134
151
135
152
@@ -490,7 +507,7 @@ def __call__(
490
507
self .last_output_token_count = response .usage .completion_tokens
491
508
message = ChatMessage .from_hf_api (response .choices [0 ].message , raw = response )
492
509
if tools_to_call_from is not None :
493
- return parse_tool_args_if_needed (message )
510
+ return parse_tool_args_if_needed (message , self . tool_name_key , self . tool_arguments_key )
494
511
return message
495
512
496
513
@@ -685,6 +702,8 @@ def __call__(
685
702
for _ in self .stream_generate (self .model , self .tokenizer , prompt = prompt_ids , ** completion_kwargs ):
686
703
self .last_output_token_count += 1
687
704
text += _ .text
705
+
706
+ found_stop_sequence = False
688
707
for stop_sequence in prepared_stop_sequences :
689
708
stop_sequence_start = text .rfind (stop_sequence )
690
709
if stop_sequence_start != - 1 :
@@ -990,7 +1009,7 @@ def __call__(
990
1009
message .raw = response
991
1010
992
1011
if tools_to_call_from is not None :
993
- return parse_tool_args_if_needed (message )
1012
+ return parse_tool_args_if_needed (message , self . tool_name_key , self . tool_arguments_key )
994
1013
return message
995
1014
996
1015
@@ -1076,7 +1095,7 @@ def __call__(
1076
1095
)
1077
1096
message .raw = response
1078
1097
if tools_to_call_from is not None :
1079
- return parse_tool_args_if_needed (message )
1098
+ return parse_tool_args_if_needed (message , self . tool_name_key , self . tool_arguments_key )
1080
1099
return message
1081
1100
1082
1101
0 commit comments