Skip to content

Commit

Permalink
Fix: do not erase existing values on update_inference_endpoint (#2476)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin authored Aug 21, 2024
1 parent ec5812d commit 7d98441
Showing 1 changed file with 25 additions and 24 deletions.
49 changes: 25 additions & 24 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import re
import struct
import warnings
from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import asdict, dataclass, field
from datetime import datetime
Expand Down Expand Up @@ -7677,30 +7678,30 @@ def update_inference_endpoint(
"""
namespace = namespace or self._get_namespace(token=token)

payload: Dict = {}
if any(
value is not None
for value in (accelerator, instance_size, instance_type, min_replica, max_replica, scale_to_zero_timeout)
):
payload["compute"] = {
"accelerator": accelerator,
"instanceSize": instance_size,
"instanceType": instance_type,
"scaling": {
"maxReplica": max_replica,
"minReplica": min_replica,
"scaleToZeroTimeout": scale_to_zero_timeout,
},
}
if any(value is not None for value in (repository, framework, revision, task, custom_image)):
image = {"custom": custom_image} if custom_image is not None else {"huggingface": {}}
payload["model"] = {
"framework": framework,
"repository": repository,
"revision": revision,
"task": task,
"image": image,
}
# Populate only the fields that are not None
payload: Dict = defaultdict(lambda: defaultdict(dict))
if accelerator is not None:
payload["compute"]["accelerator"] = accelerator
if instance_size is not None:
payload["compute"]["instanceSize"] = instance_size
if instance_type is not None:
payload["compute"]["instanceType"] = instance_type
if max_replica is not None:
payload["compute"]["scaling"]["maxReplica"] = max_replica
if min_replica is not None:
payload["compute"]["scaling"]["minReplica"] = min_replica
if scale_to_zero_timeout is not None:
payload["compute"]["scaling"]["scaleToZeroTimeout"] = scale_to_zero_timeout
if repository is not None:
payload["model"]["repository"] = repository
if framework is not None:
payload["model"]["framework"] = framework
if revision is not None:
payload["model"]["revision"] = revision
if task is not None:
payload["model"]["task"] = task
if custom_image is not None:
payload["model"]["image"] = {"custom": custom_image}

response = get_session().put(
f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}",
Expand Down

0 comments on commit 7d98441

Please sign in to comment.