Coverage for promplate/prompt/chat.py: 83%
74 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-02-07 15:14 +0800
« prev ^ index » next coverage.py v7.6.10, created at 2025-02-07 15:14 +0800
1from sys import version_info
2from typing import Literal
4from .utils import is_message_start
6Role = Literal["user", "assistant", "system"]
8if version_info >= (3, 12): 8 ↛ 11line 8 didn't jump to line 11 because the condition on line 8 was always true
9 from typing import NotRequired, TypedDict # type: ignore
10else:
11 from typing_extensions import NotRequired, TypedDict
14class Message(TypedDict):
15 role: Role
16 content: str
17 name: NotRequired[str]
20class MessageBuilder:
21 _initializing = True
22 __slots__ = ("role", "content", "name")
24 def __init__(self, role: Role, /, content: str = "", name: str | None = None):
25 self.role: Role = role
26 self.content = content
27 self.name = name
29 def __repr__(self):
30 if self.name is not None:
31 return f"<| {self.role} {self.name} |>"
32 return f"<| {self.role} |>"
34 def __getitem__(self, key):
35 return getattr(self, key)
37 def __setitem__(self, key, value):
38 return setattr(self, key, value)
40 def __setattr__(self, key, value):
41 if not self._initializing:
42 assert self is not U and self is not A and self is not S
43 assert isinstance(value, str)
44 return super().__setattr__(key, value)
46 def __matmul__(self, name: str):
47 assert isinstance(name, str) and name
48 return self.__class__(self.role, self.content, name)
50 def dict(self) -> Message:
51 if self.name:
52 return {"role": self.role, "content": self.content, "name": self.name}
53 return {"role": self.role, "content": self.content}
55 def __gt__(self, content: str) -> Message:
56 assert isinstance(content, str)
57 if self.name:
58 return {"role": self.role, "content": content, "name": self.name}
59 return {"role": self.role, "content": content}
62U = user = MessageBuilder("user")
63A = assistant = MessageBuilder("assistant")
64S = system = MessageBuilder("system")
65MessageBuilder._initializing = False
68def ensure(text_or_list: list[Message] | str) -> list[Message]:
69 return parse_chat_markup(text_or_list) if isinstance(text_or_list, str) else text_or_list
72def parse_chat_markup(text: str) -> list[Message]:
73 messages = []
74 current_message = None
75 buffer = []
77 for line in text.splitlines():
78 match = is_message_start.match(line)
79 if match:
80 role, name = match.group(1), match.group(2)
82 if current_message:
83 current_message["content"] = "\n".join(buffer)
84 messages.append(current_message)
85 buffer.clear()
87 current_message = {"role": role, "content": ""}
89 if name:
90 current_message["name"] = name
92 elif current_message:
93 buffer.append(line)
95 if current_message:
96 current_message["content"] = "\n".join(buffer)
97 messages.append(current_message)
99 if messages:
100 return messages
101 elif text and text != "\n": 101 ↛ 103line 101 didn't jump to line 103 because the condition on line 101 was always true
102 return [{"role": "user", "content": text.removesuffix("\n")}]
103 return []