From b0da4ff2d87f084c4a9098c554cc30b2817a48c2 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 9 Sep 2024 11:46:19 -0700 Subject: [PATCH] no head this request is unnecessary since the POST will short circuit the request if the blob already exists --- ollama/_client.py | 34 +++++++++++----------------------- tests/test_client.py | 26 ++++++++++++-------------- 2 files changed, 23 insertions(+), 37 deletions(-) diff --git a/ollama/_client.py b/ollama/_client.py index c1f5f95..723eb65 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -523,14 +523,8 @@ def _create_blob(self, path: Union[str, Path]) -> str: digest = f'sha256:{sha256sum.hexdigest()}' - try: - self._request_raw('HEAD', f'/api/blobs/{digest}') - except ResponseError as e: - if e.status_code != 404: - raise - - with open(path, 'rb') as r: - self._request_raw('POST', f'/api/blobs/{digest}', content=r) + with open(path, 'rb') as r: + self._request_raw('POST', f'/api/blobs/sha256:{digest}', content=r) return digest @@ -1007,21 +1001,15 @@ async def _create_blob(self, path: Union[str, Path]) -> str: digest = f'sha256:{sha256sum.hexdigest()}' - try: - await self._request_raw('HEAD', f'/api/blobs/{digest}') - except ResponseError as e: - if e.status_code != 404: - raise - - async def upload_bytes(): - with open(path, 'rb') as r: - while True: - chunk = r.read(32 * 1024) - if not chunk: - break - yield chunk - - await self._request_raw('POST', f'/api/blobs/{digest}', content=upload_bytes()) + async def upload_bytes(): + with open(path, 'rb') as r: + while True: + chunk = r.read(32 * 1024) + if not chunk: + break + yield chunk + + await self._request_raw('POST', f'/api/blobs/{digest}', content=upload_bytes()) return digest diff --git a/tests/test_client.py b/tests/test_client.py index 3bb451c..4983610 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -293,7 +293,7 @@ def generate(): def test_client_create_path(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200)) httpserver.expect_ordered_request( '/api/create', method='POST', @@ -316,7 +316,7 @@ def test_client_create_path(httpserver: HTTPServer): def test_client_create_path_relative(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200)) httpserver.expect_ordered_request( '/api/create', method='POST', @@ -348,7 +348,7 @@ def userhomedir(): def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200)) httpserver.expect_ordered_request( '/api/create', method='POST', @@ -371,7 +371,7 @@ def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir): def test_client_create_modelfile(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200)) httpserver.expect_ordered_request( '/api/create', method='POST', @@ -390,7 +390,7 @@ def test_client_create_modelfile(httpserver: HTTPServer): def test_client_create_modelfile_roundtrip(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200)) httpserver.expect_ordered_request( '/api/create', method='POST', @@ -455,7 +455,6 @@ def test_client_create_from_library(httpserver: HTTPServer): def test_client_create_blob(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404)) httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=201)) client = Client(httpserver.url_for('/')) @@ -466,7 +465,7 @@ def test_client_create_blob(httpserver: HTTPServer): def test_client_create_blob_exists(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200)) client = Client(httpserver.url_for('/')) @@ -774,7 +773,7 @@ def generate(): @pytest.mark.asyncio async def test_async_client_create_path(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200)) httpserver.expect_ordered_request( '/api/create', method='POST', @@ -798,7 +797,7 @@ async def test_async_client_create_path(httpserver: HTTPServer): @pytest.mark.asyncio async def test_async_client_create_path_relative(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200)) httpserver.expect_ordered_request( '/api/create', method='POST', @@ -822,7 +821,7 @@ async def test_async_client_create_path_relative(httpserver: HTTPServer): @pytest.mark.asyncio async def test_async_client_create_path_user_home(httpserver: HTTPServer, userhomedir): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200)) httpserver.expect_ordered_request( '/api/create', method='POST', @@ -846,7 +845,7 @@ async def test_async_client_create_path_user_home(httpserver: HTTPServer, userho @pytest.mark.asyncio async def test_async_client_create_modelfile(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200)) httpserver.expect_ordered_request( '/api/create', method='POST', @@ -866,7 +865,7 @@ async def test_async_client_create_modelfile(httpserver: HTTPServer): @pytest.mark.asyncio async def test_async_client_create_modelfile_roundtrip(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200)) httpserver.expect_ordered_request( '/api/create', method='POST', @@ -933,7 +932,6 @@ async def test_async_client_create_from_library(httpserver: HTTPServer): @pytest.mark.asyncio async def test_async_client_create_blob(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404)) httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=201)) client = AsyncClient(httpserver.url_for('/')) @@ -945,7 +943,7 @@ async def test_async_client_create_blob(httpserver: HTTPServer): @pytest.mark.asyncio async def test_async_client_create_blob_exists(httpserver: HTTPServer): - httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) + httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200)) client = AsyncClient(httpserver.url_for('/'))