Coverage for promplate/prompt/template.py: 80%

176 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-18 19:52 +0800

1from ast import Expr, parse, unparse 

2from collections import ChainMap 

3from functools import cached_property, partial 

4from pathlib import Path 

5from sys import path as sys_path 

6from sys import version_info 

7from textwrap import dedent 

8from typing import TYPE_CHECKING, Any, Literal, Protocol 

9 

10from .builder import * 

11from .utils import * 

12 

13Context = dict[str, Any] # globals must be a real dict 

14 

15 

16class Component(Protocol): 

17 def render(self, context: Context) -> str: ... 17 ↛ exitline 17 didn't return from function 'render' because

18 

19 async def arender(self, context: Context) -> str: ... 19 ↛ exitline 19 didn't return from function 'arender' because

20 

21 

22class TemplateCore(AutoNaming): 

23 """A simple template compiler, for a jinja2-like syntax.""" 

24 

25 def __init__(self, text: str): 

26 """Construct a Templite with the given `text`.""" 

27 

28 self.text = text 

29 

30 def _flush(self): 

31 for line in self._buffer: 

32 self._builder.add_line(line) 

33 self._buffer.clear() 

34 

35 @staticmethod 

36 def _unwrap_token(token: str): 

37 return dedent(token.strip()[2:-2].strip("-")).strip() 

38 

39 def _on_literal_token(self, token: str): 

40 self._buffer.append(f"__append__({repr(token)})") 

41 

42 def _on_eval_token(self, token): 

43 token = self._unwrap_token(token) 

44 if "\n" in token: 

45 mod = parse(token) 

46 [*rest, last] = mod.body 

47 assert isinstance(last, Expr), "{{ }} block must end with an expression, or you should use {# #} block" 

48 self._buffer.extend(unparse(rest).splitlines()) # type: ignore 

49 exp = unparse(last) 

50 else: 

51 exp = token 

52 self._buffer.append(f"__append__({exp})") 

53 

54 def _on_exec_token(self, token): 

55 self._buffer.extend(self._unwrap_token(token).splitlines()) 

56 

57 def _on_special_token(self, token, sync: bool): 

58 inner = self._unwrap_token(token) 

59 

60 if inner.startswith("end"): 

61 last = self._ops_stack.pop() 

62 assert last == inner.removeprefix("end") 

63 self._flush() 

64 self._builder.dedent() 

65 

66 else: 

67 op = inner.split(" ", 1)[0] 

68 

69 if op == "if" or op == "for" or op == "while": 

70 self._ops_stack.append(op) 

71 self._flush() 

72 self._builder.add_line(f"{inner}:") 

73 self._builder.indent() 

74 

75 elif op == "else" or op == "elif": 

76 self._flush() 

77 self._builder.dedent() 

78 self._builder.add_line(f"{inner}:") 

79 self._builder.indent() 

80 

81 else: 

82 params: str = self._make_context(inner) 

83 if sync: 83 ↛ 86line 83 didn't jump to line 86 because the condition on line 83 was always true

84 self._buffer.append(f"__append__({op}.render({params}))") 

85 else: 

86 self._buffer.append(f"__append__(await {op}.arender({params}))") 

87 

88 @staticmethod 

89 def _make_context(text: str): 

90 """generate context parameter if specified otherwise use locals() by default""" 

91 if version_info >= (3, 13): 91 ↛ 92line 91 didn't jump to line 92 because the condition on line 91 was never true

92 return f"globals() | locals() | dict({text[text.index(' ') + 1:]})" if " " in text else "globals() | locals()" 

93 return f"locals() | dict({text[text.index(' ') + 1:]})" if " " in text else "locals()" 

94 

95 def compile(self, sync=True, indent_str="\t"): 

96 self._buffer = [] 

97 self._ops_stack = [] 

98 self._builder = get_base_builder(sync, indent_str) 

99 

100 for token in split_template_tokens(self.text): 

101 if not token: 

102 continue 

103 s_token = token.strip() 

104 if s_token.startswith("{{") and s_token.endswith("}}"): 

105 self._on_eval_token(token) 

106 elif s_token.startswith("{#") and s_token.endswith("#}"): 

107 self._on_exec_token(token) 

108 elif s_token.startswith("{%") and s_token.endswith("%}") and "\n" not in s_token: 

109 self._on_special_token(token, sync) 

110 else: 

111 self._on_literal_token(token) 

112 

113 if self._ops_stack: 

114 raise SyntaxError(self._ops_stack) 

115 

116 self._flush() 

117 self._builder.add_line("return ''.join(map(str, __parts__))") 

118 self._builder.dedent() 

119 

120 error_handling: Literal["linecache", "tempfile", "file"] = "file" if __debug__ else "tempfile" 

121 

122 def _patch_for_error_handling(self, sync: bool): 

123 match self.error_handling: 

124 case "linecache": 

125 add_linecache(self.name, partial(self.get_script, sync, "\t")) 

126 case "file" | "tempfile": 

127 file = save_tempfile(self.name, self.get_script(sync, "\t"), self.error_handling == "tempfile") 

128 sys_path.append(str(file.parent)) 

129 

130 @cached_property 

131 def _render_code(self): 

132 self.compile() 

133 return self._builder.get_render_function().__code__.replace(co_filename=self.name, co_name="render") 

134 

135 def render(self, context: Context) -> str: 

136 try: 

137 return eval(self._render_code, context) 

138 except Exception: 

139 self._patch_for_error_handling(sync=True) 

140 raise 

141 

142 @cached_property 

143 def _arender_code(self): 

144 self.compile(sync=False) 

145 return self._builder.get_render_function().__code__.replace(co_filename=self.name, co_name="arender") 

146 

147 async def arender(self, context: Context) -> str: 

148 try: 

149 return await eval(self._arender_code, context) 

150 except Exception: 

151 self._patch_for_error_handling(sync=False) 

152 raise 

153 

154 def get_script(self, sync=True, indent_str=" "): 

155 """compile template string into python script""" 

156 self.compile(sync, indent_str) 

157 return str(self._builder) 

158 

159 

160class Loader(AutoNaming): 

161 @classmethod 

162 def read(cls, path: str | Path, encoding="utf-8"): 

163 path = Path(path) 

164 obj = cls(path.read_text(encoding)) 

165 obj.name = path.stem 

166 return obj 

167 

168 @classmethod 

169 async def aread(cls, path: str | Path, encoding="utf-8"): 

170 from aiofiles import open 

171 

172 async with open(path, encoding=encoding) as f: 

173 content = await f.read() 

174 

175 path = Path(path) 

176 obj = cls(content) 

177 obj.name = path.stem 

178 return obj 

179 

180 @classmethod 

181 def _patch_kwargs(cls, kwargs: dict) -> dict: 

182 return {"headers": {"User-Agent": get_user_agent(cls)}} | kwargs 

183 

184 @staticmethod 

185 def _join_url(url: str): 

186 if url.startswith("http"): 

187 return url 

188 

189 from urllib.parse import urljoin 

190 

191 return urljoin("https://promplate.dev/", url) 

192 

193 @classmethod 

194 def fetch(cls, url: str, **kwargs): 

195 from .utils import _get_client 

196 

197 response = _get_client().get(cls._join_url(url), **cls._patch_kwargs(kwargs)) 

198 obj = cls(response.raise_for_status().text) 

199 obj.name = Path(url).stem 

200 return obj 

201 

202 @classmethod 

203 async def afetch(cls, url: str, **kwargs): 

204 from .utils import _get_aclient 

205 

206 response = await _get_aclient().get(cls._join_url(url), **cls._patch_kwargs(kwargs)) 

207 obj = cls(response.raise_for_status().text) 

208 obj.name = Path(url).stem 

209 return obj 

210 

211 

212class SafeChainMapContext(ChainMap, dict): 

213 if TYPE_CHECKING: # fix type from `collections.ChainMap` 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true

214 from sys import version_info 

215 

216 if version_info >= (3, 11): 

217 from typing_extensions import Self 

218 else: 

219 from typing import Self 

220 

221 copy: Callable[[Self], Self] 

222 

223 

224class Template(TemplateCore, Loader): 

225 def __init__(self, text: str, /, context: Context | None = None): 

226 super().__init__(text) 

227 self.context = {} if context is None else context 

228 

229 def render(self, context: Context | None = None): 

230 if context is None: 

231 context = SafeChainMapContext({}, self.context) 

232 else: 

233 context = SafeChainMapContext({}, context, self.context) 

234 

235 return super().render(context) 

236 

237 async def arender(self, context: Context | None = None): 

238 if context is None: 

239 context = SafeChainMapContext({}, self.context) 

240 else: 

241 context = SafeChainMapContext({}, context, self.context) 

242 

243 return await super().arender(context)