Coverage for promplate/llm/openai/v1.py: 0%
120 statements
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-09 22:54 +0800
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-09 22:54 +0800
1from copy import copy
2from functools import cached_property
3from types import MappingProxyType
4from typing import TYPE_CHECKING, Any, Callable, ParamSpec, TypeVar
6from openai import AsyncClient, Client # type: ignore
8from ...prompt.chat import Message, ensure
9from ...prompt.utils import _get_aclient, _get_client, get_user_agent
10from ..base import *
12P = ParamSpec("P")
13T = TypeVar("T")
16class Config(Configurable):
17 def __init__(self, **config):
18 super().__init__(**config)
19 self._run_config = {}
21 def bind(self, **run_config):
22 obj = copy(self)
23 obj._run_config = self._run_config | run_config
24 return obj
26 @cached_property
27 def _user_agent(self):
28 from openai.version import VERSION
30 return get_user_agent(self, ("OpenAI", VERSION))
32 @property
33 def _config(self): # type: ignore
34 ua_header = {"User-Agent": self._user_agent}
35 config = dict(super()._config)
36 config["default_headers"] = config.get("default_headers", {}) | ua_header
37 return MappingProxyType(config)
39 @cached_property
40 def _client(self):
41 if "http_client" in self._config:
42 return Client(**self._config)
43 else:
44 return Client(**self._config, http_client=_get_client())
46 @cached_property
47 def _aclient(self):
48 if "http_client" in self._config:
49 return AsyncClient(**self._config)
50 else:
51 return AsyncClient(**self._config, http_client=_get_aclient())
54if TYPE_CHECKING:
56 def same_params_as(_: Callable[P, Any]) -> Callable[[Callable[..., None]], Callable[P, None]]: ...
58 class ClientConfig(Config):
59 @same_params_as(Client)
60 def __init__(self, **config): ...
62 class AsyncClientConfig(Config):
63 @same_params_as(AsyncClient)
64 def __init__(self, **config): ...
66else:
67 ClientConfig = AsyncClientConfig = Config
70class TextComplete(ClientConfig):
71 def __call__(self, text: str, /, **config):
72 config = self._run_config | config | {"stream": False, "prompt": text}
73 result = self._client.completions.create(**config)
74 return result.choices[0].text
77class AsyncTextComplete(AsyncClientConfig):
78 async def __call__(self, text: str, /, **config):
79 config = self._run_config | config | {"stream": False, "prompt": text}
80 result = await self._aclient.completions.create(**config)
81 return result.choices[0].text
84class TextGenerate(ClientConfig):
85 def __call__(self, text: str, /, **config):
86 config = self._run_config | config | {"stream": True, "prompt": text}
87 stream = self._client.completions.create(**config)
88 for event in stream:
89 try:
90 yield event.choices[0].text
91 except AttributeError:
92 pass
95class AsyncTextGenerate(AsyncClientConfig):
96 async def __call__(self, text: str, /, **config):
97 config = self._run_config | config | {"stream": True, "prompt": text}
98 stream = await self._aclient.completions.create(**config)
99 async for event in stream:
100 try:
101 yield event.choices[0].text
102 except AttributeError:
103 pass
106class ChatComplete(ClientConfig):
107 def __call__(self, messages: list[Message] | str, /, **config):
108 messages = ensure(messages)
109 config = self._run_config | config | {"stream": False, "messages": messages}
110 result = self._client.chat.completions.create(**config)
111 return result.choices[0].message.content
114class AsyncChatComplete(AsyncClientConfig):
115 async def __call__(self, messages: list[Message] | str, /, **config):
116 messages = ensure(messages)
117 config = self._run_config | config | {"stream": False, "messages": messages}
118 result = await self._aclient.chat.completions.create(**config)
119 return result.choices[0].message.content
122class ChatGenerate(ClientConfig):
123 def __call__(self, messages: list[Message] | str, /, **config):
124 messages = ensure(messages)
125 config = self._run_config | config | {"stream": True, "messages": messages}
126 stream = self._client.chat.completions.create(**config)
127 for event in stream:
128 try:
129 yield event.choices[0].delta.content or ""
130 except AttributeError:
131 pass
134class AsyncChatGenerate(AsyncClientConfig):
135 async def __call__(self, messages: list[Message] | str, /, **config):
136 messages = ensure(messages)
137 config = self._run_config | config | {"stream": True, "messages": messages}
138 stream = await self._aclient.chat.completions.create(**config)
139 async for event in stream:
140 try:
141 yield event.choices[0].delta.content or ""
142 except AttributeError:
143 pass
146class SyncTextOpenAI(ClientConfig, LLM):
147 complete = TextComplete.__call__ # type: ignore
148 generate = TextGenerate.__call__ # type: ignore
151class AsyncTextOpenAI(AsyncClientConfig, LLM):
152 complete = AsyncTextComplete.__call__ # type: ignore
153 generate = AsyncTextGenerate.__call__ # type: ignore
156class SyncChatOpenAI(ClientConfig, LLM):
157 complete = ChatComplete.__call__ # type: ignore
158 generate = ChatGenerate.__call__ # type: ignore
161class AsyncChatOpenAI(AsyncClientConfig, LLM):
162 complete = AsyncChatComplete.__call__ # type: ignore
163 generate = AsyncChatGenerate.__call__ # type: ignore
166__all__ = (
167 "TextComplete",
168 "AsyncTextComplete",
169 "TextGenerate",
170 "AsyncTextGenerate",
171 "ChatComplete",
172 "AsyncChatComplete",
173 "ChatGenerate",
174 "AsyncChatGenerate",
175 "SyncTextOpenAI",
176 "AsyncTextOpenAI",
177 "SyncChatOpenAI",
178 "AsyncChatOpenAI",
179)