From a059a2928a25ee314139de5cf5f1cf3d6ee3aa1c Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Mon, 7 Oct 2024 15:03:14 -0700 Subject: [PATCH] sanitize path --- torchtune/_cli/run.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/torchtune/_cli/run.py b/torchtune/_cli/run.py index 15a5082895..3f09f4f2ba 100644 --- a/torchtune/_cli/run.py +++ b/torchtune/_cli/run.py @@ -123,6 +123,18 @@ def _get_recipe(self, recipe_str: str) -> Optional[Recipe]: if recipe.name == recipe_str: return recipe + def _convert_to_dotpath(self, recipe_path: str) -> str: + """Convert a custom recipe path to a dot path that can be run as a module. + + Args: + recipe_path (str): The path of the recipe. + + Returns: + The dot path of the recipe. + """ + filepath, _ = os.path.splitext(recipe_path) + return filepath.replace("/", ".") + def _get_config( self, config_str: str, specific_recipe: Optional[Recipe] ) -> Optional[Config]: @@ -163,7 +175,7 @@ def _run_cmd(self, args: argparse.Namespace): # Get recipe path recipe = self._get_recipe(args.recipe) if recipe is None: - recipe_path = args.recipe + recipe_path = self._convert_to_dotpath(args.recipe) is_builtin = False else: recipe_path = str(ROOT / "recipes" / recipe.file_path)