Coverage for promplate/llm/openai/v0.py: 0%
75 statements
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-09 22:54 +0800
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-09 22:54 +0800
1from importlib.metadata import metadata
2from typing import TYPE_CHECKING, Any
4import openai
5from openai import ChatCompletion, Completion # type: ignore
7from ...prompt.chat import Message, ensure
8from ..base import *
10meta = metadata("promplate")
12if openai.api_key is None:
13 openai.api_key = ""
15openai.app_info = (openai.app_info or {}) | { # type: ignore
16 "name": "Promplate",
17 "version": meta["version"],
18 "url": meta["home-page"],
19}
22if TYPE_CHECKING:
24 class Config(Configurable):
25 def __init__(
26 self,
27 model: str,
28 temperature: float | int | None = None,
29 top_p: float | int | None = None,
30 stop: str | list[str] | None = None,
31 max_tokens: int | None = None,
32 api_key: str | None = None,
33 api_base: str | None = None,
34 **other_config,
35 ):
36 self.model = model
37 self.temperature = temperature
38 self.top_p = top_p
39 self.stop = stop
40 self.max_tokens = max_tokens
41 self.api_key = api_key
42 self.api_base = api_base
44 for key, val in other_config.items():
45 setattr(self, key, val)
47 def __setattr__(self, *_): ...
49 def __getattr__(self, _): ...
51else:
52 Config = Configurable
55class TextComplete(Config, Complete):
56 def __call__(self, text: str, /, **config):
57 config = self._config | config | {"stream": False, "prompt": text}
58 result: Any = Completion.create(**config)
59 return result["choices"][0]["text"]
62class AsyncTextComplete(Config, AsyncComplete):
63 async def __call__(self, text: str, /, **config):
64 config = self._config | config | {"stream": False, "prompt": text}
65 result: Any = await Completion.acreate(**config)
66 return result["choices"][0]["text"]
69class TextGenerate(Config, Generate):
70 def __call__(self, text: str, /, **config):
71 config = self._config | config | {"stream": True, "prompt": text}
72 stream: Any = Completion.create(**config)
73 for event in stream:
74 yield event["choices"][0]["text"]
77class AsyncTextGenerate(Config, AsyncGenerate):
78 async def __call__(self, text: str, /, **config):
79 config = self._config | config | {"stream": True, "prompt": text}
80 stream: Any = await Completion.acreate(**config)
81 async for event in stream:
82 yield event["choices"][0]["text"]
85class ChatComplete(Config, Complete):
86 def __call__(self, messages: list[Message] | str, /, **config):
87 messages = ensure(messages)
88 config = self._config | config | {"stream": False, "messages": messages}
89 result: Any = ChatCompletion.create(**config)
90 return result["choices"][0]["message"]["content"]
93class AsyncChatComplete(Config, AsyncComplete):
94 async def __call__(self, messages: list[Message] | str, /, **config):
95 messages = ensure(messages)
96 config = self._config | config | {"stream": False, "messages": messages}
97 result: Any = await ChatCompletion.acreate(**config)
98 return result["choices"][0]["message"]["content"]
101class ChatGenerate(Config, Generate):
102 def __call__(self, messages: list[Message] | str, /, **config):
103 messages = ensure(messages)
104 config = self._config | config | {"stream": True, "messages": messages}
105 stream: Any = ChatCompletion.create(**config)
106 for event in stream:
107 delta: dict = event["choices"][0]["delta"]
108 yield delta.get("content", "")
111class AsyncChatGenerate(Config, AsyncGenerate):
112 async def __call__(self, messages: list[Message] | str, /, **config):
113 messages = ensure(messages)
114 config = self._config | config | {"stream": True, "messages": messages}
115 stream: Any = await ChatCompletion.acreate(**config)
116 async for event in stream:
117 delta: dict = event["choices"][0]["delta"]
118 yield delta.get("content", "")