Coverage for promplate/prompt/chat.py: 80%

78 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-08 04:34 +0800

1from sys import version_info 

2from typing import Literal 

3 

4from .utils import is_message_start 

5 

6Role = Literal["user", "assistant", "system"] 

7 

8if version_info >= (3, 12): 8 ↛ 17line 8 didn't jump to line 17 because the condition on line 8 was always true

9 from typing import NotRequired, TypedDict 

10 

11 class Message(TypedDict): # type: ignore 

12 role: Role 

13 content: str 

14 name: NotRequired[str] 

15 

16else: 

17 from typing_extensions import NotRequired, TypedDict 

18 

19 class Message(TypedDict): 

20 role: Role 

21 content: str 

22 name: NotRequired[str] 

23 

24 

25class MessageBuilder: 

26 _initializing = True 

27 __slots__ = ("role", "content", "name") 

28 

29 def __init__(self, role: Role, /, content: str = "", name: str | None = None): 

30 self.role: Role = role 

31 self.content = content 

32 self.name = name 

33 

34 def __repr__(self): 

35 if self.name is not None: 

36 return f"<| {self.role} {self.name} |>" 

37 return f"<| {self.role} |>" 

38 

39 def __getitem__(self, key): 

40 return getattr(self, key) 

41 

42 def __setitem__(self, key, value): 

43 return setattr(self, key, value) 

44 

45 def __setattr__(self, key, value): 

46 if not self._initializing: 

47 assert self is not U and self is not A and self is not S 

48 assert isinstance(value, str) 

49 return super().__setattr__(key, value) 

50 

51 def __matmul__(self, name: str): 

52 assert isinstance(name, str) and name 

53 return self.__class__(self.role, self.content, name) 

54 

55 def dict(self) -> Message: 

56 if self.name: 

57 return {"role": self.role, "content": self.content, "name": self.name} 

58 return {"role": self.role, "content": self.content} 

59 

60 def __gt__(self, content: str) -> Message: 

61 assert isinstance(content, str) 

62 if self.name: 

63 return {"role": self.role, "content": content, "name": self.name} 

64 return {"role": self.role, "content": content} 

65 

66 

67U = user = MessageBuilder("user") 

68A = assistant = MessageBuilder("assistant") 

69S = system = MessageBuilder("system") 

70MessageBuilder._initializing = False 

71 

72 

73def ensure(text_or_list: list[Message] | str) -> list[Message]: 

74 return parse_chat_markup(text_or_list) if isinstance(text_or_list, str) else text_or_list 

75 

76 

77def parse_chat_markup(text: str) -> list[Message]: 

78 messages = [] 

79 current_message = None 

80 buffer = [] 

81 

82 for line in text.splitlines(): 

83 match = is_message_start.match(line) 

84 if match: 

85 role, name = match.group(1), match.group(2) 

86 

87 if current_message: 

88 current_message["content"] = "\n".join(buffer) 

89 messages.append(current_message) 

90 buffer.clear() 

91 

92 current_message = {"role": role, "content": ""} 

93 

94 if name: 

95 current_message["name"] = name 

96 

97 elif current_message: 

98 buffer.append(line) 

99 

100 if current_message: 

101 current_message["content"] = "\n".join(buffer) 

102 messages.append(current_message) 

103 

104 if messages: 

105 return messages 

106 elif text and text != "\n": 106 ↛ 108line 106 didn't jump to line 108 because the condition on line 106 was always true

107 return [{"role": "user", "content": text.removesuffix("\n")}] 

108 return []