Coverage for promplate/chain/node.py: 74%

314 statements  

« prev     ^ index     » next       coverage.py v7.6.2, created at 2024-10-09 22:54 +0800

1from inspect import isclass 

2from itertools import accumulate 

3from typing import Callable, Mapping, MutableMapping, TypeVar, overload 

4 

5from ..llm.base import * 

6from ..prompt.template import Context, Loader, SafeChainMapContext, Template 

7from .callback import BaseCallback, Callback 

8from .utils import accumulate_any, resolve 

9 

10C = TypeVar("C", bound="ChainContext") 

11 

12 

13class ChainContext(SafeChainMapContext): 

14 @overload 

15 def __new__(cls): ... 15 ↛ exitline 15 didn't return from function '__new__' because

16 

17 @overload 

18 def __new__(cls, least: C, *maps: Mapping) -> C: ... 18 ↛ exitline 18 didn't return from function '__new__' because

19 

20 @overload 

21 def __new__(cls, least: MutableMapping | None = None, *maps: Mapping): ... 21 ↛ exitline 21 didn't return from function '__new__' because

22 

23 def __init__(self, least: MutableMapping | None = None, *maps: Mapping): 

24 super().__init__({} if least is None else least, *maps) # type: ignore 

25 

26 def __new__(cls, *args, **kwargs): # type: ignore 

27 try: 

28 least = args[0] 

29 except IndexError: 

30 least = kwargs.get("least") 

31 if isinstance(least, cls) and least.__class__ is not cls: 

32 return least.__class__(*args, **kwargs) 

33 

34 return super().__new__(cls, *args, **kwargs) 

35 

36 @classmethod 

37 def ensure(cls, context): 

38 return context if isinstance(context, cls) else cls(context) 

39 

40 @property 

41 def result(self): 

42 return self.__getitem__("__result__") 

43 

44 @result.setter 

45 def result(self, result): 

46 self.__setitem__("__result__", result) 

47 

48 @result.deleter 

49 def result(self): 

50 self.__delitem__("__result__") 

51 

52 def __str__(self): 

53 return str({**self}) 

54 

55 

56Process = Callable[[ChainContext], Context | None] 

57 

58AsyncProcess = Callable[[ChainContext], Awaitable[Context | None]] 

59 

60 

61class AbstractNode(Protocol): 

62 def invoke( 62 ↛ exitline 62 didn't jump to the function exit

63 self, 

64 context: Context | None = None, 

65 /, 

66 complete: Complete | None = None, 

67 **config, 

68 ) -> ChainContext: ... 

69 

70 async def ainvoke( 70 ↛ exitline 70 didn't jump to the function exit

71 self, 

72 context: Context | None = None, 

73 /, 

74 complete: Complete | AsyncComplete | None = None, 

75 **config, 

76 ) -> ChainContext: ... 

77 

78 def stream( 78 ↛ exitline 78 didn't jump to the function exit

79 self, 

80 context: Context | None = None, 

81 /, 

82 generate: Generate | None = None, 

83 **config, 

84 ) -> Iterable[ChainContext]: ... 

85 

86 def astream( 86 ↛ exitline 86 didn't jump to the function exit

87 self, 

88 context: Context | None = None, 

89 /, 

90 generate: Generate | AsyncGenerate | None = None, 

91 **config, 

92 ) -> AsyncIterable[ChainContext]: ... 

93 

94 @classmethod 

95 def _get_chain_type(cls): 

96 return Chain 

97 

98 def __add__(self, chain: "AbstractNode"): 

99 if isinstance(chain, Chain): 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true

100 return self._get_chain_type()(self, *chain) 

101 return self._get_chain_type()(self, chain) 

102 

103 

104def ensure_callbacks(callbacks: list[BaseCallback | type[BaseCallback]]) -> list[BaseCallback]: 

105 return [i() if isclass(i) else i for i in callbacks] 

106 

107 

108class Interruptable(AbstractNode, Protocol): 

109 def _invoke( 109 ↛ exitline 109 didn't jump to the function exit

110 self, 

111 context: ChainContext, 

112 /, 

113 complete: Complete | None, 

114 callbacks: list[BaseCallback], 

115 **config, 

116 ): ... 

117 

118 async def _ainvoke( 118 ↛ exitline 118 didn't jump to the function exit

119 self, 

120 context: ChainContext, 

121 /, 

122 complete: Complete | AsyncComplete | None, 

123 callbacks: list[BaseCallback], 

124 **config, 

125 ): ... 

126 

127 def _stream( 127 ↛ exitline 127 didn't jump to the function exit

128 self, 

129 context: ChainContext, 

130 /, 

131 generate: Generate | None, 

132 callbacks: list[BaseCallback], 

133 **config, 

134 ) -> Iterable: ... 

135 

136 def _astream( 136 ↛ exitline 136 didn't jump to the function exit

137 self, 

138 context: ChainContext, 

139 /, 

140 generate: Generate | AsyncGenerate | None, 

141 callbacks: list[BaseCallback], 

142 **config, 

143 ) -> AsyncIterable: ... 

144 

145 callbacks: list[BaseCallback | type[BaseCallback]] 

146 

147 def enter(self, context: Context | None, config: Context): 

148 callbacks: list[BaseCallback] = ensure_callbacks(self.callbacks) 

149 for callback in callbacks: 

150 context, config = callback.on_enter(self, context, config) 

151 return context, config, callbacks 

152 

153 def leave(self, context: ChainContext, config: Context, callbacks: list[BaseCallback]): 

154 for callback in reversed(callbacks): 

155 context, config = callback.on_leave(self, context, config) 

156 return context, config 

157 

158 def add_pre_processes(self, *processes: Process | AsyncProcess): 

159 self.callbacks.extend(Callback(pre_process=i) for i in processes) 

160 return self 

161 

162 def add_mid_processes(self, *processes: Process | AsyncProcess): 

163 self.callbacks.extend(Callback(mid_process=i) for i in processes) 

164 return self 

165 

166 def add_end_processes(self, *processes: Process | AsyncProcess): 

167 self.callbacks.extend(Callback(end_process=i) for i in processes) 

168 return self 

169 

170 def add_callbacks(self, *callbacks: BaseCallback | type[BaseCallback]): 

171 self.callbacks.extend(callbacks) 

172 return self 

173 

174 def pre_process(self, process: Process | AsyncProcess): 

175 self.add_pre_processes(process) 

176 return process 

177 

178 def mid_process(self, process: Process | AsyncProcess): 

179 self.add_mid_processes(process) 

180 return process 

181 

182 def end_process(self, process: Process | AsyncProcess): 

183 self.add_end_processes(process) 

184 return process 

185 

186 def callback(self, callback: BaseCallback | type[BaseCallback]): 

187 self.add_callbacks(callback) 

188 return callback 

189 

190 @staticmethod 

191 def _apply_pre_processes(context: ChainContext, callbacks: list[BaseCallback]): 

192 for callback in callbacks: 

193 context |= cast(Context, callback.pre_process(context) or {}) 

194 

195 @staticmethod 

196 def _apply_mid_processes(context: ChainContext, callbacks: list[BaseCallback]): 

197 for callback in callbacks: 

198 context |= cast(Context, callback.mid_process(context) or {}) 

199 

200 @staticmethod 

201 def _apply_end_processes(context: ChainContext, callbacks: list[BaseCallback]): 

202 for callback in reversed(callbacks): 

203 context |= cast(Context, callback.end_process(context) or {}) 

204 

205 @staticmethod 

206 async def _apply_async_pre_processes(context: ChainContext, callbacks: list[BaseCallback]): 

207 for callback in callbacks: 207 ↛ 208line 207 didn't jump to line 208 because the loop on line 207 never started

208 context |= cast(Context, await resolve(callback.pre_process(context)) or {}) 

209 

210 @staticmethod 

211 async def _apply_async_mid_processes(context: ChainContext, callbacks: list[BaseCallback]): 

212 for callback in callbacks: 212 ↛ 213line 212 didn't jump to line 213 because the loop on line 212 never started

213 context |= cast(Context, await resolve(callback.mid_process(context)) or {}) 

214 

215 @staticmethod 

216 async def _apply_async_end_processes(context: ChainContext, callbacks: list[BaseCallback]): 

217 for callback in reversed(callbacks): 217 ↛ 218line 217 didn't jump to line 218 because the loop on line 217 never started

218 context |= cast(Context, await resolve(callback.end_process(context)) or {}) 

219 

220 def invoke(self, context=None, /, complete=None, **config) -> ChainContext: 

221 context, config, callbacks = self.enter(context, config) 

222 context = ChainContext.ensure(context) 

223 

224 try: 

225 self._invoke(ChainContext(context, self.context), complete, callbacks, **config) 

226 except Jump as jump: 

227 context, config = self.leave(context, config, callbacks) 

228 if jump.out_of is not None and jump.out_of is not self: 

229 raise jump from None 

230 if jump.into is not None: 230 ↛ 231line 230 didn't jump to line 231 because the condition on line 230 was never true

231 jump.into.invoke(context, complete, **config) 

232 else: 

233 context, config = self.leave(context, config, callbacks) 

234 

235 return context 

236 

237 async def ainvoke(self, context=None, /, complete=None, **config) -> ChainContext: 

238 context, config, callbacks = self.enter(context, config) 

239 context = ChainContext.ensure(context) 

240 

241 try: 

242 await self._ainvoke(ChainContext(context, self.context), complete, callbacks, **config) 

243 except Jump as jump: 

244 context, config = self.leave(context, config, callbacks) 

245 if jump.out_of is not None and jump.out_of is not self: 

246 raise jump from None 

247 if jump.into is not None: 

248 await jump.into.ainvoke(context, complete, **config) 

249 else: 

250 context, config = self.leave(context, config, callbacks) 

251 

252 return context 

253 

254 def stream(self, context=None, /, generate=None, **config) -> Iterable[ChainContext]: 

255 context, config, callbacks = self.enter(context, config) 

256 context = ChainContext.ensure(context) 

257 

258 try: 

259 for _ in self._stream(ChainContext(context, self.context), generate, callbacks, **config): 

260 yield context 

261 except Jump as jump: 

262 context, config = self.leave(context, config, callbacks) 

263 if jump.out_of is not None and jump.out_of is not self: 

264 raise jump from None 

265 if jump.into is not None: 

266 yield from jump.into.stream(context, generate, **config) 

267 else: 

268 context, config = self.leave(context, config, callbacks) 

269 

270 async def astream(self, context=None, /, generate=None, **config) -> AsyncIterable[ChainContext]: 

271 context, config, callbacks = self.enter(context, config) 

272 context = ChainContext.ensure(context) 

273 

274 try: 

275 async for _ in self._astream(ChainContext(context, self.context), generate, callbacks, **config): 

276 yield context 

277 except Jump as jump: 

278 context, config = self.leave(context, config, callbacks) 

279 if jump.out_of is not None and jump.out_of is not self: 

280 raise jump from None 

281 if jump.into is not None: 

282 async for i in jump.into.astream(context, generate, **config): 

283 yield i 

284 else: 

285 context, config = self.leave(context, config, callbacks) 

286 

287 _context: Context | None 

288 

289 @property 

290 def context(self): 

291 if self._context is None: 

292 self._context = {} 

293 return self._context 

294 

295 @context.setter 

296 def context(self, context: Context | None): 

297 self._context = context 

298 

299 @context.deleter 

300 def context(self): 

301 self._context = None 

302 

303 

304class Node(Loader, Interruptable): 

305 def __init__( 

306 self, 

307 template: Template | str, 

308 partial_context: Context | None = None, 

309 llm: LLM | None = None, 

310 **config, 

311 ): 

312 self.template = Template(template) if isinstance(template, str) else template 

313 self._context = partial_context 

314 self.callbacks: list[BaseCallback | type[BaseCallback]] = [] 

315 self.llm = llm 

316 self.run_config = config 

317 

318 def _invoke(self, context, /, complete, callbacks, **config): 

319 complete = cast(Complete, self.llm.complete if self.llm else complete) 

320 assert complete is not None 

321 

322 prompt = self.render(context, callbacks) 

323 

324 context.result = complete(prompt, **(self.run_config | config)) 

325 

326 self._apply_mid_processes(context, callbacks) 

327 

328 self._apply_end_processes(context, callbacks) 

329 

330 def _stream(self, context, /, generate, callbacks, **config): 

331 generate = cast(Generate, self.llm.generate if self.llm else generate) 

332 assert generate is not None 

333 

334 prompt = self.render(context, callbacks) 

335 

336 for result in accumulate(generate(prompt, **(self.run_config | config))): 

337 context.result = result 

338 self._apply_mid_processes(context, callbacks) 

339 yield 

340 

341 self._apply_end_processes(context, callbacks) 

342 

343 async def _ainvoke(self, context, /, complete, callbacks, **config): 

344 complete = cast(Complete | AsyncComplete, self.llm.complete if self.llm else complete) 

345 assert complete is not None 

346 

347 prompt = await self.arender(context, callbacks) 

348 

349 context.result = await resolve(complete(prompt, **(self.run_config | config))) 

350 

351 await self._apply_async_mid_processes(context, callbacks) 

352 

353 await self._apply_async_end_processes(context, callbacks) 

354 

355 async def _astream(self, context, /, generate, callbacks, **config): 

356 generate = cast(Generate | AsyncGenerate, self.llm.generate if self.llm else generate) 

357 assert generate is not None 

358 

359 prompt = await self.arender(context, callbacks) 

360 

361 async for result in accumulate_any(generate(prompt, **(self.run_config | config))): 

362 context.result = result 

363 await self._apply_async_mid_processes(context, callbacks) 

364 yield 

365 

366 await self._apply_async_end_processes(context, callbacks) 

367 

368 def render(self, context: Context | None = None, callbacks: list[BaseCallback] | None = None): 

369 if callbacks is None: 

370 callbacks = ensure_callbacks(self.callbacks) 

371 context = ChainContext(context, self.context) 

372 self._apply_pre_processes(context, callbacks) 

373 return self.template.render(context) 

374 

375 async def arender(self, context: Context | None = None, callbacks: list[BaseCallback] | None = None): 

376 if callbacks is None: 376 ↛ 377line 376 didn't jump to line 377 because the condition on line 376 was never true

377 callbacks = ensure_callbacks(self.callbacks) 

378 context = ChainContext(context, self.context) 

379 await self._apply_async_pre_processes(context, callbacks) 

380 return await self.template.arender(context) 

381 

382 def __str__(self): 

383 return f"</{self.name}/>" 

384 

385 

386class Loop(Interruptable): 

387 def __init__(self, chain: AbstractNode, partial_context: Context | None = None): 

388 self.chain = chain 

389 self._context = partial_context 

390 self.callbacks: list[BaseCallback | type[BaseCallback]] = [] 

391 

392 def _invoke(self, context, /, complete, callbacks, **config): 

393 while True: 

394 self._apply_pre_processes(context, callbacks) 

395 self.chain.invoke(context, complete, **config) 

396 self._apply_mid_processes(context, callbacks) 

397 self._apply_end_processes(context, callbacks) 

398 

399 async def _ainvoke(self, context, /, complete, callbacks, **config): 

400 while True: 

401 await self._apply_async_pre_processes(context, callbacks) 

402 await self.chain.ainvoke(context, complete, **config) 

403 await self._apply_async_mid_processes(context, callbacks) 

404 await self._apply_async_end_processes(context, callbacks) 

405 

406 def _stream(self, context, /, generate, callbacks, **config): 

407 while True: 

408 self._apply_pre_processes(context, callbacks) 

409 for _ in self.chain.stream(context, generate, **config): 

410 self._apply_mid_processes(context, callbacks) 

411 yield 

412 self._apply_end_processes(context, callbacks) 

413 

414 async def _astream(self, context, /, generate, callbacks, **config): 

415 while True: 

416 await self._apply_async_pre_processes(context, callbacks) 

417 async for _ in self.chain.astream(context, generate, **config): 

418 await self._apply_async_mid_processes(context, callbacks) 

419 yield 

420 await self._apply_async_end_processes(context, callbacks) 

421 

422 

423class Chain(Interruptable): 

424 def __init__(self, *nodes: AbstractNode, partial_context: Context | None = None): 

425 self.nodes = list(nodes) 

426 self._context = partial_context 

427 self.callbacks: list[BaseCallback | type[BaseCallback]] = [] 

428 

429 @classmethod 

430 def _get_chain_type(cls): 

431 return cls 

432 

433 def __iadd__(self, chain: AbstractNode): 

434 self.nodes.append(chain) 

435 return self 

436 

437 def __iter__(self): 

438 return iter(self.nodes) 

439 

440 def _invoke(self, context, /, complete, callbacks: list[BaseCallback], **config): 

441 self._apply_pre_processes(context, callbacks) 

442 for node in self.nodes: 

443 node.invoke(context, complete, **config) 

444 self._apply_mid_processes(context, callbacks) 

445 self._apply_end_processes(context, callbacks) 

446 

447 async def _ainvoke(self, context, /, complete, callbacks: list[BaseCallback], **config): 

448 await self._apply_async_pre_processes(context, callbacks) 

449 for node in self.nodes: 

450 await node.ainvoke(context, complete, **config) 

451 await self._apply_async_mid_processes(context, callbacks) 

452 await self._apply_async_end_processes(context, callbacks) 

453 

454 def _stream(self, context, /, generate, callbacks: list[BaseCallback], **config): 

455 self._apply_pre_processes(context, callbacks) 

456 for node in self.nodes: 456 ↛ 460line 456 didn't jump to line 460 because the loop on line 456 didn't complete

457 for _ in node.stream(context, generate, **config): 

458 self._apply_mid_processes(context, callbacks) 

459 yield 

460 self._apply_end_processes(context, callbacks) 

461 

462 async def _astream(self, context, /, generate, callbacks: list[BaseCallback], **config): 

463 await self._apply_async_pre_processes(context, callbacks) 

464 for node in self.nodes: 

465 async for _ in node.astream(context, generate, **config): 

466 await self._apply_async_mid_processes(context, callbacks) 

467 yield 

468 await self._apply_async_end_processes(context, callbacks) 

469 

470 def __repr__(self): 

471 return " + ".join(map(str, self.nodes)) 

472 

473 

474class Jump(Exception): 

475 def __init__(self, into: Interruptable | None = None, out_of: Interruptable | None = None): 

476 self.into = into 

477 self.out_of = out_of 

478 

479 def __str__(self): 

480 return f"{self.out_of} does not exist in the chain hierarchy"