Coverage for promplate/prompt/chat.py: 80%
78 statements
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-09 22:54 +0800
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-09 22:54 +0800
1from sys import version_info
2from typing import Literal
4from .utils import is_message_start
6Role = Literal["user", "assistant", "system"]
8if version_info >= (3, 12): 8 ↛ 17line 8 didn't jump to line 17 because the condition on line 8 was always true
9 from typing import NotRequired, TypedDict
11 class Message(TypedDict): # type: ignore
12 role: Role
13 content: str
14 name: NotRequired[str]
16else:
17 from typing_extensions import NotRequired, TypedDict
19 class Message(TypedDict):
20 role: Role
21 content: str
22 name: NotRequired[str]
25class MessageBuilder:
26 _initializing = True
27 __slots__ = ("role", "content", "name")
29 def __init__(self, role: Role, /, content: str = "", name: str | None = None):
30 self.role: Role = role
31 self.content = content
32 self.name = name
34 def __repr__(self):
35 if self.name is not None:
36 return f"<| {self.role} {self.name} |>"
37 return f"<| {self.role} |>"
39 def __getitem__(self, key):
40 return getattr(self, key)
42 def __setitem__(self, key, value):
43 return setattr(self, key, value)
45 def __setattr__(self, key, value):
46 if not self._initializing:
47 assert self is not U and self is not A and self is not S
48 assert isinstance(value, str)
49 return super().__setattr__(key, value)
51 def __matmul__(self, name: str):
52 assert isinstance(name, str) and name
53 return self.__class__(self.role, self.content, name)
55 def dict(self) -> Message:
56 if self.name:
57 return {"role": self.role, "content": self.content, "name": self.name}
58 return {"role": self.role, "content": self.content}
60 def __gt__(self, content: str) -> Message:
61 assert isinstance(content, str)
62 if self.name:
63 return {"role": self.role, "content": content, "name": self.name}
64 return {"role": self.role, "content": content}
67U = user = MessageBuilder("user")
68A = assistant = MessageBuilder("assistant")
69S = system = MessageBuilder("system")
70MessageBuilder._initializing = False
73def ensure(text_or_list: list[Message] | str) -> list[Message]:
74 return parse_chat_markup(text_or_list) if isinstance(text_or_list, str) else text_or_list
77def parse_chat_markup(text: str) -> list[Message]:
78 messages = []
79 current_message = None
80 buffer = []
82 for line in text.splitlines():
83 match = is_message_start.match(line)
84 if match:
85 role, name = match.group(1), match.group(2)
87 if current_message:
88 current_message["content"] = "\n".join(buffer)
89 messages.append(current_message)
90 buffer.clear()
92 current_message = {"role": role, "content": ""}
94 if name:
95 current_message["name"] = name
97 elif current_message:
98 buffer.append(line)
100 if current_message:
101 current_message["content"] = "\n".join(buffer)
102 messages.append(current_message)
104 if messages:
105 return messages
106 elif text and text != "\n": 106 ↛ 108line 106 didn't jump to line 108 because the condition on line 106 was always true
107 return [{"role": "user", "content": text.removesuffix("\n")}]
108 return []