Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add LLaMA-3 instruct prompt strategies for fine-tuning #1553

Merged
28 changes: 20 additions & 8 deletions src/axolotl/cli/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
)
from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.sharegpt import register_chatml_template
from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
)

LOG = logging.getLogger("axolotl.cli.preprocess")

Expand All @@ -36,13 +39,22 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
return_remaining_strings=True
)

if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message:
LOG.info(
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
)
register_chatml_template(parsed_cfg.default_system_message)
else:
register_chatml_template()
if parsed_cfg.chat_template == "chatml":
if parsed_cfg.default_system_message:
LOG.info(
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
)
register_chatml_template(parsed_cfg.default_system_message)
else:
register_chatml_template()
elif parsed_cfg.chat_template == "llama3":
if parsed_cfg.default_system_message:
LOG.info(
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
)
register_llama3_template(parsed_cfg.default_system_message)
else:
register_llama3_template()

if not parsed_cfg.dataset_prepared_path:
msg = (
Expand Down
13 changes: 12 additions & 1 deletion src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.prompt_strategies.sharegpt import register_chatml_template
from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
)
from axolotl.train import train

LOG = logging.getLogger("axolotl.cli.train")
Expand Down Expand Up @@ -47,6 +50,14 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
else:
register_chatml_template()

if cfg.chat_template == "llama3" and cfg.default_system_message:
LOG.info(
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
)
register_llama3_template(cfg.default_system_message)
else:
register_llama3_template()

if cfg.rl: # and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else:
Expand Down
20 changes: 18 additions & 2 deletions src/axolotl/prompt_strategies/sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def register_chatml_template(system_message=None):
name="chatml",
system_template="<|im_start|>system\n{system_message}",
system_message=system_message,
roles=["<|im_start|>user", "<|im_start|>assistant"],
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
)
Expand All @@ -32,13 +32,29 @@ def register_chatml_template(system_message=None):
name="chatml_glaive",
system_template="<|im_start|>system\n{system_message}",
system_message=system_message,
roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"],
roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"),
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
)
)


def register_llama3_template(system_message=None):
system_message = system_message or "You are a helpful assistant."
register_conv_template(
Conversation(
name="llama3",
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
system_message=system_message,
roles=("user", "assistant"),
sep_style=SeparatorStyle.LLAMA3,
sep="",
stop_str="<|eot_id|>",
stop_token_ids=[128001, 128009],
)
)


def build_loader(
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
prompter_cls: Type["ShareGPTPrompterV2"],
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def __repr__(self) -> str:
"chatml": "<|im_start|>{ROLE}",
"zephyr": "<|{ROLE}|>",
"vicuna_v1.1": "{ROLE}",
"llama3": "<|start_header_id|>{ROLE}<|end_header_id|>",
}


Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def chat_templates(user_choice: str):
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% else %}{{ eos_token }}{% endif %}",
}

if user_choice in templates:
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class ChatTemplate(str, Enum):
inst = "inst" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name


class LoftQConfig(BaseModel):
Expand Down
50 changes: 49 additions & 1 deletion tests/prompt_strategies/test_sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
GlaiveShareGPTPromptTokenizingStrategy,
SimpleShareGPTPromptTokenizingStrategy,
register_chatml_template,
register_llama3_template,
)
from axolotl.prompters import ShareGPTPrompterV2

register_chatml_template()
register_llama3_template()


@pytest.fixture(name="sharegpt_dataset")
Expand Down Expand Up @@ -115,7 +117,53 @@ def fixture_tokenizer():
return tokenizer


class TestSharegpt:
@pytest.fixture(name="llama3_tokenizer")
def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
tokenizer.eos_token = "<|eot_id|>"

return tokenizer


class TestSharegptLlama3:
"""Test class for ShareGPT style datasets with llama-3 prompts"""

def test_tokenization(self, sharegpt_dataset, llama3_tokenizer):
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation="llama3",
role_key_model=None,
role_key_human=None,
),
llama3_tokenizer,
False, # train_on_inputs
2048, # sequence_len
)

dataset_wrapper = TokenizedPromptDataset(
strategy, sharegpt_dataset, process_count=1
)

input_ids = dataset_wrapper[0]["input_ids"]

# fmt: off
assert input_ids == [
128000, # bos
128006, 9125, 128007, # system header
271, 31724, 128009, # sys prompt, eot
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot
128006, 882, 128007,
271, 19045, 29474, 128009,
128006, 78191, 128007,
271, 19045, 29474, 128009,
]
# fmt: on


class TestSharegptChatML:
"""
Test class for sharegpt prompter
"""
Expand Down
Loading