Coverage for promplate/prompt/chat.py: 83%

74 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-02-07 15:14 +0800

1from sys import version_info 

2from typing import Literal 

3 

4from .utils import is_message_start 

5 

6Role = Literal["user", "assistant", "system"] 

7 

8if version_info >= (3, 12): 8 ↛ 11line 8 didn't jump to line 11 because the condition on line 8 was always true

9 from typing import NotRequired, TypedDict # type: ignore 

10else: 

11 from typing_extensions import NotRequired, TypedDict 

12 

13 

14class Message(TypedDict): 

15 role: Role 

16 content: str 

17 name: NotRequired[str] 

18 

19 

20class MessageBuilder: 

21 _initializing = True 

22 __slots__ = ("role", "content", "name") 

23 

24 def __init__(self, role: Role, /, content: str = "", name: str | None = None): 

25 self.role: Role = role 

26 self.content = content 

27 self.name = name 

28 

29 def __repr__(self): 

30 if self.name is not None: 

31 return f"<| {self.role} {self.name} |>" 

32 return f"<| {self.role} |>" 

33 

34 def __getitem__(self, key): 

35 return getattr(self, key) 

36 

37 def __setitem__(self, key, value): 

38 return setattr(self, key, value) 

39 

40 def __setattr__(self, key, value): 

41 if not self._initializing: 

42 assert self is not U and self is not A and self is not S 

43 assert isinstance(value, str) 

44 return super().__setattr__(key, value) 

45 

46 def __matmul__(self, name: str): 

47 assert isinstance(name, str) and name 

48 return self.__class__(self.role, self.content, name) 

49 

50 def dict(self) -> Message: 

51 if self.name: 

52 return {"role": self.role, "content": self.content, "name": self.name} 

53 return {"role": self.role, "content": self.content} 

54 

55 def __gt__(self, content: str) -> Message: 

56 assert isinstance(content, str) 

57 if self.name: 

58 return {"role": self.role, "content": content, "name": self.name} 

59 return {"role": self.role, "content": content} 

60 

61 

62U = user = MessageBuilder("user") 

63A = assistant = MessageBuilder("assistant") 

64S = system = MessageBuilder("system") 

65MessageBuilder._initializing = False 

66 

67 

68def ensure(text_or_list: list[Message] | str) -> list[Message]: 

69 return parse_chat_markup(text_or_list) if isinstance(text_or_list, str) else text_or_list 

70 

71 

72def parse_chat_markup(text: str) -> list[Message]: 

73 messages = [] 

74 current_message = None 

75 buffer = [] 

76 

77 for line in text.splitlines(): 

78 match = is_message_start.match(line) 

79 if match: 

80 role, name = match.group(1), match.group(2) 

81 

82 if current_message: 

83 current_message["content"] = "\n".join(buffer) 

84 messages.append(current_message) 

85 buffer.clear() 

86 

87 current_message = {"role": role, "content": ""} 

88 

89 if name: 

90 current_message["name"] = name 

91 

92 elif current_message: 

93 buffer.append(line) 

94 

95 if current_message: 

96 current_message["content"] = "\n".join(buffer) 

97 messages.append(current_message) 

98 

99 if messages: 

100 return messages 

101 elif text and text != "\n": 101 ↛ 103line 101 didn't jump to line 103 because the condition on line 101 was always true

102 return [{"role": "user", "content": text.removesuffix("\n")}] 

103 return []