Skip to content

Commit

Permalink
feat: chat prompt for md
Browse files Browse the repository at this point in the history
  • Loading branch information
tvaroska committed Dec 10, 2024
1 parent b85b5bb commit 4c127f9
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 96 deletions.
13 changes: 12 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

127 changes: 32 additions & 95 deletions promptgit/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from pathlib import Path
from pydantic import BaseModel, field_validator, model_validator


MULTI_LINE = ["prmpt", "description"]

import mistletoe
from mistletoe.block_token import Heading, Paragraph
from mistletoe.span_token import LineBreak

class PromptLocation(NamedTuple):
"""
Expand Down Expand Up @@ -51,54 +51,6 @@ def __str__(self):
else:
return f"{self.name}"


def parse_md(text: str) -> Dict:
current_section = None
content = {}

def safe_add(data, key, line):
if key in data:
if isinstance(data[key], str):
data[key] = data[key] + "\n" + line
elif isinstance(data[key], list):
data[key].append(line)
else:
raise NotImplementedError
else:
if key in "models":
data[key] = [line]
else:
data[key] = line
return data

def standardize_sections(key: str) -> str:
return (
key.replace("#", "")
.strip()
.lower()
.replace(" ", "_")
.replace("-", "_")
)

for line in text.splitlines():
if line.strip() == "":
if current_section and current_section in MULTI_LINE:
content = safe_add(content, current_section, line)
continue
else:
continue
if line[0] == "#":
current_section = standardize_sections(line)
continue
if not current_section:
current_section = "prompt"
if current_section in content:
content[current_section] = content[current_section] + "\n" + line
else:
content[current_section] = line

return content

class PromptTurn(BaseModel):
role: Literal['system', 'user', 'human', 'model', 'ai']
content: str
Expand Down Expand Up @@ -164,51 +116,36 @@ def from_yaml(cls, content: str):

@classmethod
def from_md(cls, content: str):
current_section = None
fields = {}

def safe_add(data, key, line):
if key in data:
if isinstance(data[key], str):
data[key] = data[key] + "\n" + line
elif isinstance(data[key], list):
data[key].append(line)
else:
raise NotImplementedError
else:
if key in "models":
data[key] = [line]
else:
data[key] = line
return data

def standardize_sections(key: str) -> str:
return (
key.replace("#", "")
.strip()
.lower()
.replace(" ", "_")
.replace("-", "_")
)

for line in content.splitlines():
if line.strip() == "":
if current_section and current_section in MULTI_LINE:
fields = safe_add(fields, current_section, line)
continue

def parse_children(children):

max_level = min([item.level for item in children if isinstance(item, Heading)])

top_level = [(item, idx) for (idx, item) in enumerate(children) if isinstance(item, Heading) and item.level == max_level]

response = {}

for i, (heading, idx) in enumerate(top_level):
if isinstance(children[idx+1], Paragraph):
content = ''
for child in children[idx+1].children:
if isinstance(child, LineBreak):
content += '\n'
else:
try:
content += child.content
except AttributeError:
pass
else:
continue
if line[0] == "#":
current_section = standardize_sections(line)
continue
if not current_section:
current_section = "prompt"
if current_section in fields:
fields[current_section] = fields[current_section] + "\n" + line
else:
fields[current_section] = line

return cls(**fields)
print(f'{idx+1} - {top_level[i+1][1]}')
content = parse_children(children[idx+1:top_level[i+1][1]])
response[heading.children[0].content.strip().lower().replace(" ", "_").replace("-", "_")] = content

return response

fields = parse_children(mistletoe.Document(content).children)

return cls.model_validate(fields)


def as_langchain(self):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ GitPython = ">=3.0.0,<4.0.0"
urllib3 = ">=2.2.2,<3.0.0"
langchain-core = { version = ">=0.3.0,<0.4.0", optional = true }
pyyaml = ">=6.0.0"
mistletoe = "^1.4.0"

[tool.poetry.extras]
langchain = [ "langchain-core" ]
Expand Down

0 comments on commit 4c127f9

Please sign in to comment.