From b85b5bb3f3f3e89c02bea7ea534f98d62964b76c Mon Sep 17 00:00:00 2001 From: Boris Tvaroska Date: Mon, 9 Dec 2024 18:22:33 -0500 Subject: [PATCH] feat: chat prompt - yaml --- promptgit/prompt.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/promptgit/prompt.py b/promptgit/prompt.py index 6eb1f6f..f2438ec 100644 --- a/promptgit/prompt.py +++ b/promptgit/prompt.py @@ -3,12 +3,10 @@ """ -import json import string import yaml -from enum import Enum -from typing import Dict, List, Optional, Union, NamedTuple +from typing import Dict, List, Optional, Union, NamedTuple, Literal from pathlib import Path from pydantic import BaseModel, field_validator, model_validator @@ -101,13 +99,8 @@ def standardize_sections(key: str) -> str: return content -class PromptRoles(Enum): - system = 'system' - human = 'human' - ai = 'ai' - class PromptTurn(BaseModel): - role: PromptRoles + role: Literal['system', 'user', 'human', 'model', 'ai'] content: str class Prompt(BaseModel): @@ -167,7 +160,7 @@ def from_json(cls, content: str): @classmethod def from_yaml(cls, content: str): - return cls(**yaml.load(content)) + return cls.model_validate(yaml.safe_load(content)) @classmethod def from_md(cls, content: str):