Skip to content

Commit

Permalink
feat: LLM - Added support for stop_sequences in inference
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 558618706
  • Loading branch information
Ark-kun authored and copybara-github committed Aug 20, 2023
1 parent 226ab8b commit 6f7ea84
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/system/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_text_generation(self):
temperature=0,
top_p=1,
top_k=5,
stop_sequences=["# %%"],
).text

def test_text_generation_streaming(self):
Expand Down Expand Up @@ -84,6 +85,7 @@ def test_chat_on_chat_model(self):
),
],
temperature=0.0,
stop_sequences=["# %%"],
)

message1 = "Are my favorite movies based on a book series?"
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,13 +1237,15 @@ def test_text_generation_ga(self):
temperature=0,
top_p=1,
top_k=5,
stop_sequences=["\n"],
)

prediction_parameters = mock_predict.call_args[1]["parameters"]
assert prediction_parameters["maxDecodeSteps"] == 128
assert prediction_parameters["temperature"] == 0
assert prediction_parameters["topP"] == 1
assert prediction_parameters["topK"] == 5
assert prediction_parameters["stopSequences"] == ["\n"]
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]

# Validating that unspecified parameters are not passed to the model
Expand Down Expand Up @@ -1798,16 +1800,19 @@ def test_chat_ga(self):
chat_max_output_tokens = 100
chat_top_k = 1
chat_top_p = 0.1
stop_sequences = ["\n"]
message_temperature = 0.2
message_max_output_tokens = 200
message_top_k = 2
message_top_p = 0.2
message_stop_sequences = ["# %%"]

chat2 = model.start_chat(
temperature=chat_temperature,
max_output_tokens=chat_max_output_tokens,
top_k=chat_top_k,
top_p=chat_top_p,
stop_sequences=stop_sequences,
)

gca_predict_response3 = gca_prediction_service.PredictResponse()
Expand All @@ -1824,19 +1829,22 @@ def test_chat_ga(self):
assert prediction_parameters["maxDecodeSteps"] == chat_max_output_tokens
assert prediction_parameters["topK"] == chat_top_k
assert prediction_parameters["topP"] == chat_top_p
assert prediction_parameters["stopSequences"] == stop_sequences

chat2.send_message(
"Are my favorite movies based on a book series?",
temperature=message_temperature,
max_output_tokens=message_max_output_tokens,
top_k=message_top_k,
top_p=message_top_p,
stop_sequences=message_stop_sequences,
)
prediction_parameters = mock_predict3.call_args[1]["parameters"]
assert prediction_parameters["temperature"] == message_temperature
assert prediction_parameters["maxDecodeSteps"] == message_max_output_tokens
assert prediction_parameters["topK"] == message_top_k
assert prediction_parameters["topP"] == message_top_p
assert prediction_parameters["stopSequences"] == message_stop_sequences

def test_chat_model_send_message_streaming(self):
"""Tests the chat generation model."""
Expand Down Expand Up @@ -2102,6 +2110,7 @@ def test_code_generation(self):
default_max_output_tokens = (
language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
)
stop_sequences = ["\n"]

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
Expand All @@ -2112,10 +2121,12 @@ def test_code_generation(self):
prefix="Write a function that checks if a year is a leap year.",
max_output_tokens=predict_max_output_tokens,
temperature=predict_temperature,
stop_sequences=stop_sequences,
)
prediction_parameters = mock_predict.call_args[1]["parameters"]
assert prediction_parameters["temperature"] == predict_temperature
assert prediction_parameters["maxOutputTokens"] == predict_max_output_tokens
assert prediction_parameters["stopSequences"] == stop_sequences

model.predict(
prefix="Write a function that checks if a year is a leap year.",
Expand Down
33 changes: 33 additions & 0 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ def predict(
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
) -> "TextGenerationResponse":
"""Gets model response for a single prompt.
Expand All @@ -645,6 +646,7 @@ def predict(
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
stop_sequences: Customized stop sequences to stop the decoding process.
Returns:
A `TextGenerationResponse` object that contains the text produced by the model.
Expand All @@ -656,6 +658,7 @@ def predict(
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_sequences=stop_sequences,
)[0]

def _batch_predict(
Expand All @@ -665,6 +668,7 @@ def _batch_predict(
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
) -> List["TextGenerationResponse"]:
"""Gets model response for a single prompt.
Expand All @@ -674,6 +678,7 @@ def _batch_predict(
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
stop_sequences: Customized stop sequences to stop the decoding process.
Returns:
A list of `TextGenerationResponse` objects that contain the texts produced by the model.
Expand All @@ -693,6 +698,9 @@ def _batch_predict(
if top_k:
prediction_parameters["topK"] = top_k

if stop_sequences:
prediction_parameters["stopSequences"] = stop_sequences

prediction_response = self._endpoint.predict(
instances=instances,
parameters=prediction_parameters,
Expand Down Expand Up @@ -1165,6 +1173,7 @@ def start_chat(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
message_history: Optional[List[ChatMessage]] = None,
stop_sequences: Optional[List[str]] = None,
) -> "ChatSession":
"""Starts a chat session with the model.
Expand All @@ -1178,6 +1187,7 @@ def start_chat(
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
message_history: A list of previously sent and received messages.
stop_sequences: Customized stop sequences to stop the decoding process.
Returns:
A `ChatSession` object.
Expand All @@ -1191,6 +1201,7 @@ def start_chat(
top_k=top_k,
top_p=top_p,
message_history=message_history,
stop_sequences=stop_sequences,
)


Expand Down Expand Up @@ -1291,6 +1302,7 @@ def __init__(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
message_history: Optional[List[ChatMessage]] = None,
stop_sequences: Optional[List[str]] = None,
):
self._model = model
self._context = context
Expand All @@ -1300,6 +1312,7 @@ def __init__(
self._top_k = top_k
self._top_p = top_p
self._message_history: List[ChatMessage] = message_history or []
self._stop_sequences = stop_sequences

@property
def message_history(self) -> List[ChatMessage]:
Expand All @@ -1314,6 +1327,7 @@ def _prepare_request(
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
) -> _PredictionRequest:
"""Prepares a request for the language model.
Expand All @@ -1327,6 +1341,7 @@ def _prepare_request(
Uses the value specified when calling `ChatModel.start_chat` by default.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
Uses the value specified when calling `ChatModel.start_chat` by default.
stop_sequences: Customized stop sequences to stop the decoding process.
Returns:
A `_PredictionRequest` object.
Expand All @@ -1350,6 +1365,10 @@ def _prepare_request(
if top_k:
prediction_parameters["topK"] = top_k

stop_sequences = stop_sequences or self._stop_sequences
if stop_sequences:
prediction_parameters["stopSequences"] = stop_sequences

message_structs = []
for past_message in self._message_history:
message_structs.append(
Expand Down Expand Up @@ -1426,6 +1445,7 @@ def send_message(
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
) -> "TextGenerationResponse":
"""Sends message to the language model and gets a response.
Expand All @@ -1439,6 +1459,7 @@ def send_message(
Uses the value specified when calling `ChatModel.start_chat` by default.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
Uses the value specified when calling `ChatModel.start_chat` by default.
stop_sequences: Customized stop sequences to stop the decoding process.
Returns:
A `TextGenerationResponse` object that contains the text produced by the model.
Expand All @@ -1449,6 +1470,7 @@ def send_message(
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_sequences=stop_sequences,
)

prediction_response = self._model._endpoint.predict(
Expand Down Expand Up @@ -1553,6 +1575,7 @@ def __init__(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
message_history: Optional[List[ChatMessage]] = None,
stop_sequences: Optional[List[str]] = None,
):
super().__init__(
model=model,
Expand All @@ -1563,6 +1586,7 @@ def __init__(
top_k=top_k,
top_p=top_p,
message_history=message_history,
stop_sequences=stop_sequences,
)


Expand Down Expand Up @@ -1669,6 +1693,7 @@ def _create_prediction_request(
*,
max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS,
temperature: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
) -> _PredictionRequest:
"""Creates a code generation prediction request.
Expand All @@ -1677,6 +1702,8 @@ def _create_prediction_request(
suffix: Code after the current point.
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
temperature: Controls the randomness of predictions. Range: [0, 1].
stop_sequences: Customized stop sequences to stop the decoding process.
Returns:
A `TextGenerationResponse` object that contains the text produced by the model.
Expand All @@ -1693,6 +1720,9 @@ def _create_prediction_request(
if max_output_tokens:
prediction_parameters["maxOutputTokens"] = max_output_tokens

if stop_sequences:
prediction_parameters["stopSequences"] = stop_sequences

return _PredictionRequest(instance=instance, parameters=prediction_parameters)

def predict(
Expand All @@ -1702,6 +1732,7 @@ def predict(
*,
max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS,
temperature: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
) -> "TextGenerationResponse":
"""Gets model response for a single prompt.
Expand All @@ -1710,6 +1741,7 @@ def predict(
suffix: Code after the current point.
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
temperature: Controls the randomness of predictions. Range: [0, 1].
stop_sequences: Customized stop sequences to stop the decoding process.
Returns:
A `TextGenerationResponse` object that contains the text produced by the model.
Expand All @@ -1719,6 +1751,7 @@ def predict(
suffix=suffix,
max_output_tokens=max_output_tokens,
temperature=temperature,
stop_sequences=stop_sequences,
)

prediction_response = self._endpoint.predict(
Expand Down

0 comments on commit 6f7ea84

Please sign in to comment.