Coverage for promplate/llm/openai/v0.py: 0%

75 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-08 04:34 +0800

1from importlib.metadata import metadata 

2from typing import TYPE_CHECKING, Any 

3 

4import openai 

5from openai import ChatCompletion, Completion # type: ignore 

6 

7from ...prompt.chat import Message, ensure 

8from ..base import * 

9 

10meta = metadata("promplate") 

11 

12if openai.api_key is None: 

13 openai.api_key = "" 

14 

15openai.app_info = (openai.app_info or {}) | { # type: ignore 

16 "name": "Promplate", 

17 "version": meta["version"], 

18 "url": meta["home-page"], 

19} 

20 

21 

22if TYPE_CHECKING: 

23 

24 class Config(Configurable): 

25 def __init__( 

26 self, 

27 model: str, 

28 temperature: float | int | None = None, 

29 top_p: float | int | None = None, 

30 stop: str | list[str] | None = None, 

31 max_tokens: int | None = None, 

32 api_key: str | None = None, 

33 api_base: str | None = None, 

34 **other_config, 

35 ): 

36 self.model = model 

37 self.temperature = temperature 

38 self.top_p = top_p 

39 self.stop = stop 

40 self.max_tokens = max_tokens 

41 self.api_key = api_key 

42 self.api_base = api_base 

43 

44 for key, val in other_config.items(): 

45 setattr(self, key, val) 

46 

47 def __setattr__(self, *_): ... 

48 

49 def __getattr__(self, _): ... 

50 

51else: 

52 Config = Configurable 

53 

54 

55class TextComplete(Config, Complete): 

56 def __call__(self, text: str, /, **config): 

57 config = self._config | config | {"stream": False, "prompt": text} 

58 result: Any = Completion.create(**config) 

59 return result["choices"][0]["text"] 

60 

61 

62class AsyncTextComplete(Config, AsyncComplete): 

63 async def __call__(self, text: str, /, **config): 

64 config = self._config | config | {"stream": False, "prompt": text} 

65 result: Any = await Completion.acreate(**config) 

66 return result["choices"][0]["text"] 

67 

68 

69class TextGenerate(Config, Generate): 

70 def __call__(self, text: str, /, **config): 

71 config = self._config | config | {"stream": True, "prompt": text} 

72 stream: Any = Completion.create(**config) 

73 for event in stream: 

74 yield event["choices"][0]["text"] 

75 

76 

77class AsyncTextGenerate(Config, AsyncGenerate): 

78 async def __call__(self, text: str, /, **config): 

79 config = self._config | config | {"stream": True, "prompt": text} 

80 stream: Any = await Completion.acreate(**config) 

81 async for event in stream: 

82 yield event["choices"][0]["text"] 

83 

84 

85class ChatComplete(Config, Complete): 

86 def __call__(self, messages: list[Message] | str, /, **config): 

87 messages = ensure(messages) 

88 config = self._config | config | {"stream": False, "messages": messages} 

89 result: Any = ChatCompletion.create(**config) 

90 return result["choices"][0]["message"]["content"] 

91 

92 

93class AsyncChatComplete(Config, AsyncComplete): 

94 async def __call__(self, messages: list[Message] | str, /, **config): 

95 messages = ensure(messages) 

96 config = self._config | config | {"stream": False, "messages": messages} 

97 result: Any = await ChatCompletion.acreate(**config) 

98 return result["choices"][0]["message"]["content"] 

99 

100 

101class ChatGenerate(Config, Generate): 

102 def __call__(self, messages: list[Message] | str, /, **config): 

103 messages = ensure(messages) 

104 config = self._config | config | {"stream": True, "messages": messages} 

105 stream: Any = ChatCompletion.create(**config) 

106 for event in stream: 

107 delta: dict = event["choices"][0]["delta"] 

108 yield delta.get("content", "") 

109 

110 

111class AsyncChatGenerate(Config, AsyncGenerate): 

112 async def __call__(self, messages: list[Message] | str, /, **config): 

113 messages = ensure(messages) 

114 config = self._config | config | {"stream": True, "messages": messages} 

115 stream: Any = await ChatCompletion.acreate(**config) 

116 async for event in stream: 

117 delta: dict = event["choices"][0]["delta"] 

118 yield delta.get("content", "")