Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: do not erase existing values on update_inference_endpoint #2476

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 25 additions & 24 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import re
import struct
import warnings
from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import asdict, dataclass, field
from datetime import datetime
Expand Down Expand Up @@ -7674,30 +7675,30 @@ def update_inference_endpoint(
"""
namespace = namespace or self._get_namespace(token=token)

payload: Dict = {}
if any(
value is not None
for value in (accelerator, instance_size, instance_type, min_replica, max_replica, scale_to_zero_timeout)
):
payload["compute"] = {
"accelerator": accelerator,
"instanceSize": instance_size,
"instanceType": instance_type,
"scaling": {
"maxReplica": max_replica,
"minReplica": min_replica,
"scaleToZeroTimeout": scale_to_zero_timeout,
},
}
if any(value is not None for value in (repository, framework, revision, task, custom_image)):
image = {"custom": custom_image} if custom_image is not None else {"huggingface": {}}
payload["model"] = {
"framework": framework,
"repository": repository,
"revision": revision,
"task": task,
"image": image,
}
# Populate only the fields that are not None
payload: Dict = defaultdict(lambda: defaultdict(dict))
if accelerator is not None:
payload["compute"]["accelerator"] = accelerator
if instance_size is not None:
payload["compute"]["instanceSize"] = instance_size
if instance_type is not None:
payload["compute"]["instanceType"] = instance_type
if max_replica is not None:
payload["compute"]["scaling"]["maxReplica"] = max_replica
if min_replica is not None:
payload["compute"]["scaling"]["minReplica"] = min_replica
if scale_to_zero_timeout is not None:
payload["compute"]["scaling"]["scaleToZeroTimeout"] = scale_to_zero_timeout
if repository is not None:
payload["model"]["repository"] = repository
if framework is not None:
payload["model"]["framework"] = framework
if revision is not None:
payload["model"]["revision"] = revision
if task is not None:
payload["model"]["task"] = task
if custom_image is not None:
payload["model"]["image"] = {"custom": custom_image}

response = get_session().put(
f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}",
Expand Down
Loading