Skip to content

Commit

Permalink
Linting
Browse files Browse the repository at this point in the history
  • Loading branch information
Ram authored and Ram committed Apr 26, 2024
1 parent 8633857 commit 2470724
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/axolotl/cli/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_chatml_template,
register_llama3_template,
)

Expand All @@ -46,7 +46,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
register_chatml_template(parsed_cfg.default_system_message)
else:
register_chatml_template()

if parsed_cfg.chat_template == "llama3" and parsed_cfg.default_system_message:
LOG.info(
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_chatml_template,
register_llama3_template,
)
from axolotl.train import train
Expand Down Expand Up @@ -49,7 +49,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
register_chatml_template(cfg.default_system_message)
else:
register_chatml_template()

if cfg.chat_template == "llama3" and cfg.default_system_message:
LOG.info(
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/prompt_strategies/sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def register_chatml_template(system_message=None):
)
)


def register_llama3_template(system_message=None):
system_message = system_message or "You are a helpful assistant."
register_conv_template(
Expand All @@ -52,6 +53,7 @@ def register_llama3_template(system_message=None):
)
)


def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class ChatTemplate(str, Enum):
inst = "inst" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name


class LoftQConfig(BaseModel):
Expand Down

0 comments on commit 2470724

Please sign in to comment.