Coverage for promplate/llm/openai/v1.py: 0%

113 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-02-07 15:14 +0800

1from contextlib import suppress 

2from copy import copy 

3from functools import cached_property 

4from types import MappingProxyType 

5from typing import TYPE_CHECKING, Any, Callable, ParamSpec, TypeVar 

6 

7from openai import AsyncClient, Client # type: ignore 

8 

9from ...prompt.chat import Message, ensure 

10from ...prompt.utils import _get_aclient, _get_client, get_user_agent 

11from ..base import * 

12 

13P = ParamSpec("P") 

14T = TypeVar("T") 

15 

16 

17class Config(Configurable): 

18 def __init__(self, **config): 

19 super().__init__(**config) 

20 self._run_config = {} 

21 

22 def bind(self, **run_config): 

23 obj = copy(self) 

24 obj._run_config = self._run_config | run_config 

25 return obj 

26 

27 @cached_property 

28 def _user_agent(self): 

29 from openai.version import VERSION 

30 

31 return get_user_agent(self, ("OpenAI", VERSION)) 

32 

33 @property 

34 def _config(self): # type: ignore 

35 ua_header = {"User-Agent": self._user_agent} 

36 config = dict(super()._config) 

37 config["default_headers"] = config.get("default_headers", {}) | ua_header 

38 return MappingProxyType(config) 

39 

40 @cached_property 

41 def _client(self): 

42 if "http_client" in self._config: 

43 return Client(**self._config) 

44 else: 

45 return Client(**self._config, http_client=_get_client()) 

46 

47 @cached_property 

48 def _aclient(self): 

49 if "http_client" in self._config: 

50 return AsyncClient(**self._config) 

51 else: 

52 return AsyncClient(**self._config, http_client=_get_aclient()) 

53 

54 

55if TYPE_CHECKING: 

56 

57 def same_params_as(_: Callable[P, Any]) -> Callable[[Callable[..., None]], Callable[P, None]]: ... 

58 

59 class ClientConfig(Config): 

60 @same_params_as(Client) 

61 def __init__(self, **config): ... 

62 

63 class AsyncClientConfig(Config): 

64 @same_params_as(AsyncClient) 

65 def __init__(self, **config): ... 

66 

67else: 

68 ClientConfig = AsyncClientConfig = Config 

69 

70 

71class TextComplete(ClientConfig): 

72 def __call__(self, text: str, /, **config): 

73 config = self._run_config | config | {"stream": False, "prompt": text} 

74 result = self._client.completions.create(**config) 

75 return result.choices[0].text 

76 

77 

78class AsyncTextComplete(AsyncClientConfig): 

79 async def __call__(self, text: str, /, **config): 

80 config = self._run_config | config | {"stream": False, "prompt": text} 

81 result = await self._aclient.completions.create(**config) 

82 return result.choices[0].text 

83 

84 

85class TextGenerate(ClientConfig): 

86 def __call__(self, text: str, /, **config): 

87 config = self._run_config | config | {"stream": True, "prompt": text} 

88 stream = self._client.completions.create(**config) 

89 for event in stream: 

90 with suppress(AttributeError, IndexError): 

91 yield event.choices[0].text 

92 

93 

94class AsyncTextGenerate(AsyncClientConfig): 

95 async def __call__(self, text: str, /, **config): 

96 config = self._run_config | config | {"stream": True, "prompt": text} 

97 stream = await self._aclient.completions.create(**config) 

98 async for event in stream: 

99 with suppress(AttributeError, IndexError): 

100 yield event.choices[0].text 

101 

102 

103class ChatComplete(ClientConfig): 

104 def __call__(self, messages: list[Message] | str, /, **config): 

105 messages = ensure(messages) 

106 config = self._run_config | config | {"stream": False, "messages": messages} 

107 result = self._client.chat.completions.create(**config) 

108 return result.choices[0].message.content 

109 

110 

111class AsyncChatComplete(AsyncClientConfig): 

112 async def __call__(self, messages: list[Message] | str, /, **config): 

113 messages = ensure(messages) 

114 config = self._run_config | config | {"stream": False, "messages": messages} 

115 result = await self._aclient.chat.completions.create(**config) 

116 return result.choices[0].message.content 

117 

118 

119class ChatGenerate(ClientConfig): 

120 def __call__(self, messages: list[Message] | str, /, **config): 

121 messages = ensure(messages) 

122 config = self._run_config | config | {"stream": True, "messages": messages} 

123 stream = self._client.chat.completions.create(**config) 

124 for event in stream: 

125 with suppress(AttributeError, IndexError): 

126 yield event.choices[0].delta.content or "" 

127 

128 

129class AsyncChatGenerate(AsyncClientConfig): 

130 async def __call__(self, messages: list[Message] | str, /, **config): 

131 messages = ensure(messages) 

132 config = self._run_config | config | {"stream": True, "messages": messages} 

133 stream = await self._aclient.chat.completions.create(**config) 

134 async for event in stream: 

135 with suppress(AttributeError, IndexError): 

136 yield event.choices[0].delta.content or "" 

137 

138 

139class SyncTextOpenAI(ClientConfig, LLM): 

140 complete = TextComplete.__call__ # type: ignore 

141 generate = TextGenerate.__call__ # type: ignore 

142 

143 

144class AsyncTextOpenAI(AsyncClientConfig, LLM): 

145 complete = AsyncTextComplete.__call__ # type: ignore 

146 generate = AsyncTextGenerate.__call__ # type: ignore 

147 

148 

149class SyncChatOpenAI(ClientConfig, LLM): 

150 complete = ChatComplete.__call__ # type: ignore 

151 generate = ChatGenerate.__call__ # type: ignore 

152 

153 

154class AsyncChatOpenAI(AsyncClientConfig, LLM): 

155 complete = AsyncChatComplete.__call__ # type: ignore 

156 generate = AsyncChatGenerate.__call__ # type: ignore 

157 

158 

159__all__ = ( 

160 "TextComplete", 

161 "AsyncTextComplete", 

162 "TextGenerate", 

163 "AsyncTextGenerate", 

164 "ChatComplete", 

165 "AsyncChatComplete", 

166 "ChatGenerate", 

167 "AsyncChatGenerate", 

168 "SyncTextOpenAI", 

169 "AsyncTextOpenAI", 

170 "SyncChatOpenAI", 

171 "AsyncChatOpenAI", 

172)