Skip to content

Commit

Permalink
chore: cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
NJordan72 committed Jan 14, 2025
1 parent d4ac2c8 commit ab270e9
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 45 deletions.
18 changes: 7 additions & 11 deletions src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,10 @@ def __init__(
}

# Default optional fields, kept for compatibility
self._default_optional_fields = [
"tool_calls",
"name",
"tool_call_id"
]

self.optional_message_fields = (
self._default_optional_fields + (optional_message_fields or [])
self._default_optional_fields = ["tool_calls", "name", "tool_call_id"]

self.optional_message_fields = self._default_optional_fields + (
optional_message_fields or []
)

self.message_field_role = message_field_role
Expand Down Expand Up @@ -232,7 +228,7 @@ def messages(self):
def messages(self, messages):
self._messages = messages

def tokenize_prompt(self, prompt: dict):
def tokenize_prompt(self, prompt):
# Old simple legacy behavior that works reliably.
if (
not self.roles_to_train
Expand Down Expand Up @@ -434,7 +430,7 @@ def find_turn(self, turns: list[dict], turn_idx: int):

return start_idx, end_idx

def get_conversation_thread(self, prompt: dict):
def get_conversation_thread(self, prompt):
turns = []

for message in prompt[self.messages]:
Expand Down Expand Up @@ -464,7 +460,7 @@ def get_conversation_thread(self, prompt: dict):

return turns

def get_images(self, prompt: dict):
def get_images(self, prompt):
return prompt.get(self.images, None)


Expand Down
1 change: 1 addition & 0 deletions tests/prompt_strategies/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def fixture_llama3_2_vision_with_hardcoded_date() -> str:

return modified_template


@pytest.fixture(name="chat_template_jinja_with_optional_fields")
def fixture_chat_template_jinja_with_optional_fields() -> str:
return """{% for message in messages %}
Expand Down
53 changes: 19 additions & 34 deletions tests/prompt_strategies/test_chat_templates_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,13 +911,11 @@ def verify_labels(labels_span, should_train, context_message):
LOG.debug(f"Final labels: {labels}")
LOG.debug(f"Final input_ids: {input_ids}")


class TestOptionalMessageFields:
"""Test class for optional message fields functionality in chat templates."""

def test_combined_optional_fields_with_template(
self,
request
):
def test_combined_optional_fields_with_template(self, request):
"""Test both default and custom optional fields with template rendering."""
LOG.info("Testing combined optional fields with template")

Expand All @@ -932,19 +930,13 @@ def test_combined_optional_fields_with_template(

test_data = {
"messages": [
{
"role": "system",
"content": "You are an AI assistant."
},
{
"role": "user",
"content": "What is the temperature in Paris?"
},
{"role": "system", "content": "You are an AI assistant."},
{"role": "user", "content": "What is the temperature in Paris?"},
{
"role": "assistant",
"content": "Let me help with that calculation",
"thoughts": "We should take care to convert the temperature to Fahrenheit",
}
},
]
}

Expand All @@ -959,51 +951,44 @@ def test_combined_optional_fields_with_template(
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"]
roles_to_train=["assistant"],
)

res = strategy.tokenize_prompt(test_data)
turns = strategy.get_conversation_thread(test_data)
labels = res["labels"]
input_ids = res["input_ids"]

tokens = []
for _, (input_id, label_id) in enumerate(zip(input_ids, labels)):
decoded_input_token = tokenizer.decode(input_id)
# Choose the color based on whether the label has the ignore value or not
token = f"({label_id}, {decoded_input_token})"
tokens.append(token)

LOG.info("\n".join(tokens))
LOG.info("\n\n\n")

LOG.info(f"Labels: {labels}")
LOG.info(f"Input IDs: {input_ids}")

# Verify both optional fields are in the tokenized output
decoded_output = tokenizer.decode(input_ids)

LOG.info(f"Decoded output: {decoded_output}")
assert "[Thoughts: We should take care to convert the temperature to Fahrenheit]" in decoded_output, "Thoughts not found in output"
assert (
"[Thoughts: We should take care to convert the temperature to Fahrenheit]"
in decoded_output
), "Thoughts not found in output"

start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=0)
turn_labels = labels[start_idx:end_idx]
assert all(label == IGNORE_TOKEN_ID for label in turn_labels), (
"Expected system message content to be not to be trained on"
)
assert all(
label == IGNORE_TOKEN_ID for label in turn_labels
), "Expected system message content to be not to be trained on"

start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=1)
turn_labels = labels[start_idx:end_idx]
assert all(label == IGNORE_TOKEN_ID for label in turn_labels), (
"Expected user message content to be not to be trained on"
)
assert all(
label == IGNORE_TOKEN_ID for label in turn_labels
), "Expected user message content to be not to be trained on"

# Verify all content is properly labeled for assistant message
start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=2)
turn_labels = labels[start_idx:end_idx]
assert not all(label == IGNORE_TOKEN_ID for label in turn_labels), (
"Expected assistant message content to be trained on"
)
assert not all(
label == IGNORE_TOKEN_ID for label in turn_labels
), "Expected assistant message content to be trained on"


if __name__ == "__main__":
Expand Down

0 comments on commit ab270e9

Please sign in to comment.