diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 78619f1707..17fa4889a7 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -19,6 +19,7 @@ import re import struct import warnings +from collections import defaultdict from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import asdict, dataclass, field from datetime import datetime @@ -7677,30 +7678,30 @@ def update_inference_endpoint( """ namespace = namespace or self._get_namespace(token=token) - payload: Dict = {} - if any( - value is not None - for value in (accelerator, instance_size, instance_type, min_replica, max_replica, scale_to_zero_timeout) - ): - payload["compute"] = { - "accelerator": accelerator, - "instanceSize": instance_size, - "instanceType": instance_type, - "scaling": { - "maxReplica": max_replica, - "minReplica": min_replica, - "scaleToZeroTimeout": scale_to_zero_timeout, - }, - } - if any(value is not None for value in (repository, framework, revision, task, custom_image)): - image = {"custom": custom_image} if custom_image is not None else {"huggingface": {}} - payload["model"] = { - "framework": framework, - "repository": repository, - "revision": revision, - "task": task, - "image": image, - } + # Populate only the fields that are not None + payload: Dict = defaultdict(lambda: defaultdict(dict)) + if accelerator is not None: + payload["compute"]["accelerator"] = accelerator + if instance_size is not None: + payload["compute"]["instanceSize"] = instance_size + if instance_type is not None: + payload["compute"]["instanceType"] = instance_type + if max_replica is not None: + payload["compute"]["scaling"]["maxReplica"] = max_replica + if min_replica is not None: + payload["compute"]["scaling"]["minReplica"] = min_replica + if scale_to_zero_timeout is not None: + payload["compute"]["scaling"]["scaleToZeroTimeout"] = scale_to_zero_timeout + if repository is not None: + payload["model"]["repository"] = repository + if framework is not None: + payload["model"]["framework"] = framework + if revision is not None: + payload["model"]["revision"] = revision + if task is not None: + payload["model"]["task"] = task + if custom_image is not None: + payload["model"]["image"] = {"custom": custom_image} response = get_session().put( f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}",