Skip to content

Commit

Permalink
feat: Add eot_token property to ControlModel and derived classes.
Browse files Browse the repository at this point in the history
Task: PHS-492
  • Loading branch information
SebastianNiehusAA committed Jun 11, 2024
1 parent 5c82bd4 commit a87e5cc
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
...

### New Features
...
- Add `eot_token` property to `ControlModel` and derived classes (`LuminousControlModel`, `Llama2InstructModel` and `Llama3InstructModel`) and let `PromptBasedClassify` use this property instead of a hardcoded string.

### Fixes
...
Expand Down
24 changes: 20 additions & 4 deletions src/intelligence_layer/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ def __init__(
)
super().__init__(name, client)

@property
@abstractmethod
def eot_token(self) -> str:
pass

@abstractmethod
def to_instruct_prompt(
self,
Expand Down Expand Up @@ -256,6 +261,10 @@ def __init__(
) -> None:
super().__init__(name, client)

@property
def eot_token(self) -> str:
return "<|endoftext|>"

def to_instruct_prompt(
self,
instruction: str,
Expand Down Expand Up @@ -300,6 +309,10 @@ def __init__(
) -> None:
super().__init__(name, client)

@property
def eot_token(self) -> str:
return "<|endoftext|>"

def to_instruct_prompt(
self,
instruction: str,
Expand Down Expand Up @@ -330,7 +343,6 @@ class Llama3InstructModel(ControlModel):
{{response_prefix}}{% endif %}"""
)
EOT_TOKEN = "<|eot_id|>"

RECOMMENDED_MODELS = [
"llama-3-8b-instruct",
Expand All @@ -344,14 +356,18 @@ def __init__(
) -> None:
super().__init__(name, client)

@property
def eot_token(self) -> str:
return "<|eot_id|>"

def _add_eot_token_to_stop_sequences(self, input: CompleteInput) -> CompleteInput:
# remove this once the API supports the llama-3 EOT_TOKEN
params = input.__dict__
if isinstance(params["stop_sequences"], list):
if self.EOT_TOKEN not in params["stop_sequences"]:
params["stop_sequences"].append(self.EOT_TOKEN)
if self.eot_token not in params["stop_sequences"]:
params["stop_sequences"].append(self.eot_token)
else:
params["stop_sequences"] = [self.EOT_TOKEN]
params["stop_sequences"] = [self.eot_token]
return CompleteInput(**params)

def complete(self, input: CompleteInput, tracer: Tracer) -> CompleteOutput:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _log_probs_per_label(

def _prepare_label_for_echo_task(self, label: str) -> str:
label = label if re.match(r"^\s+", label) else f" {label}"
return label + "<|endoftext|>"
return label + self._model.eot_token

def _compute_scores(
self,
Expand Down

0 comments on commit a87e5cc

Please sign in to comment.