Coverage for promplate/llm/openai/v1.py: 0%

120 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-08 04:34 +0800

1from copy import copy 

2from functools import cached_property 

3from types import MappingProxyType 

4from typing import TYPE_CHECKING, Any, Callable, ParamSpec, TypeVar 

5 

6from openai import AsyncClient, Client # type: ignore 

7 

8from ...prompt.chat import Message, ensure 

9from ...prompt.utils import _get_aclient, _get_client, get_user_agent 

10from ..base import * 

11 

12P = ParamSpec("P") 

13T = TypeVar("T") 

14 

15 

16class Config(Configurable): 

17 def __init__(self, **config): 

18 super().__init__(**config) 

19 self._run_config = {} 

20 

21 def bind(self, **run_config): 

22 obj = copy(self) 

23 obj._run_config = self._run_config | run_config 

24 return obj 

25 

26 @cached_property 

27 def _user_agent(self): 

28 from openai.version import VERSION 

29 

30 return get_user_agent(self, ("OpenAI", VERSION)) 

31 

32 @property 

33 def _config(self): # type: ignore 

34 ua_header = {"User-Agent": self._user_agent} 

35 config = dict(super()._config) 

36 config["default_headers"] = config.get("default_headers", {}) | ua_header 

37 return MappingProxyType(config) 

38 

39 @cached_property 

40 def _client(self): 

41 if "http_client" in self._config: 

42 return Client(**self._config) 

43 else: 

44 return Client(**self._config, http_client=_get_client()) 

45 

46 @cached_property 

47 def _aclient(self): 

48 if "http_client" in self._config: 

49 return AsyncClient(**self._config) 

50 else: 

51 return AsyncClient(**self._config, http_client=_get_aclient()) 

52 

53 

54if TYPE_CHECKING: 

55 

56 def same_params_as(_: Callable[P, Any]) -> Callable[[Callable[..., None]], Callable[P, None]]: ... 

57 

58 class ClientConfig(Config): 

59 @same_params_as(Client) 

60 def __init__(self, **config): ... 

61 

62 class AsyncClientConfig(Config): 

63 @same_params_as(AsyncClient) 

64 def __init__(self, **config): ... 

65 

66else: 

67 ClientConfig = AsyncClientConfig = Config 

68 

69 

70class TextComplete(ClientConfig): 

71 def __call__(self, text: str, /, **config): 

72 config = self._run_config | config | {"stream": False, "prompt": text} 

73 result = self._client.completions.create(**config) 

74 return result.choices[0].text 

75 

76 

77class AsyncTextComplete(AsyncClientConfig): 

78 async def __call__(self, text: str, /, **config): 

79 config = self._run_config | config | {"stream": False, "prompt": text} 

80 result = await self._aclient.completions.create(**config) 

81 return result.choices[0].text 

82 

83 

84class TextGenerate(ClientConfig): 

85 def __call__(self, text: str, /, **config): 

86 config = self._run_config | config | {"stream": True, "prompt": text} 

87 stream = self._client.completions.create(**config) 

88 for event in stream: 

89 try: 

90 yield event.choices[0].text 

91 except AttributeError: 

92 pass 

93 

94 

95class AsyncTextGenerate(AsyncClientConfig): 

96 async def __call__(self, text: str, /, **config): 

97 config = self._run_config | config | {"stream": True, "prompt": text} 

98 stream = await self._aclient.completions.create(**config) 

99 async for event in stream: 

100 try: 

101 yield event.choices[0].text 

102 except AttributeError: 

103 pass 

104 

105 

106class ChatComplete(ClientConfig): 

107 def __call__(self, messages: list[Message] | str, /, **config): 

108 messages = ensure(messages) 

109 config = self._run_config | config | {"stream": False, "messages": messages} 

110 result = self._client.chat.completions.create(**config) 

111 return result.choices[0].message.content 

112 

113 

114class AsyncChatComplete(AsyncClientConfig): 

115 async def __call__(self, messages: list[Message] | str, /, **config): 

116 messages = ensure(messages) 

117 config = self._run_config | config | {"stream": False, "messages": messages} 

118 result = await self._aclient.chat.completions.create(**config) 

119 return result.choices[0].message.content 

120 

121 

122class ChatGenerate(ClientConfig): 

123 def __call__(self, messages: list[Message] | str, /, **config): 

124 messages = ensure(messages) 

125 config = self._run_config | config | {"stream": True, "messages": messages} 

126 stream = self._client.chat.completions.create(**config) 

127 for event in stream: 

128 try: 

129 yield event.choices[0].delta.content or "" 

130 except AttributeError: 

131 pass 

132 

133 

134class AsyncChatGenerate(AsyncClientConfig): 

135 async def __call__(self, messages: list[Message] | str, /, **config): 

136 messages = ensure(messages) 

137 config = self._run_config | config | {"stream": True, "messages": messages} 

138 stream = await self._aclient.chat.completions.create(**config) 

139 async for event in stream: 

140 try: 

141 yield event.choices[0].delta.content or "" 

142 except AttributeError: 

143 pass 

144 

145 

146class SyncTextOpenAI(ClientConfig, LLM): 

147 complete = TextComplete.__call__ # type: ignore 

148 generate = TextGenerate.__call__ # type: ignore 

149 

150 

151class AsyncTextOpenAI(AsyncClientConfig, LLM): 

152 complete = AsyncTextComplete.__call__ # type: ignore 

153 generate = AsyncTextGenerate.__call__ # type: ignore 

154 

155 

156class SyncChatOpenAI(ClientConfig, LLM): 

157 complete = ChatComplete.__call__ # type: ignore 

158 generate = ChatGenerate.__call__ # type: ignore 

159 

160 

161class AsyncChatOpenAI(AsyncClientConfig, LLM): 

162 complete = AsyncChatComplete.__call__ # type: ignore 

163 generate = AsyncChatGenerate.__call__ # type: ignore 

164 

165 

166__all__ = ( 

167 "TextComplete", 

168 "AsyncTextComplete", 

169 "TextGenerate", 

170 "AsyncTextGenerate", 

171 "ChatComplete", 

172 "AsyncChatComplete", 

173 "ChatGenerate", 

174 "AsyncChatGenerate", 

175 "SyncTextOpenAI", 

176 "AsyncTextOpenAI", 

177 "SyncChatOpenAI", 

178 "AsyncChatOpenAI", 

179)