Coverage for promplate/llm/openai/v1.py: 0%
113 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-02-07 15:14 +0800
« prev ^ index » next coverage.py v7.6.10, created at 2025-02-07 15:14 +0800
1from contextlib import suppress
2from copy import copy
3from functools import cached_property
4from types import MappingProxyType
5from typing import TYPE_CHECKING, Any, Callable, ParamSpec, TypeVar
7from openai import AsyncClient, Client # type: ignore
9from ...prompt.chat import Message, ensure
10from ...prompt.utils import _get_aclient, _get_client, get_user_agent
11from ..base import *
13P = ParamSpec("P")
14T = TypeVar("T")
17class Config(Configurable):
18 def __init__(self, **config):
19 super().__init__(**config)
20 self._run_config = {}
22 def bind(self, **run_config):
23 obj = copy(self)
24 obj._run_config = self._run_config | run_config
25 return obj
27 @cached_property
28 def _user_agent(self):
29 from openai.version import VERSION
31 return get_user_agent(self, ("OpenAI", VERSION))
33 @property
34 def _config(self): # type: ignore
35 ua_header = {"User-Agent": self._user_agent}
36 config = dict(super()._config)
37 config["default_headers"] = config.get("default_headers", {}) | ua_header
38 return MappingProxyType(config)
40 @cached_property
41 def _client(self):
42 if "http_client" in self._config:
43 return Client(**self._config)
44 else:
45 return Client(**self._config, http_client=_get_client())
47 @cached_property
48 def _aclient(self):
49 if "http_client" in self._config:
50 return AsyncClient(**self._config)
51 else:
52 return AsyncClient(**self._config, http_client=_get_aclient())
55if TYPE_CHECKING:
57 def same_params_as(_: Callable[P, Any]) -> Callable[[Callable[..., None]], Callable[P, None]]: ...
59 class ClientConfig(Config):
60 @same_params_as(Client)
61 def __init__(self, **config): ...
63 class AsyncClientConfig(Config):
64 @same_params_as(AsyncClient)
65 def __init__(self, **config): ...
67else:
68 ClientConfig = AsyncClientConfig = Config
71class TextComplete(ClientConfig):
72 def __call__(self, text: str, /, **config):
73 config = self._run_config | config | {"stream": False, "prompt": text}
74 result = self._client.completions.create(**config)
75 return result.choices[0].text
78class AsyncTextComplete(AsyncClientConfig):
79 async def __call__(self, text: str, /, **config):
80 config = self._run_config | config | {"stream": False, "prompt": text}
81 result = await self._aclient.completions.create(**config)
82 return result.choices[0].text
85class TextGenerate(ClientConfig):
86 def __call__(self, text: str, /, **config):
87 config = self._run_config | config | {"stream": True, "prompt": text}
88 stream = self._client.completions.create(**config)
89 for event in stream:
90 with suppress(AttributeError, IndexError):
91 yield event.choices[0].text
94class AsyncTextGenerate(AsyncClientConfig):
95 async def __call__(self, text: str, /, **config):
96 config = self._run_config | config | {"stream": True, "prompt": text}
97 stream = await self._aclient.completions.create(**config)
98 async for event in stream:
99 with suppress(AttributeError, IndexError):
100 yield event.choices[0].text
103class ChatComplete(ClientConfig):
104 def __call__(self, messages: list[Message] | str, /, **config):
105 messages = ensure(messages)
106 config = self._run_config | config | {"stream": False, "messages": messages}
107 result = self._client.chat.completions.create(**config)
108 return result.choices[0].message.content
111class AsyncChatComplete(AsyncClientConfig):
112 async def __call__(self, messages: list[Message] | str, /, **config):
113 messages = ensure(messages)
114 config = self._run_config | config | {"stream": False, "messages": messages}
115 result = await self._aclient.chat.completions.create(**config)
116 return result.choices[0].message.content
119class ChatGenerate(ClientConfig):
120 def __call__(self, messages: list[Message] | str, /, **config):
121 messages = ensure(messages)
122 config = self._run_config | config | {"stream": True, "messages": messages}
123 stream = self._client.chat.completions.create(**config)
124 for event in stream:
125 with suppress(AttributeError, IndexError):
126 yield event.choices[0].delta.content or ""
129class AsyncChatGenerate(AsyncClientConfig):
130 async def __call__(self, messages: list[Message] | str, /, **config):
131 messages = ensure(messages)
132 config = self._run_config | config | {"stream": True, "messages": messages}
133 stream = await self._aclient.chat.completions.create(**config)
134 async for event in stream:
135 with suppress(AttributeError, IndexError):
136 yield event.choices[0].delta.content or ""
139class SyncTextOpenAI(ClientConfig, LLM):
140 complete = TextComplete.__call__ # type: ignore
141 generate = TextGenerate.__call__ # type: ignore
144class AsyncTextOpenAI(AsyncClientConfig, LLM):
145 complete = AsyncTextComplete.__call__ # type: ignore
146 generate = AsyncTextGenerate.__call__ # type: ignore
149class SyncChatOpenAI(ClientConfig, LLM):
150 complete = ChatComplete.__call__ # type: ignore
151 generate = ChatGenerate.__call__ # type: ignore
154class AsyncChatOpenAI(AsyncClientConfig, LLM):
155 complete = AsyncChatComplete.__call__ # type: ignore
156 generate = AsyncChatGenerate.__call__ # type: ignore
159__all__ = (
160 "TextComplete",
161 "AsyncTextComplete",
162 "TextGenerate",
163 "AsyncTextGenerate",
164 "ChatComplete",
165 "AsyncChatComplete",
166 "ChatGenerate",
167 "AsyncChatGenerate",
168 "SyncTextOpenAI",
169 "AsyncTextOpenAI",
170 "SyncChatOpenAI",
171 "AsyncChatOpenAI",
172)