Skip to content

Commit

Permalink
Add str -> str valid json mapping and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
ParthSareen committed Nov 14, 2024
1 parent e68700c commit f452fab
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ async def chat(
model: str = '',
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[True] = True,
format: Optional[Literal['', 'json']] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
Expand Down
9 changes: 8 additions & 1 deletion ollama/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ class Parameters(SubscriptableBaseModel):

class Property(SubscriptableBaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
type: Union[type, UnionType, Optional[T]]
type: Union[type, UnionType, Optional[T], str]
description: str

@model_serializer
Expand Down Expand Up @@ -458,15 +458,20 @@ def __init__(self, error: str, status_code: int = -1):
# Basic types
int: 'integer',
'int': 'integer',
'integer': 'integer',
str: 'string',
'str': 'string',
'string': 'string',
float: 'number',
'float': 'number',
'number': 'number',
bool: 'boolean',
'bool': 'boolean',
'boolean': 'boolean',
type(None): 'null',
None: 'null',
'None': 'null',
'null': 'null',
# Collection types
list: 'array',
'list': 'array',
Expand All @@ -481,13 +486,15 @@ def __init__(self, error: str, status_code: int = -1):
Set: 'array',
TypeSet: 'array',
'Set': 'array',
'array': 'array',
# Mapping types
dict: 'object',
'dict': 'object',
Dict: 'object',
'Dict': 'object',
Mapping: 'object',
'Mapping': 'object',
'object': 'object',
Any: 'string',
'Any': 'string',
}
Expand Down
20 changes: 17 additions & 3 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,11 +1012,25 @@ def func2(y: str) -> int:
assert list(_copy_tools([])) == []

# Test with mix of functions and tool dicts
tool_dict = {'type': 'function', 'function': {'name': 'test', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'string', 'description': 'A string'}}, 'required': ['x']}}}
tools = list(_copy_tools([func1, tool_dict]))
assert len(tools) == 2
tool_dict = {
'type': 'function',
'function': {
'name': 'test',
'description': 'Test function',
'parameters': {
'type': 'object',
'properties': {'x': {'type': 'string', 'description': 'A string'}},
'required': ['x'],
},
},
}

tool_json = json.loads(json.dumps(tool_dict))
tools = list(_copy_tools([func1, tool_dict, tool_json]))
assert len(tools) == 3
assert tools[0].function.name == 'func1'
assert tools[1].function.name == 'test'
assert tools[2].function.name == 'test'


def test_tool_validation():
Expand Down

0 comments on commit f452fab

Please sign in to comment.