Coverage for promplate/chain/callback.py: 89%
36 statements
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-09 22:54 +0800
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-09 22:54 +0800
1from typing import TYPE_CHECKING, Awaitable, Callable, Protocol
3from ..prompt import Context
5if TYPE_CHECKING: 5 ↛ 6line 5 didn't jump to line 6 because the condition on line 5 was never true
6 from .node import AsyncProcess, ChainContext, Interruptable, Process
9class BaseCallback(Protocol):
10 def pre_process(self, context: "ChainContext") -> Context | Awaitable[Context | None] | None: ...
12 def mid_process(self, context: "ChainContext") -> Context | Awaitable[Context | None] | None: ...
14 def end_process(self, context: "ChainContext") -> Context | Awaitable[Context | None] | None: ...
16 def on_enter(self, node: "Interruptable", context: Context | None, config: Context) -> tuple[Context | None, Context]:
17 return context, config
19 def on_leave(self, node: "Interruptable", context: "ChainContext", config: Context) -> tuple["ChainContext", Context]:
20 return context, config
23class Callback(BaseCallback):
24 def __init__(
25 self,
26 *,
27 pre_process: "Process | AsyncProcess | None" = None,
28 mid_process: "Process | AsyncProcess | None" = None,
29 end_process: "Process | AsyncProcess | None" = None,
30 on_enter: Callable[["Interruptable", Context | None, Context], tuple[Context | None, Context]] | None = None,
31 on_leave: Callable[["Interruptable", "ChainContext", Context], tuple["ChainContext", Context]] | None = None,
32 ):
33 self._pre_process = pre_process
34 self._mid_process = mid_process
35 self._end_process = end_process
36 self._on_enter = on_enter
37 self._on_leave = on_leave
39 def pre_process(self, context):
40 if self._pre_process is not None:
41 return self._pre_process(context)
43 def mid_process(self, context):
44 if self._mid_process is not None:
45 return self._mid_process(context)
47 def end_process(self, context):
48 if self._end_process is not None:
49 return self._end_process(context)
51 def on_enter(self, node, context, config):
52 if self._on_enter is not None: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true
53 return self._on_enter(node, context, config)
54 return context, config
56 def on_leave(self, node, context, config):
57 if self._on_leave is not None: 57 ↛ 58line 57 didn't jump to line 58 because the condition on line 57 was never true
58 return self._on_leave(node, context, config)
59 return context, config