diff --git a/poetry.lock b/poetry.lock index dae735e..8046aac 100644 --- a/poetry.lock +++ b/poetry.lock @@ -531,6 +531,17 @@ files = [ {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, ] +[[package]] +name = "mistletoe" +version = "1.4.0" +description = "A fast, extensible Markdown parser in pure Python." +optional = false +python-versions = "~=3.5" +files = [ + {file = "mistletoe-1.4.0-py3-none-any.whl", hash = "sha256:44a477803861de1237ba22e375c6b617690a31d2902b47279d1f8f7ed498a794"}, + {file = "mistletoe-1.4.0.tar.gz", hash = "sha256:1630f906e5e4bbe66fdeb4d29d277e2ea515d642bb18a9b49b136361a9818c9d"}, +] + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -1149,4 +1160,4 @@ langchain = ["langchain-core"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.14" -content-hash = "c037d2de682d78f035780ca79931f8267026d659c48d10bcbacc87cc16b91101" +content-hash = "939f562e03e07e01ef7d0de0bda236525830c552a443b2a1b7522264a89240ff" diff --git a/promptgit/prompt.py b/promptgit/prompt.py index f2438ec..1412dd5 100644 --- a/promptgit/prompt.py +++ b/promptgit/prompt.py @@ -10,9 +10,9 @@ from pathlib import Path from pydantic import BaseModel, field_validator, model_validator - -MULTI_LINE = ["prmpt", "description"] - +import mistletoe +from mistletoe.block_token import Heading, Paragraph +from mistletoe.span_token import LineBreak class PromptLocation(NamedTuple): """ @@ -51,54 +51,6 @@ def __str__(self): else: return f"{self.name}" - -def parse_md(text: str) -> Dict: - current_section = None - content = {} - - def safe_add(data, key, line): - if key in data: - if isinstance(data[key], str): - data[key] = data[key] + "\n" + line - elif isinstance(data[key], list): - data[key].append(line) - else: - raise NotImplementedError - else: - if key in "models": - data[key] = [line] - else: - data[key] = line - return data - - def standardize_sections(key: str) -> str: - return ( - key.replace("#", "") - .strip() - .lower() - .replace(" ", "_") - .replace("-", "_") - ) - - for line in text.splitlines(): - if line.strip() == "": - if current_section and current_section in MULTI_LINE: - content = safe_add(content, current_section, line) - continue - else: - continue - if line[0] == "#": - current_section = standardize_sections(line) - continue - if not current_section: - current_section = "prompt" - if current_section in content: - content[current_section] = content[current_section] + "\n" + line - else: - content[current_section] = line - - return content - class PromptTurn(BaseModel): role: Literal['system', 'user', 'human', 'model', 'ai'] content: str @@ -164,51 +116,36 @@ def from_yaml(cls, content: str): @classmethod def from_md(cls, content: str): - current_section = None - fields = {} - - def safe_add(data, key, line): - if key in data: - if isinstance(data[key], str): - data[key] = data[key] + "\n" + line - elif isinstance(data[key], list): - data[key].append(line) - else: - raise NotImplementedError - else: - if key in "models": - data[key] = [line] - else: - data[key] = line - return data - - def standardize_sections(key: str) -> str: - return ( - key.replace("#", "") - .strip() - .lower() - .replace(" ", "_") - .replace("-", "_") - ) - - for line in content.splitlines(): - if line.strip() == "": - if current_section and current_section in MULTI_LINE: - fields = safe_add(fields, current_section, line) - continue + + def parse_children(children): + + max_level = min([item.level for item in children if isinstance(item, Heading)]) + + top_level = [(item, idx) for (idx, item) in enumerate(children) if isinstance(item, Heading) and item.level == max_level] + + response = {} + + for i, (heading, idx) in enumerate(top_level): + if isinstance(children[idx+1], Paragraph): + content = '' + for child in children[idx+1].children: + if isinstance(child, LineBreak): + content += '\n' + else: + try: + content += child.content + except AttributeError: + pass else: - continue - if line[0] == "#": - current_section = standardize_sections(line) - continue - if not current_section: - current_section = "prompt" - if current_section in fields: - fields[current_section] = fields[current_section] + "\n" + line - else: - fields[current_section] = line - - return cls(**fields) + print(f'{idx+1} - {top_level[i+1][1]}') + content = parse_children(children[idx+1:top_level[i+1][1]]) + response[heading.children[0].content.strip().lower().replace(" ", "_").replace("-", "_")] = content + + return response + + fields = parse_children(mistletoe.Document(content).children) + + return cls.model_validate(fields) def as_langchain(self): diff --git a/pyproject.toml b/pyproject.toml index 3962ec7..d2bad6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ GitPython = ">=3.0.0,<4.0.0" urllib3 = ">=2.2.2,<3.0.0" langchain-core = { version = ">=0.3.0,<0.4.0", optional = true } pyyaml = ">=6.0.0" +mistletoe = "^1.4.0" [tool.poetry.extras] langchain = [ "langchain-core" ]