Skip to content

Commit

Permalink
Code cleanup and organization
Browse files Browse the repository at this point in the history
  • Loading branch information
ParthSareen committed Nov 14, 2024
1 parent f452fab commit ca16670
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 150 deletions.
111 changes: 111 additions & 0 deletions ollama/_json_type_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import sys
from typing import Any, List, Mapping, Optional, Sequence, Union, get_origin, get_args
from collections.abc import Set
from typing import Dict, Set as TypeSet, TypeVar

T = TypeVar('T')
if sys.version_info >= (3, 10):
from types import UnionType
else:
UnionType = Union[T]

# Python doesn't have a type serializer, so we need to map types to JSON types
TYPE_MAP = {
# Basic types
int: 'integer',
'int': 'integer',
'integer': 'integer',
str: 'string',
'str': 'string',
'string': 'string',
float: 'number',
'float': 'number',
'number': 'number',
bool: 'boolean',
'bool': 'boolean',
'boolean': 'boolean',
type(None): 'null',
None: 'null',
'None': 'null',
'null': 'null',
# Collection types
list: 'array',
'list': 'array',
List: 'array',
'List': 'array',
Sequence: 'array',
'Sequence': 'array',
tuple: 'array',
'tuple': 'array',
set: 'array',
'set': 'array',
Set: 'array',
TypeSet: 'array',
'Set': 'array',
'array': 'array',
# Mapping types
dict: 'object',
'dict': 'object',
Dict: 'object',
'Dict': 'object',
Mapping: 'object',
'Mapping': 'object',
'object': 'object',
Any: 'string',
'Any': 'string',
}

if sys.version_info >= (3, 10):
from types import UnionType

def is_union(tp: Any) -> bool:
return get_origin(tp) in (Union, UnionType)
else:

def is_union(tp: Any) -> bool:
return get_origin(tp) is Union


def _map_type(python_type: Any) -> str:
# Handle generic types (List[int], Dict[str, int], etc.)
origin = get_origin(python_type)
if origin is not None:
# Get the base type (List, Dict, etc.)
base_type = TYPE_MAP.get(origin, None)
if base_type:
return base_type
# If it's a subclass of known abstract base classes, map to appropriate type
if isinstance(origin, type):
if issubclass(origin, (list, Sequence, tuple, set, Set)):
return 'array'
if issubclass(origin, (dict, Mapping)):
return 'object'

# Handle both type objects and type references (older Python versions)
type_key = python_type
if isinstance(python_type, type):
type_key = python_type
elif isinstance(python_type, str):
type_key = python_type

# If type not found in map, try to get the type name
if type_key not in TYPE_MAP and hasattr(python_type, '__name__'):
type_key = python_type.__name__

if type_key in TYPE_MAP:
return TYPE_MAP[type_key]

raise ValueError(f'Could not map Python type {python_type} to a valid JSON type')


def get_json_type(python_type: Union[type, UnionType, Optional[T]]) -> Union[str, List[str]]:
# Handle Optional types (Union[type, None] and type | None)
if is_union(python_type):
args = get_args(python_type)
# Filter out None/NoneType from union args
if non_none_args := [arg for arg in args if arg not in (None, type(None))]:
if len(non_none_args) == 1:
return _map_type(non_none_args[0])
# For multiple return types (e.g., int | str | None), return stringified array of types -> "['integer', 'string', 'null']"
return str([_map_type(arg) for arg in non_none_args]).replace(' ', '')
return _map_type(python_type)
124 changes: 4 additions & 120 deletions ollama/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from base64 import b64encode
from pathlib import Path
from datetime import datetime
from typing import Any, List, Mapping, Optional, TypeVar, Union, Sequence, get_args, get_origin
from collections.abc import Set
from typing import Dict, Set as TypeSet
from typing import Any, Mapping, Optional, Union, Sequence

from ollama._json_type_map import get_json_type, T, UnionType

import sys
from typing_extensions import Annotated, Literal

from pydantic import (
Expand All @@ -19,12 +18,6 @@
model_serializer,
)

T = TypeVar('T')
if sys.version_info >= (3, 10):
from types import UnionType
else:
UnionType = Union[T]


class SubscriptableBaseModel(BaseModel):
def __getitem__(self, key: str) -> Any:
Expand Down Expand Up @@ -242,7 +235,7 @@ class Property(SubscriptableBaseModel):

@model_serializer
def serialize_model(self) -> dict:
return {'type': _get_json_type(self.type), 'description': self.description}
return {'type': get_json_type(self.type), 'description': self.description}

properties: Optional[Mapping[str, Property]] = None

Expand Down Expand Up @@ -451,112 +444,3 @@ def __init__(self, error: str, status_code: int = -1):

self.status_code = status_code
'HTTP status code of the response.'


# Python doesn't have a type serializer, so we need to map types to JSON types
TYPE_MAP = {
# Basic types
int: 'integer',
'int': 'integer',
'integer': 'integer',
str: 'string',
'str': 'string',
'string': 'string',
float: 'number',
'float': 'number',
'number': 'number',
bool: 'boolean',
'bool': 'boolean',
'boolean': 'boolean',
type(None): 'null',
None: 'null',
'None': 'null',
'null': 'null',
# Collection types
list: 'array',
'list': 'array',
List: 'array',
'List': 'array',
Sequence: 'array',
'Sequence': 'array',
tuple: 'array',
'tuple': 'array',
set: 'array',
'set': 'array',
Set: 'array',
TypeSet: 'array',
'Set': 'array',
'array': 'array',
# Mapping types
dict: 'object',
'dict': 'object',
Dict: 'object',
'Dict': 'object',
Mapping: 'object',
'Mapping': 'object',
'object': 'object',
Any: 'string',
'Any': 'string',
}

if sys.version_info >= (3, 10):
from types import UnionType

def is_union(tp: Any) -> bool:
return get_origin(tp) in (Union, UnionType)
else:

def is_union(tp: Any) -> bool:
return get_origin(tp) is Union


def map_type(python_type: Any) -> str:
# Handle generic types (List[int], Dict[str, int], etc.)
origin = get_origin(python_type)
if origin is not None:
# Get the base type (List, Dict, etc.)
base_type = TYPE_MAP.get(origin, None)
if base_type:
return base_type
# If it's a subclass of known abstract base classes, map to appropriate type
if isinstance(origin, type):
if issubclass(origin, (list, Sequence, tuple, set, Set)):
return 'array'
if issubclass(origin, (dict, Mapping)):
return 'object'

# Handle both type objects and type references (older Python versions)
type_key = python_type
if isinstance(python_type, type):
type_key = python_type
elif isinstance(python_type, str):
type_key = python_type

# If type not found in map, try to get the type name
if type_key not in TYPE_MAP and hasattr(python_type, '__name__'):
type_key = python_type.__name__

if type_key in TYPE_MAP:
return TYPE_MAP[type_key]

raise ValueError(f'Could not map Python type {python_type} to a valid JSON type')


def _get_json_type(python_type: Union[type, UnionType, Optional[T]]) -> Union[str, List[str]]:
# Handle Optional types (Union[type, None] and type | None)
if is_union(python_type):
args = get_args(python_type)
# Filter out None/NoneType from union args
if non_none_args := [arg for arg in args if arg not in (None, type(None))]:
if len(non_none_args) == 1:
return map_type(non_none_args[0])
# For multiple return types (e.g., int | str | None), return stringified array of types -> "['integer', 'string', 'null']"
return str([map_type(arg) for arg in non_none_args]).replace(' ', '')
return map_type(python_type)


def _is_optional_type(python_type: Any) -> bool:
if is_union(python_type):
args = get_args(python_type)
return any(arg in (None, type(None)) for arg in args)
return False
27 changes: 19 additions & 8 deletions ollama/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import Callable
from ollama._types import Tool, _is_optional_type
from typing import Any, Callable, get_args
from ollama._json_type_map import is_union
from ollama._types import Tool
from typing import Dict


Expand Down Expand Up @@ -76,6 +77,13 @@ def _parse_docstring(func: Callable, doc_string: str) -> tuple[str, Dict[str, st
return description, param_descriptions


def is_optional_type(python_type: Any) -> bool:
if is_union(python_type):
args = get_args(python_type)
return any(arg in (None, type(None)) for arg in args)
return False


def convert_function_to_tool(func: Callable) -> Tool:
doc_string = func.__doc__
if not doc_string:
Expand All @@ -92,11 +100,14 @@ def convert_function_to_tool(func: Callable) -> Tool:
parameters.properties[param_name] = Tool.Function.Parameters.Property(type=param_type, description=param_descriptions[param_name])

# Only add to required if not optional
if not _is_optional_type(param_type):
if not is_optional_type(param_type):
parameters.required.append(param_name)

function = Tool.Function(name=func.__name__, description=description, parameters=parameters, return_type=None)

tool = Tool(function=function)

return tool
return Tool(
function=Tool.Function(
name=func.__name__,
description=description,
parameters=parameters,
return_type=None,
)
)
44 changes: 22 additions & 22 deletions tests/test_type_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from ollama._types import Image, _get_json_type
from ollama._types import Image, get_json_type


def test_image_serialization():
Expand All @@ -21,49 +21,49 @@ def test_image_serialization():

def test_json_type_conversion():
# Test basic types
assert _get_json_type(List) == 'array'
assert _get_json_type(Dict) == 'object'
assert get_json_type(List) == 'array'
assert get_json_type(Dict) == 'object'


def test_advanced_json_type_conversion():
from typing import Optional, Union, List, Dict, Sequence, Mapping, Set, Tuple, Any

# Test nested collections
assert _get_json_type(List[List[int]]) == 'array'
assert _get_json_type(Dict[str, List[int]]) == 'object'
assert get_json_type(List[List[int]]) == 'array'
assert get_json_type(Dict[str, List[int]]) == 'object'

# Test multiple unions
result = _get_json_type(Union[int, str, float])
result = get_json_type(Union[int, str, float])
# Remove brackets from start/end
result = result[1:-1] if result.startswith('[') else result
assert set(x.strip().strip("'") for x in result.split(',')) == {'integer', 'string', 'number'}

# Test collections.abc types
assert _get_json_type(Sequence[int]) == 'array'
assert _get_json_type(Mapping[str, int]) == 'object'
assert _get_json_type(Set[int]) == 'array'
assert _get_json_type(Tuple[int, str]) == 'array'
assert get_json_type(Sequence[int]) == 'array'
assert get_json_type(Mapping[str, int]) == 'object'
assert get_json_type(Set[int]) == 'array'
assert get_json_type(Tuple[int, str]) == 'array'

# Test nested optionals
assert _get_json_type(Optional[List[Optional[int]]]) == 'array'
assert get_json_type(Optional[List[Optional[int]]]) == 'array'

# Test edge cases
assert _get_json_type(Any) == 'string' # or however you want to handle Any
assert _get_json_type(None) == 'null'
assert _get_json_type(type(None)) == 'null'
assert get_json_type(Any) == 'string' # or however you want to handle Any
assert get_json_type(None) == 'null'
assert get_json_type(type(None)) == 'null'

# Test complex nested types
complex_type = Dict[str, Union[List[int], Optional[str], Dict[str, bool]]]
assert _get_json_type(complex_type) == 'object'
assert get_json_type(complex_type) == 'object'


def test_invalid_types():
# Test that invalid types raise appropriate errors
with pytest.raises(ValueError):
_get_json_type(lambda x: x) # Function type
get_json_type(lambda x: x) # Function type

with pytest.raises(ValueError):
_get_json_type(type) # metaclass
get_json_type(type) # metaclass


if sys.version_info >= (3, 10):
Expand All @@ -72,10 +72,10 @@ def test_json_type_conversion_with_optional():
from typing import Optional

# Test basic types
assert _get_json_type(str) == 'string'
assert _get_json_type(int) == 'integer'
assert _get_json_type(list) == 'array'
assert _get_json_type(dict) == 'object'
assert get_json_type(str) == 'string'
assert get_json_type(int) == 'integer'
assert get_json_type(list) == 'array'
assert get_json_type(dict) == 'object'

# Test Optional
assert _get_json_type(Optional[str]) == 'string'
assert get_json_type(Optional[str]) == 'string'

0 comments on commit ca16670

Please sign in to comment.