Coverage for promplate/chain/node.py: 74%
314 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-18 19:52 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-18 19:52 +0800
1from inspect import isclass
2from itertools import accumulate
3from typing import Callable, Mapping, MutableMapping, TypeVar, overload
5from ..llm.base import *
6from ..prompt.template import Context, Loader, SafeChainMapContext, Template
7from .callback import BaseCallback, Callback
8from .utils import accumulate_any, resolve
10C = TypeVar("C", bound="ChainContext")
13class ChainContext(SafeChainMapContext):
14 @overload
15 def __new__(cls): ... 15 ↛ exitline 15 didn't return from function '__new__' because
17 @overload
18 def __new__(cls, least: C, *maps: Mapping) -> C: ... 18 ↛ exitline 18 didn't return from function '__new__' because
20 @overload
21 def __new__(cls, least: MutableMapping | None = None, *maps: Mapping): ... 21 ↛ exitline 21 didn't return from function '__new__' because
23 def __init__(self, least: MutableMapping | None = None, *maps: Mapping):
24 super().__init__({} if least is None else least, *maps) # type: ignore
26 def __new__(cls, *args, **kwargs): # type: ignore
27 try:
28 least = args[0]
29 except IndexError:
30 least = kwargs.get("least")
31 if isinstance(least, cls) and least.__class__ is not cls:
32 return least.__class__(*args, **kwargs)
34 return super().__new__(cls, *args, **kwargs)
36 @classmethod
37 def ensure(cls, context):
38 return context if isinstance(context, cls) else cls(context)
40 @property
41 def result(self):
42 return self.__getitem__("__result__")
44 @result.setter
45 def result(self, result):
46 self.__setitem__("__result__", result)
48 @result.deleter
49 def result(self):
50 self.__delitem__("__result__")
52 def __str__(self):
53 return str({**self})
56Process = Callable[[ChainContext], Context | None]
58AsyncProcess = Callable[[ChainContext], Awaitable[Context | None]]
61class AbstractNode(Protocol):
62 def invoke( 62 ↛ exitline 62 didn't jump to the function exit
63 self,
64 context: Context | None = None,
65 /,
66 complete: Complete | None = None,
67 **config,
68 ) -> ChainContext: ...
70 async def ainvoke( 70 ↛ exitline 70 didn't jump to the function exit
71 self,
72 context: Context | None = None,
73 /,
74 complete: Complete | AsyncComplete | None = None,
75 **config,
76 ) -> ChainContext: ...
78 def stream( 78 ↛ exitline 78 didn't jump to the function exit
79 self,
80 context: Context | None = None,
81 /,
82 generate: Generate | None = None,
83 **config,
84 ) -> Iterable[ChainContext]: ...
86 def astream( 86 ↛ exitline 86 didn't jump to the function exit
87 self,
88 context: Context | None = None,
89 /,
90 generate: Generate | AsyncGenerate | None = None,
91 **config,
92 ) -> AsyncIterable[ChainContext]: ...
94 @classmethod
95 def _get_chain_type(cls):
96 return Chain
98 def __add__(self, chain: "AbstractNode"):
99 if isinstance(chain, Chain): 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true
100 return self._get_chain_type()(self, *chain)
101 return self._get_chain_type()(self, chain)
104def ensure_callbacks(callbacks: list[BaseCallback | type[BaseCallback]]) -> list[BaseCallback]:
105 return [i() if isclass(i) else i for i in callbacks]
108class Interruptable(AbstractNode, Protocol):
109 def _invoke( 109 ↛ exitline 109 didn't jump to the function exit
110 self,
111 context: ChainContext,
112 /,
113 complete: Complete | None,
114 callbacks: list[BaseCallback],
115 **config,
116 ): ...
118 async def _ainvoke( 118 ↛ exitline 118 didn't jump to the function exit
119 self,
120 context: ChainContext,
121 /,
122 complete: Complete | AsyncComplete | None,
123 callbacks: list[BaseCallback],
124 **config,
125 ): ...
127 def _stream( 127 ↛ exitline 127 didn't jump to the function exit
128 self,
129 context: ChainContext,
130 /,
131 generate: Generate | None,
132 callbacks: list[BaseCallback],
133 **config,
134 ) -> Iterable: ...
136 def _astream( 136 ↛ exitline 136 didn't jump to the function exit
137 self,
138 context: ChainContext,
139 /,
140 generate: Generate | AsyncGenerate | None,
141 callbacks: list[BaseCallback],
142 **config,
143 ) -> AsyncIterable: ...
145 callbacks: list[BaseCallback | type[BaseCallback]]
147 def enter(self, context: Context | None, config: Context):
148 callbacks: list[BaseCallback] = ensure_callbacks(self.callbacks)
149 for callback in callbacks:
150 context, config = callback.on_enter(self, context, config)
151 return context, config, callbacks
153 def leave(self, context: ChainContext, config: Context, callbacks: list[BaseCallback]):
154 for callback in reversed(callbacks):
155 context, config = callback.on_leave(self, context, config)
156 return context, config
158 def add_pre_processes(self, *processes: Process | AsyncProcess):
159 self.callbacks.extend(Callback(pre_process=i) for i in processes)
160 return self
162 def add_mid_processes(self, *processes: Process | AsyncProcess):
163 self.callbacks.extend(Callback(mid_process=i) for i in processes)
164 return self
166 def add_end_processes(self, *processes: Process | AsyncProcess):
167 self.callbacks.extend(Callback(end_process=i) for i in processes)
168 return self
170 def add_callbacks(self, *callbacks: BaseCallback | type[BaseCallback]):
171 self.callbacks.extend(callbacks)
172 return self
174 def pre_process(self, process: Process | AsyncProcess):
175 self.add_pre_processes(process)
176 return process
178 def mid_process(self, process: Process | AsyncProcess):
179 self.add_mid_processes(process)
180 return process
182 def end_process(self, process: Process | AsyncProcess):
183 self.add_end_processes(process)
184 return process
186 def callback(self, callback: BaseCallback | type[BaseCallback]):
187 self.add_callbacks(callback)
188 return callback
190 @staticmethod
191 def _apply_pre_processes(context: ChainContext, callbacks: list[BaseCallback]):
192 for callback in callbacks:
193 context |= cast(Context, callback.pre_process(context) or {})
195 @staticmethod
196 def _apply_mid_processes(context: ChainContext, callbacks: list[BaseCallback]):
197 for callback in callbacks:
198 context |= cast(Context, callback.mid_process(context) or {})
200 @staticmethod
201 def _apply_end_processes(context: ChainContext, callbacks: list[BaseCallback]):
202 for callback in reversed(callbacks):
203 context |= cast(Context, callback.end_process(context) or {})
205 @staticmethod
206 async def _apply_async_pre_processes(context: ChainContext, callbacks: list[BaseCallback]):
207 for callback in callbacks: 207 ↛ 208line 207 didn't jump to line 208 because the loop on line 207 never started
208 context |= cast(Context, await resolve(callback.pre_process(context)) or {})
210 @staticmethod
211 async def _apply_async_mid_processes(context: ChainContext, callbacks: list[BaseCallback]):
212 for callback in callbacks: 212 ↛ 213line 212 didn't jump to line 213 because the loop on line 212 never started
213 context |= cast(Context, await resolve(callback.mid_process(context)) or {})
215 @staticmethod
216 async def _apply_async_end_processes(context: ChainContext, callbacks: list[BaseCallback]):
217 for callback in reversed(callbacks): 217 ↛ 218line 217 didn't jump to line 218 because the loop on line 217 never started
218 context |= cast(Context, await resolve(callback.end_process(context)) or {})
220 def invoke(self, context=None, /, complete=None, **config) -> ChainContext:
221 context, config, callbacks = self.enter(context, config)
222 context = ChainContext.ensure(context)
224 try:
225 self._invoke(ChainContext(context, self.context), complete, callbacks, **config)
226 except Jump as jump:
227 context, config = self.leave(context, config, callbacks)
228 if jump.out_of is not None and jump.out_of is not self:
229 raise jump from None
230 if jump.into is not None: 230 ↛ 231line 230 didn't jump to line 231 because the condition on line 230 was never true
231 jump.into.invoke(context, complete, **config)
232 else:
233 context, config = self.leave(context, config, callbacks)
235 return context
237 async def ainvoke(self, context=None, /, complete=None, **config) -> ChainContext:
238 context, config, callbacks = self.enter(context, config)
239 context = ChainContext.ensure(context)
241 try:
242 await self._ainvoke(ChainContext(context, self.context), complete, callbacks, **config)
243 except Jump as jump:
244 context, config = self.leave(context, config, callbacks)
245 if jump.out_of is not None and jump.out_of is not self:
246 raise jump from None
247 if jump.into is not None:
248 await jump.into.ainvoke(context, complete, **config)
249 else:
250 context, config = self.leave(context, config, callbacks)
252 return context
254 def stream(self, context=None, /, generate=None, **config) -> Iterable[ChainContext]:
255 context, config, callbacks = self.enter(context, config)
256 context = ChainContext.ensure(context)
258 try:
259 for _ in self._stream(ChainContext(context, self.context), generate, callbacks, **config):
260 yield context
261 except Jump as jump:
262 context, config = self.leave(context, config, callbacks)
263 if jump.out_of is not None and jump.out_of is not self:
264 raise jump from None
265 if jump.into is not None:
266 yield from jump.into.stream(context, generate, **config)
267 else:
268 context, config = self.leave(context, config, callbacks)
270 async def astream(self, context=None, /, generate=None, **config) -> AsyncIterable[ChainContext]:
271 context, config, callbacks = self.enter(context, config)
272 context = ChainContext.ensure(context)
274 try:
275 async for _ in self._astream(ChainContext(context, self.context), generate, callbacks, **config):
276 yield context
277 except Jump as jump:
278 context, config = self.leave(context, config, callbacks)
279 if jump.out_of is not None and jump.out_of is not self:
280 raise jump from None
281 if jump.into is not None:
282 async for i in jump.into.astream(context, generate, **config):
283 yield i
284 else:
285 context, config = self.leave(context, config, callbacks)
287 _context: Context | None
289 @property
290 def context(self):
291 if self._context is None:
292 self._context = {}
293 return self._context
295 @context.setter
296 def context(self, context: Context | None):
297 self._context = context
299 @context.deleter
300 def context(self):
301 self._context = None
304class Node(Loader, Interruptable):
305 def __init__(
306 self,
307 template: Template | str,
308 partial_context: Context | None = None,
309 llm: LLM | None = None,
310 **config,
311 ):
312 self.template = Template(template) if isinstance(template, str) else template
313 self._context = partial_context
314 self.callbacks: list[BaseCallback | type[BaseCallback]] = []
315 self.llm = llm
316 self.run_config = config
318 def _invoke(self, context, /, complete, callbacks, **config):
319 complete = cast(Complete, self.llm.complete if self.llm else complete)
320 assert complete is not None
322 prompt = self.render(context, callbacks)
324 context.result = complete(prompt, **(self.run_config | config))
326 self._apply_mid_processes(context, callbacks)
328 self._apply_end_processes(context, callbacks)
330 def _stream(self, context, /, generate, callbacks, **config):
331 generate = cast(Generate, self.llm.generate if self.llm else generate)
332 assert generate is not None
334 prompt = self.render(context, callbacks)
336 for result in accumulate(generate(prompt, **(self.run_config | config))):
337 context.result = result
338 self._apply_mid_processes(context, callbacks)
339 yield
341 self._apply_end_processes(context, callbacks)
343 async def _ainvoke(self, context, /, complete, callbacks, **config):
344 complete = cast(Complete | AsyncComplete, self.llm.complete if self.llm else complete)
345 assert complete is not None
347 prompt = await self.arender(context, callbacks)
349 context.result = await resolve(complete(prompt, **(self.run_config | config)))
351 await self._apply_async_mid_processes(context, callbacks)
353 await self._apply_async_end_processes(context, callbacks)
355 async def _astream(self, context, /, generate, callbacks, **config):
356 generate = cast(Generate | AsyncGenerate, self.llm.generate if self.llm else generate)
357 assert generate is not None
359 prompt = await self.arender(context, callbacks)
361 async for result in accumulate_any(generate(prompt, **(self.run_config | config))):
362 context.result = result
363 await self._apply_async_mid_processes(context, callbacks)
364 yield
366 await self._apply_async_end_processes(context, callbacks)
368 def render(self, context: Context | None = None, callbacks: list[BaseCallback] | None = None):
369 if callbacks is None:
370 callbacks = ensure_callbacks(self.callbacks)
371 context = ChainContext(context, self.context)
372 self._apply_pre_processes(context, callbacks)
373 return self.template.render(context)
375 async def arender(self, context: Context | None = None, callbacks: list[BaseCallback] | None = None):
376 if callbacks is None: 376 ↛ 377line 376 didn't jump to line 377 because the condition on line 376 was never true
377 callbacks = ensure_callbacks(self.callbacks)
378 context = ChainContext(context, self.context)
379 await self._apply_async_pre_processes(context, callbacks)
380 return await self.template.arender(context)
382 def __str__(self):
383 return f"</{self.name}/>"
386class Loop(Interruptable):
387 def __init__(self, chain: AbstractNode, partial_context: Context | None = None):
388 self.chain = chain
389 self._context = partial_context
390 self.callbacks: list[BaseCallback | type[BaseCallback]] = []
392 def _invoke(self, context, /, complete, callbacks, **config):
393 while True:
394 self._apply_pre_processes(context, callbacks)
395 self.chain.invoke(context, complete, **config)
396 self._apply_mid_processes(context, callbacks)
397 self._apply_end_processes(context, callbacks)
399 async def _ainvoke(self, context, /, complete, callbacks, **config):
400 while True:
401 await self._apply_async_pre_processes(context, callbacks)
402 await self.chain.ainvoke(context, complete, **config)
403 await self._apply_async_mid_processes(context, callbacks)
404 await self._apply_async_end_processes(context, callbacks)
406 def _stream(self, context, /, generate, callbacks, **config):
407 while True:
408 self._apply_pre_processes(context, callbacks)
409 for _ in self.chain.stream(context, generate, **config):
410 self._apply_mid_processes(context, callbacks)
411 yield
412 self._apply_end_processes(context, callbacks)
414 async def _astream(self, context, /, generate, callbacks, **config):
415 while True:
416 await self._apply_async_pre_processes(context, callbacks)
417 async for _ in self.chain.astream(context, generate, **config):
418 await self._apply_async_mid_processes(context, callbacks)
419 yield
420 await self._apply_async_end_processes(context, callbacks)
423class Chain(Interruptable):
424 def __init__(self, *nodes: AbstractNode, partial_context: Context | None = None):
425 self.nodes = list(nodes)
426 self._context = partial_context
427 self.callbacks: list[BaseCallback | type[BaseCallback]] = []
429 @classmethod
430 def _get_chain_type(cls):
431 return cls
433 def __iadd__(self, chain: AbstractNode):
434 self.nodes.append(chain)
435 return self
437 def __iter__(self):
438 return iter(self.nodes)
440 def _invoke(self, context, /, complete, callbacks: list[BaseCallback], **config):
441 self._apply_pre_processes(context, callbacks)
442 for node in self.nodes:
443 node.invoke(context, complete, **config)
444 self._apply_mid_processes(context, callbacks)
445 self._apply_end_processes(context, callbacks)
447 async def _ainvoke(self, context, /, complete, callbacks: list[BaseCallback], **config):
448 await self._apply_async_pre_processes(context, callbacks)
449 for node in self.nodes:
450 await node.ainvoke(context, complete, **config)
451 await self._apply_async_mid_processes(context, callbacks)
452 await self._apply_async_end_processes(context, callbacks)
454 def _stream(self, context, /, generate, callbacks: list[BaseCallback], **config):
455 self._apply_pre_processes(context, callbacks)
456 for node in self.nodes: 456 ↛ 460line 456 didn't jump to line 460 because the loop on line 456 didn't complete
457 for _ in node.stream(context, generate, **config):
458 self._apply_mid_processes(context, callbacks)
459 yield
460 self._apply_end_processes(context, callbacks)
462 async def _astream(self, context, /, generate, callbacks: list[BaseCallback], **config):
463 await self._apply_async_pre_processes(context, callbacks)
464 for node in self.nodes:
465 async for _ in node.astream(context, generate, **config):
466 await self._apply_async_mid_processes(context, callbacks)
467 yield
468 await self._apply_async_end_processes(context, callbacks)
470 def __repr__(self):
471 return " + ".join(map(str, self.nodes))
474class Jump(Exception):
475 def __init__(self, into: Interruptable | None = None, out_of: Interruptable | None = None):
476 self.into = into
477 self.out_of = out_of
479 def __str__(self):
480 return f"{self.out_of} does not exist in the chain hierarchy"