Coverage for promplate/chain/node.py: 73%

319 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-02-07 15:14 +0800

1from inspect import isclass 

2from itertools import accumulate 

3from typing import TYPE_CHECKING, Callable, Mapping, MutableMapping, TypeVar, overload 

4 

5from ..llm.base import * 

6from ..prompt.template import Context, Loader, SafeChainMapContext, Template 

7from .callback import BaseCallback, Callback 

8from .utils import accumulate_any, resolve 

9 

10C = TypeVar("C", bound="ChainContext") 

11 

12 

13class ChainContext(SafeChainMapContext): 

14 @overload 

15 def __new__(cls): ... 15 ↛ exitline 15 didn't return from function '__new__' because

16 

17 @overload 

18 def __new__(cls, least: C, *maps: Mapping) -> C: ... 18 ↛ exitline 18 didn't return from function '__new__' because

19 

20 @overload 

21 def __new__(cls, least: MutableMapping | None = None, *maps: Mapping): ... 21 ↛ exitline 21 didn't return from function '__new__' because

22 

23 def __init__(self, least: MutableMapping | None = None, *maps: Mapping): 

24 super().__init__({} if least is None else least, *maps) # type: ignore 

25 

26 def __new__(cls, *args, **kwargs): # type: ignore 

27 try: 

28 least = args[0] 

29 except IndexError: 

30 least = kwargs.get("least") 

31 if isinstance(least, cls) and least.__class__ is not cls: 

32 return least.__class__(*args, **kwargs) 

33 

34 return super().__new__(cls, *args, **kwargs) 

35 

36 @classmethod 

37 def ensure(cls, context): 

38 return context if isinstance(context, cls) else cls(context) 

39 

40 @property 

41 def result(self): 

42 return self.__getitem__("__result__") 

43 

44 @result.setter 

45 def result(self, result): 

46 self.__setitem__("__result__", result) 

47 

48 @result.deleter 

49 def result(self): 

50 self.__delitem__("__result__") 

51 

52 def __str__(self): 

53 return str({**self}) 

54 

55 

56Process = Callable[[ChainContext], Context | None] 

57 

58AsyncProcess = Callable[[ChainContext], Awaitable[Context | None]] 

59 

60 

61class AbstractNode(Protocol): 

62 def invoke( 62 ↛ exitline 62 didn't return from function 'invoke' because

63 self, 

64 context: Context | None = None, 

65 /, 

66 complete: Complete | None = None, 

67 **config, 

68 ) -> ChainContext: ... 

69 

70 async def ainvoke( 70 ↛ exitline 70 didn't return from function 'ainvoke' because

71 self, 

72 context: Context | None = None, 

73 /, 

74 complete: Complete | AsyncComplete | None = None, 

75 **config, 

76 ) -> ChainContext: ... 

77 

78 def stream( 78 ↛ exitline 78 didn't return from function 'stream' because

79 self, 

80 context: Context | None = None, 

81 /, 

82 generate: Generate | None = None, 

83 **config, 

84 ) -> Iterable[ChainContext]: ... 

85 

86 def astream( 86 ↛ exitline 86 didn't return from function 'astream' because

87 self, 

88 context: Context | None = None, 

89 /, 

90 generate: Generate | AsyncGenerate | None = None, 

91 **config, 

92 ) -> AsyncIterable[ChainContext]: ... 

93 

94 @classmethod 

95 def _get_chain_type(cls): 

96 return Chain 

97 

98 def __add__(self, chain: "AbstractNode"): 

99 if isinstance(chain, Chain): 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true

100 return self._get_chain_type()(self, *chain) 

101 return self._get_chain_type()(self, chain) 

102 

103 

104def ensure_callbacks(callbacks: list[BaseCallback | type[BaseCallback]]) -> list[BaseCallback]: 

105 return [i() if isclass(i) else i for i in callbacks] 

106 

107 

108class Interruptible(AbstractNode, Protocol): 

109 def _invoke( 109 ↛ exitline 109 didn't return from function '_invoke' because

110 self, 

111 context: ChainContext, 

112 /, 

113 complete: Complete | None, 

114 callbacks: list[BaseCallback], 

115 **config, 

116 ): ... 

117 

118 async def _ainvoke( 118 ↛ exitline 118 didn't return from function '_ainvoke' because

119 self, 

120 context: ChainContext, 

121 /, 

122 complete: Complete | AsyncComplete | None, 

123 callbacks: list[BaseCallback], 

124 **config, 

125 ): ... 

126 

127 def _stream( 127 ↛ exitline 127 didn't return from function '_stream' because

128 self, 

129 context: ChainContext, 

130 /, 

131 generate: Generate | None, 

132 callbacks: list[BaseCallback], 

133 **config, 

134 ) -> Iterable: ... 

135 

136 def _astream( 136 ↛ exitline 136 didn't return from function '_astream' because

137 self, 

138 context: ChainContext, 

139 /, 

140 generate: Generate | AsyncGenerate | None, 

141 callbacks: list[BaseCallback], 

142 **config, 

143 ) -> AsyncIterable: ... 

144 

145 callbacks: list[BaseCallback | type[BaseCallback]] 

146 

147 def enter(self, context: Context | None, config: Context): 

148 callbacks: list[BaseCallback] = ensure_callbacks(self.callbacks) 

149 for callback in callbacks: 

150 context, config = callback.on_enter(self, context, config) 

151 return context, config, callbacks 

152 

153 def leave(self, context: ChainContext, config: Context, callbacks: list[BaseCallback]): 

154 for callback in reversed(callbacks): 

155 context, config = callback.on_leave(self, context, config) 

156 return context, config 

157 

158 def add_pre_processes(self, *processes: Process | AsyncProcess): 

159 self.callbacks.extend(Callback(pre_process=i) for i in processes) 

160 return self 

161 

162 def add_mid_processes(self, *processes: Process | AsyncProcess): 

163 self.callbacks.extend(Callback(mid_process=i) for i in processes) 

164 return self 

165 

166 def add_end_processes(self, *processes: Process | AsyncProcess): 

167 self.callbacks.extend(Callback(end_process=i) for i in processes) 

168 return self 

169 

170 def add_callbacks(self, *callbacks: BaseCallback | type[BaseCallback]): 

171 self.callbacks.extend(callbacks) 

172 return self 

173 

174 def pre_process(self, process: Process | AsyncProcess): 

175 self.add_pre_processes(process) 

176 return process 

177 

178 def mid_process(self, process: Process | AsyncProcess): 

179 self.add_mid_processes(process) 

180 return process 

181 

182 def end_process(self, process: Process | AsyncProcess): 

183 self.add_end_processes(process) 

184 return process 

185 

186 def callback(self, callback: BaseCallback | type[BaseCallback]): 

187 self.add_callbacks(callback) 

188 return callback 

189 

190 @staticmethod 

191 def _apply_pre_processes(context: ChainContext, callbacks: list[BaseCallback]): 

192 for callback in callbacks: 

193 context |= cast(Context, callback.pre_process(context) or {}) 

194 

195 @staticmethod 

196 def _apply_mid_processes(context: ChainContext, callbacks: list[BaseCallback]): 

197 for callback in callbacks: 

198 context |= cast(Context, callback.mid_process(context) or {}) 

199 

200 @staticmethod 

201 def _apply_end_processes(context: ChainContext, callbacks: list[BaseCallback]): 

202 for callback in reversed(callbacks): 

203 context |= cast(Context, callback.end_process(context) or {}) 

204 

205 @staticmethod 

206 async def _apply_async_pre_processes(context: ChainContext, callbacks: list[BaseCallback]): 

207 for callback in callbacks: 207 ↛ 208line 207 didn't jump to line 208 because the loop on line 207 never started

208 context |= cast(Context, await resolve(callback.pre_process(context)) or {}) 

209 

210 @staticmethod 

211 async def _apply_async_mid_processes(context: ChainContext, callbacks: list[BaseCallback]): 

212 for callback in callbacks: 212 ↛ 213line 212 didn't jump to line 213 because the loop on line 212 never started

213 context |= cast(Context, await resolve(callback.mid_process(context)) or {}) 

214 

215 @staticmethod 

216 async def _apply_async_end_processes(context: ChainContext, callbacks: list[BaseCallback]): 

217 for callback in reversed(callbacks): 217 ↛ 218line 217 didn't jump to line 218 because the loop on line 217 never started

218 context |= cast(Context, await resolve(callback.end_process(context)) or {}) 

219 

220 def invoke(self, context=None, /, complete=None, **config) -> ChainContext: 

221 context, config, callbacks = self.enter(context, config) 

222 context = ChainContext.ensure(context) 

223 

224 try: 

225 self._invoke(ChainContext(context, self.context), complete, callbacks, **config) 

226 except Jump as jump: 

227 context, config = self.leave(context, config, callbacks) 

228 if jump.out_of is not None and jump.out_of is not self: 

229 raise jump from None 

230 if jump.into is not None: 230 ↛ 231line 230 didn't jump to line 231 because the condition on line 230 was never true

231 jump.into.invoke(context, complete, **config) 

232 else: 

233 context, config = self.leave(context, config, callbacks) 

234 

235 return context 

236 

237 async def ainvoke(self, context=None, /, complete=None, **config) -> ChainContext: 

238 context, config, callbacks = self.enter(context, config) 

239 context = ChainContext.ensure(context) 

240 

241 try: 

242 await self._ainvoke(ChainContext(context, self.context), complete, callbacks, **config) 

243 except Jump as jump: 

244 context, config = self.leave(context, config, callbacks) 

245 if jump.out_of is not None and jump.out_of is not self: 

246 raise jump from None 

247 if jump.into is not None: 

248 await jump.into.ainvoke(context, complete, **config) 

249 else: 

250 context, config = self.leave(context, config, callbacks) 

251 

252 return context 

253 

254 def stream(self, context=None, /, generate=None, **config) -> Iterable[ChainContext]: 

255 context, config, callbacks = self.enter(context, config) 

256 context = ChainContext.ensure(context) 

257 

258 try: 

259 for _ in self._stream(ChainContext(context, self.context), generate, callbacks, **config): 

260 yield context 

261 except Jump as jump: 

262 context, config = self.leave(context, config, callbacks) 

263 if jump.out_of is not None and jump.out_of is not self: 

264 raise jump from None 

265 if jump.into is not None: 

266 yield from jump.into.stream(context, generate, **config) 

267 else: 

268 context, config = self.leave(context, config, callbacks) 

269 

270 async def astream(self, context=None, /, generate=None, **config) -> AsyncIterable[ChainContext]: 

271 context, config, callbacks = self.enter(context, config) 

272 context = ChainContext.ensure(context) 

273 

274 try: 

275 async for _ in self._astream(ChainContext(context, self.context), generate, callbacks, **config): 

276 yield context 

277 except Jump as jump: 

278 context, config = self.leave(context, config, callbacks) 

279 if jump.out_of is not None and jump.out_of is not self: 

280 raise jump from None 

281 if jump.into is not None: 

282 async for i in jump.into.astream(context, generate, **config): 

283 yield i 

284 else: 

285 context, config = self.leave(context, config, callbacks) 

286 

287 _context: Context | None 

288 

289 @property 

290 def context(self): 

291 if self._context is None: 

292 self._context = {} 

293 return self._context 

294 

295 @context.setter 

296 def context(self, context: Context | None): 

297 self._context = context 

298 

299 @context.deleter 

300 def context(self): 

301 self._context = None 

302 

303 

304if TYPE_CHECKING: 304 ↛ 305line 304 didn't jump to line 305 because the condition on line 304 was never true

305 from typing_extensions import deprecated # type: ignore 

306 

307 @deprecated("Use `Interruptible` instead") 

308 class Interruptable(Interruptible, Protocol): ... 

309 

310else: 

311 Interruptable = Interruptible 

312 

313 

314class Node(Loader, Interruptible): 

315 def __init__( 

316 self, 

317 template: Template | str, 

318 partial_context: Context | None = None, 

319 llm: LLM | None = None, 

320 **config, 

321 ): 

322 self.template = Template(template) if isinstance(template, str) else template 

323 self._context = partial_context 

324 self.callbacks: list[BaseCallback | type[BaseCallback]] = [] 

325 self.llm = llm 

326 self.run_config = config 

327 

328 def _invoke(self, context, /, complete, callbacks, **config): 

329 complete = cast(Complete, self.llm.complete if self.llm else complete) 

330 assert complete is not None 

331 

332 prompt = self.render(context, callbacks) 

333 

334 context.result = complete(prompt, **(self.run_config | config)) 

335 

336 self._apply_mid_processes(context, callbacks) 

337 

338 self._apply_end_processes(context, callbacks) 

339 

340 def _stream(self, context, /, generate, callbacks, **config): 

341 generate = cast(Generate, self.llm.generate if self.llm else generate) 

342 assert generate is not None 

343 

344 prompt = self.render(context, callbacks) 

345 

346 for result in accumulate(generate(prompt, **(self.run_config | config))): 

347 context.result = result 

348 self._apply_mid_processes(context, callbacks) 

349 yield 

350 

351 self._apply_end_processes(context, callbacks) 

352 

353 async def _ainvoke(self, context, /, complete, callbacks, **config): 

354 complete = cast(Complete | AsyncComplete, self.llm.complete if self.llm else complete) 

355 assert complete is not None 

356 

357 prompt = await self.arender(context, callbacks) 

358 

359 context.result = await resolve(complete(prompt, **(self.run_config | config))) 

360 

361 await self._apply_async_mid_processes(context, callbacks) 

362 

363 await self._apply_async_end_processes(context, callbacks) 

364 

365 async def _astream(self, context, /, generate, callbacks, **config): 

366 generate = cast(Generate | AsyncGenerate, self.llm.generate if self.llm else generate) 

367 assert generate is not None 

368 

369 prompt = await self.arender(context, callbacks) 

370 

371 async for result in accumulate_any(generate(prompt, **(self.run_config | config))): 

372 context.result = result 

373 await self._apply_async_mid_processes(context, callbacks) 

374 yield 

375 

376 await self._apply_async_end_processes(context, callbacks) 

377 

378 def render(self, context: Context | None = None, callbacks: list[BaseCallback] | None = None): 

379 if callbacks is None: 

380 callbacks = ensure_callbacks(self.callbacks) 

381 context = ChainContext(context, self.context) 

382 self._apply_pre_processes(context, callbacks) 

383 return self.template.render(context) 

384 

385 async def arender(self, context: Context | None = None, callbacks: list[BaseCallback] | None = None): 

386 if callbacks is None: 386 ↛ 387line 386 didn't jump to line 387 because the condition on line 386 was never true

387 callbacks = ensure_callbacks(self.callbacks) 

388 context = ChainContext(context, self.context) 

389 await self._apply_async_pre_processes(context, callbacks) 

390 return await self.template.arender(context) 

391 

392 def __str__(self): 

393 return f"</{self.name}/>" 

394 

395 

396class Loop(Interruptible): 

397 def __init__(self, chain: AbstractNode, partial_context: Context | None = None): 

398 self.chain = chain 

399 self._context = partial_context 

400 self.callbacks: list[BaseCallback | type[BaseCallback]] = [] 

401 

402 def _invoke(self, context, /, complete, callbacks, **config): 

403 while True: 

404 self._apply_pre_processes(context, callbacks) 

405 self.chain.invoke(context, complete, **config) 

406 self._apply_mid_processes(context, callbacks) 

407 self._apply_end_processes(context, callbacks) 

408 

409 async def _ainvoke(self, context, /, complete, callbacks, **config): 

410 while True: 

411 await self._apply_async_pre_processes(context, callbacks) 

412 await self.chain.ainvoke(context, complete, **config) 

413 await self._apply_async_mid_processes(context, callbacks) 

414 await self._apply_async_end_processes(context, callbacks) 

415 

416 def _stream(self, context, /, generate, callbacks, **config): 

417 while True: 

418 self._apply_pre_processes(context, callbacks) 

419 for _ in self.chain.stream(context, generate, **config): 

420 self._apply_mid_processes(context, callbacks) 

421 yield 

422 self._apply_end_processes(context, callbacks) 

423 

424 async def _astream(self, context, /, generate, callbacks, **config): 

425 while True: 

426 await self._apply_async_pre_processes(context, callbacks) 

427 async for _ in self.chain.astream(context, generate, **config): 

428 await self._apply_async_mid_processes(context, callbacks) 

429 yield 

430 await self._apply_async_end_processes(context, callbacks) 

431 

432 

433class Chain(Interruptible): 

434 def __init__(self, *nodes: AbstractNode, partial_context: Context | None = None): 

435 self.nodes = list(nodes) 

436 self._context = partial_context 

437 self.callbacks: list[BaseCallback | type[BaseCallback]] = [] 

438 

439 @classmethod 

440 def _get_chain_type(cls): 

441 return cls 

442 

443 def __iadd__(self, chain: AbstractNode): 

444 self.nodes.append(chain) 

445 return self 

446 

447 def __iter__(self): 

448 return iter(self.nodes) 

449 

450 def _invoke(self, context, /, complete, callbacks: list[BaseCallback], **config): 

451 self._apply_pre_processes(context, callbacks) 

452 for node in self.nodes: 

453 node.invoke(context, complete, **config) 

454 self._apply_mid_processes(context, callbacks) 

455 self._apply_end_processes(context, callbacks) 

456 

457 async def _ainvoke(self, context, /, complete, callbacks: list[BaseCallback], **config): 

458 await self._apply_async_pre_processes(context, callbacks) 

459 for node in self.nodes: 

460 await node.ainvoke(context, complete, **config) 

461 await self._apply_async_mid_processes(context, callbacks) 

462 await self._apply_async_end_processes(context, callbacks) 

463 

464 def _stream(self, context, /, generate, callbacks: list[BaseCallback], **config): 

465 self._apply_pre_processes(context, callbacks) 

466 for node in self.nodes: 466 ↛ 470line 466 didn't jump to line 470 because the loop on line 466 didn't complete

467 for _ in node.stream(context, generate, **config): 

468 self._apply_mid_processes(context, callbacks) 

469 yield 

470 self._apply_end_processes(context, callbacks) 

471 

472 async def _astream(self, context, /, generate, callbacks: list[BaseCallback], **config): 

473 await self._apply_async_pre_processes(context, callbacks) 

474 for node in self.nodes: 

475 async for _ in node.astream(context, generate, **config): 

476 await self._apply_async_mid_processes(context, callbacks) 

477 yield 

478 await self._apply_async_end_processes(context, callbacks) 

479 

480 def __repr__(self): 

481 return " + ".join(map(str, self.nodes)) 

482 

483 

484class Jump(Exception): 

485 def __init__(self, into: Interruptible | None = None, out_of: Interruptible | None = None): 

486 self.into = into 

487 self.out_of = out_of 

488 

489 def __str__(self): 

490 return f"{self.out_of} does not exist in the chain hierarchy"