Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add gen params validation #2489

Merged
merged 4 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@

# Configure logger
logger = logging.getLogger(__name__)
logger.propagate = False
logger.setLevel(logging.DEBUG)
format_str = "%(asctime)s [%(module)s] %(funcName)s %(lineno)s: %(levelname)-8s [%(process)d] %(message)s"
formatter = logging.Formatter(format_str)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.setLevel(logging.DEBUG)
logger.addHandler(stream_handler)

PORT = 80
Expand Down Expand Up @@ -67,11 +68,45 @@ class SupportedTask:
CHAT_COMPLETION = "chat-completion"


SUPPORTED_INFERENCE_PARAMS = {
# Activate logits sampling
"do_sample": {"type": bool, "default": True},
# Maximum number of generated tokens
"max_new_tokens": {"type": int, "default": 256},
# Generate best_of sequences and return the one if the highest token logprobs
"best_of": {"type": int, "optional": True},
# 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
"repetition_penalty": {"type": int, "optional": True},
# Whether to prepend the prompt to the generated text
"return_full_text": {"type": bool, "default": True},
# The value used to module the logits distribution.
"seed": {"type": int, "optional": True},
# Stop generating tokens if a member of `stop_sequences` is generated
"stop_sequences": {"type": list, "default": []},
# Random sampling seed
"temperature": {"type": float, "optional": True},
# The number of highest probability vocabulary tokens to keep for top-k-filtering.
"top_k": {"type": int, "optional": True},
# If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.
"top_p": {"type": float, "optional": True},
# Truncate inputs tokens to the given size
"truncate": {"type": int, "optional": True},
# Typical Decoding mass. See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
"typical_p": {"type": float, "optional": True},
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
"watermark": {"type": bool, "default": False},
# Get decoder input token logprobs and ids
"decoder_input_details": {"type": bool, "default": False},
}

# default values
MLMODEL_PATH = "mlflow_model_folder/MLmodel"
DEFAULT_MODEL_ID_PATH = "mlflow_model_folder/data/model"
client = None
task_type = SupportedTask.TEXT_GENERATION
default_generator_configs = {
k: v["default"] for k, v in SUPPORTED_INFERENCE_PARAMS.items() if "default" in v
}


def is_server_healthy():
Expand Down Expand Up @@ -301,6 +336,7 @@ def init():
"""Initialize text-generation-inference server and client."""
global client
global task_type
global default_generator_configs
global aacs_client

try:
Expand Down Expand Up @@ -342,12 +378,26 @@ def init():
flavors = mlmodel.get("flavors", {})
if "hftransformersv2" in flavors:
task_type = flavors["hftransformersv2"]["task_type"]
model_generator_configs = flavors["hftransformersv2"].get(
"generator_config", {}
)
logger.info(f"model_generator_configs: {model_generator_configs}")
if task_type not in (
SupportedTask.TEXT_GENERATION,
SupportedTask.CHAT_COMPLETION,
):
raise Exception(f"Unsupported task_type {task_type}")

if task_type == SupportedTask.CHAT_COMPLETION:
default_generator_configs["return_full_text"] = False
# update default gen configs with model configs
default_generator_configs = get_generator_params(
model_generator_configs
)
logger.info(
f"updated default_generator_configs: {default_generator_configs}"
)

logger.info(f"Loading model from path {model_path} for task_type: {task_type}")
logger.info(f"List model_path = {os.listdir(model_path)}")

Expand Down Expand Up @@ -485,6 +535,32 @@ def get_request_data(request_string) -> Tuple[Union[str, List[str]], Dict[str, A
)


def add_and_validate_gen_params(params: dict, new_params: dict):
"""Add and validate inference params."""
if not new_params or not isinstance(new_params, dict):
return params
for k, v in new_params.items():
if not k in SUPPORTED_INFERENCE_PARAMS:
logger.warning(f"Ignoring unsupported inference param {k}.")
elif not isinstance(v, SUPPORTED_INFERENCE_PARAMS[k]["type"]):
logger.warning(
f"Ignoring inference param {k} as value passed is of type {type(v)} and not of type {SUPPORTED_INFERENCE_PARAMS[k]['type']}"
)
else:
params[k] = v
return params


def get_generator_params(params: dict):
"""Return accumulated generator params."""
global default_generator_configs

updated_params = {}
updated_params.update(default_generator_configs)
updated_params = add_and_validate_gen_params(updated_params, params)
return updated_params


def run(data):
"""Run for inference data provided."""
global client
Expand All @@ -502,6 +578,7 @@ def run(data):
raise Exception("Client is not initialized")

query, params = get_request_data(data)
params = get_generator_params(params)
logger.info(
f"generating response for input_string: {query}, parameters: {params}"
)
Expand All @@ -512,8 +589,7 @@ def run(data):
time_taken = time.time() - time_start
logger.info(f"time_taken: {time_taken}")
result_dict = {"output": f"{response_str}"}
resp = pd.DataFrame([result_dict])
return get_safe_response(resp)
return get_safe_response(result_dict)

assert task_type == SupportedTask.TEXT_GENERATION and isinstance(
query, list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
"# The severity level that will trigger response be blocked\n",
"# Please reference Azure AI content documentation for more details\n",
"# https://learn.microsoft.com/en-us/azure/cognitive-services/content-safety/concepts/harm-categories\n",
"content_severity_threshold = \"2\"\n",
"content_severity_threshold = \"0\"\n",
"\n",
"# UAI to be used for endpoint if you choose to use UAI as authentication method\n",
"uai_name = \"\" # default to \"llama-uai\" in prepare uai notebook"
Expand Down Expand Up @@ -618,6 +618,7 @@
" ],\n",
" \"parameters\": {\n",
" \"temperature\": 0.6,\n",
" \"top_p\": 0.6,\n",
" \"max_new_tokens\": 256,\n",
" \"do_sample\": True,\n",
" },\n",
Expand Down Expand Up @@ -647,6 +648,7 @@
" ],\n",
" \"parameters\": {\n",
" \"temperature\": 0.6,\n",
" \"top_p\": 0.6,\n",
" \"max_new_tokens\": 256,\n",
" \"do_sample\": True,\n",
" },\n",
Expand Down