From 0983791b343d5adaedfe6cac5484e9f67a9a135a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sun, 2 Jun 2024 22:34:40 +0200 Subject: [PATCH 1/3] Support parallel function calls with tool_choice --- llama_cpp/llama_chat_format.py | 150 ++++++++++++++++++++------------- 1 file changed, 91 insertions(+), 59 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 8f3b1de11..6d63029d2 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -330,10 +330,49 @@ def _convert_completion_to_chat_function( ], stream: bool, ): + def _completion_text_to_tool_calls( + tool_name: str, + completion_text: str, + completion_id: str + stream: bool, + ) -> Union[ + llama_types.ChatCompletionMessageToolCalls, List[llama_types.ChatCompletionMessageToolCallChunk] + ]: + try: + function_calls = json.loads(completion_text) + assert isinstance(function_calls, list) + except Exception as e: + function_calls = [] + + i = 0 + tool_calls = [] + for function_call in function_calls: + function_name = function_call.get("name") + function_arguments = function_call.get("arguments") + if function_name == tool_name and function_arguments: + tool_id = f'call__{i}_{tool_name}_{completion_id}' + tool_call = { + "id": tool_id, + "type": "function", + "function": { + "name": tool_name, + "arguments": json.dumps(function_arguments), + }, + } + if stream: + tool_call["index"] = i + typed_call: llama_types.ChatCompletionMessageToolCallChunk = tool_call + else: + typed_call: llama_types.ChatCompletionMessageToolCall = tool_call + tool_calls.append(typed_call) + i += 1 + + return tool_calls + if not stream: completion: llama_types.CreateCompletionResponse = completion_or_chunks # type: ignore assert "usage" in completion - tool_id = "call_" + "_0_" + tool_name + "_" + completion["id"] + tool_calls: llama_types.ChatCompletionMessageToolCalls = _completion_text_to_tool_calls(tool_name, completion["choices"][0]["text"], completion["id"], stream) # type: ignore # TODO: Fix for legacy function calls chat_completion: llama_types.CreateChatCompletionResponse = { "id": "chat" + completion["id"], @@ -345,24 +384,12 @@ def _convert_completion_to_chat_function( "index": 0, "message": { "role": "assistant", - "content": None, - "function_call": { - "name": tool_name, - "arguments": completion["choices"][0]["text"], - }, - "tool_calls": [ - { - "id": tool_id, - "type": "function", - "function": { - "name": tool_name, - "arguments": completion["choices"][0]["text"], - }, - } - ], + "content": None if tool_calls else completion["choices"][0]["text"], + "function_call": tool_calls[0]["function"] if tool_calls else None, + "tool_calls": tool_calls or None, }, "logprobs": completion["choices"][0]["logprobs"], - "finish_reason": "tool_calls", + "finish_reason": "tool_calls" if tool_calls else completion["choices"][0]["finish_reason"], } ], "usage": completion["usage"], @@ -379,13 +406,15 @@ def _stream_response_to_function_stream( id_ = None created = None model = None - tool_id = None + finish = None + tools_called = "" for chunk in chunks: + tools_called += chunk["choices"][0]["text"] + finish = chunk["choices"][0]["finish_reason"] if first: id_ = "chat" + chunk["id"] created = chunk["created"] model = chunk["model"] - tool_id = "call_" + "_0_" + tool_name + "_" + chunk["id"] yield { "id": id_, "object": "chat.completion.chunk", @@ -417,29 +446,15 @@ def _stream_response_to_function_stream( "logprobs": chunk["choices"][0]["logprobs"], "delta": { "role": None, - "content": None, - "function_call": { - "name": tool_name, - "arguments": chunk["choices"][0]["text"], - }, - "tool_calls": [ - { - "index": 0, - "id": tool_id, - "type": "function", - "function": { - "name": tool_name, - "arguments": chunk["choices"][0]["text"], - }, - } - ], + "content": chunk["choices"][0]["text"], + "function_call": None, + "tool_calls": None, }, } ], } first = False continue - assert tool_id is not None yield { "id": "chat" + chunk["id"], "object": "chat.completion.chunk", @@ -452,30 +467,16 @@ def _stream_response_to_function_stream( "logprobs": chunk["choices"][0]["logprobs"], "delta": { "role": None, - "content": None, - "function_call": { - "name": tool_name, - "arguments": chunk["choices"][0]["text"], - }, - "tool_calls": [ - { - "index": 0, - "id": tool_id, - "type": "function", - "function": { - "name": tool_name, - "arguments": chunk["choices"][0][ - "text" - ], - }, - } - ], + "content": chunk["choices"][0]["text"], + "function_call": None, + "tool_calls": None, }, } ], } if id_ is not None and created is not None and model is not None: + tool_calls: List[llama_types.ChatCompletionMessageToolCallChunk] = _completion_text_to_tool_calls(tool_name, tools_called, id_, stream) # type: ignore yield { "id": id_, "object": "chat.completion.chunk", @@ -484,13 +485,13 @@ def _stream_response_to_function_stream( "choices": [ { "index": 0, - "finish_reason": "tool_calls", + "finish_reason": "tool_calls" if tool_calls else finish, "logprobs": None, "delta": { "role": None, "content": None, - "function_call": None, - "tool_calls": None, + "function_call": tool_calls[0]["function"] if tool_calls else None, + "tool_calls": tool_calls or None, }, } ], @@ -591,7 +592,22 @@ def chat_completion_handler( tool = next((t for t in tools if t["function"]["name"] == name), None) if tool is None: raise ValueError(f"Tool choice '{name}' not found in tools.") - schema = tool["function"]["parameters"] + schema = { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "arguments": tool["function"]["parameters"] + }, + "required": [ + "name", + "arguments" + ] + } + } try: # create grammar from json schema grammar = llama_grammar.LlamaGrammar.from_json_schema( @@ -3337,9 +3353,25 @@ def chatml_function_calling( add_generation_prompt=True, ) prompt += f"functions.{tool_name}:\n" + schema = { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "arguments": tool["function"]["parameters"] + }, + "required": [ + "name", + "arguments" + ] + } + } try: grammar = llama_grammar.LlamaGrammar.from_json_schema( - json.dumps(tool["function"]["parameters"]), verbose=llama.verbose + json.dumps(schema), verbose=llama.verbose ) except Exception as e: grammar = llama_grammar.LlamaGrammar.from_string( From cc59985394f01e724938109640f482eedfad8f1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sun, 2 Jun 2024 22:58:46 +0200 Subject: [PATCH 2/3] typo-- --- llama_cpp/llama_chat_format.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 6d63029d2..4019e0972 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -333,7 +333,7 @@ def _convert_completion_to_chat_function( def _completion_text_to_tool_calls( tool_name: str, completion_text: str, - completion_id: str + completion_id: str, stream: bool, ) -> Union[ llama_types.ChatCompletionMessageToolCalls, List[llama_types.ChatCompletionMessageToolCallChunk] From 7ed205eac4c10af3684957749885c008d09ad1a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sun, 2 Jun 2024 23:39:55 +0200 Subject: [PATCH 3/3] dump arguments as unicode --- llama_cpp/llama_chat_format.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 4019e0972..d046379c6 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -356,7 +356,7 @@ def _completion_text_to_tool_calls( "type": "function", "function": { "name": tool_name, - "arguments": json.dumps(function_arguments), + "arguments": json.dumps(function_arguments, ensure_ascii=False), }, } if stream: