Coverage for promplate/prompt/template.py: 81%

173 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-08 04:34 +0800

1from ast import Expr, parse, unparse 

2from collections import ChainMap 

3from functools import cached_property, partial 

4from pathlib import Path 

5from sys import path as sys_path 

6from textwrap import dedent 

7from typing import TYPE_CHECKING, Any, Literal, Protocol 

8 

9from .builder import * 

10from .utils import * 

11 

12Context = dict[str, Any] # globals must be a real dict 

13 

14 

15class Component(Protocol): 

16 def render(self, context: Context) -> str: ... 16 ↛ exitline 16 didn't jump to line 16 because

17 

18 async def arender(self, context: Context) -> str: ... 18 ↛ exitline 18 didn't jump to line 18 because

19 

20 

21class TemplateCore(AutoNaming): 

22 """A simple template compiler, for a jinja2-like syntax.""" 

23 

24 def __init__(self, text: str): 

25 """Construct a Templite with the given `text`.""" 

26 

27 self.text = text 

28 

29 def _flush(self): 

30 for line in self._buffer: 

31 self._builder.add_line(line) 

32 self._buffer.clear() 

33 

34 @staticmethod 

35 def _unwrap_token(token: str): 

36 return dedent(token.strip()[2:-2].strip("-")).strip() 

37 

38 def _on_literal_token(self, token: str): 

39 self._buffer.append(f"__append__({repr(token)})") 

40 

41 def _on_eval_token(self, token): 

42 token = self._unwrap_token(token) 

43 if "\n" in token: 

44 mod = parse(token) 

45 [*rest, last] = mod.body 

46 assert isinstance(last, Expr), "{{ }} block must end with an expression, or you should use {# #} block" 

47 self._buffer.extend(unparse(rest).splitlines()) # type: ignore 

48 exp = unparse(last) 

49 else: 

50 exp = token 

51 self._buffer.append(f"__append__({exp})") 

52 

53 def _on_exec_token(self, token): 

54 self._buffer.extend(self._unwrap_token(token).splitlines()) 

55 

56 def _on_special_token(self, token, sync: bool): 

57 inner = self._unwrap_token(token) 

58 

59 if inner.startswith("end"): 

60 last = self._ops_stack.pop() 

61 assert last == inner.removeprefix("end") 

62 self._flush() 

63 self._builder.dedent() 

64 

65 else: 

66 op = inner.split(" ", 1)[0] 

67 

68 if op == "if" or op == "for" or op == "while": 

69 self._ops_stack.append(op) 

70 self._flush() 

71 self._builder.add_line(f"{inner}:") 

72 self._builder.indent() 

73 

74 elif op == "else" or op == "elif": 

75 self._flush() 

76 self._builder.dedent() 

77 self._builder.add_line(f"{inner}:") 

78 self._builder.indent() 

79 

80 else: 

81 params: str = self._make_context(inner) 

82 if sync: 82 ↛ 85line 82 didn't jump to line 85 because the condition on line 82 was always true

83 self._buffer.append(f"__append__({op}.render({params}))") 

84 else: 

85 self._buffer.append(f"__append__(await {op}.arender({params}))") 

86 

87 @staticmethod 

88 def _make_context(text: str): 

89 """generate context parameter if specified otherwise use locals() by default""" 

90 

91 return f"locals() | dict({text[text.index(' ') + 1:]})" if " " in text else "locals()" 

92 

93 def compile(self, sync=True, indent_str="\t"): 

94 self._buffer = [] 

95 self._ops_stack = [] 

96 self._builder = get_base_builder(sync, indent_str) 

97 

98 for token in split_template_tokens(self.text): 

99 if not token: 

100 continue 

101 s_token = token.strip() 

102 if s_token.startswith("{{") and s_token.endswith("}}"): 

103 self._on_eval_token(token) 

104 elif s_token.startswith("{#") and s_token.endswith("#}"): 

105 self._on_exec_token(token) 

106 elif s_token.startswith("{%") and s_token.endswith("%}") and "\n" not in s_token: 

107 self._on_special_token(token, sync) 

108 else: 

109 self._on_literal_token(token) 

110 

111 if self._ops_stack: 

112 raise SyntaxError(self._ops_stack) 

113 

114 self._flush() 

115 self._builder.add_line("return ''.join(map(str, __parts__))") 

116 self._builder.dedent() 

117 

118 error_handling: Literal["linecache", "tempfile", "file"] = "file" if __debug__ else "tempfile" 

119 

120 def _patch_for_error_handling(self, sync: bool): 

121 match self.error_handling: 

122 case "linecache": 

123 add_linecache(self.name, partial(self.get_script, sync, "\t")) 

124 case "file" | "tempfile": 

125 file = save_tempfile(self.name, self.get_script(sync, "\t"), self.error_handling == "tempfile") 

126 sys_path.append(str(file.parent)) 

127 

128 @cached_property 

129 def _render_code(self): 

130 self.compile() 

131 return self._builder.get_render_function().__code__.replace(co_filename=self.name, co_name="render") 

132 

133 def render(self, context: Context) -> str: 

134 try: 

135 return eval(self._render_code, context) 

136 except Exception: 

137 self._patch_for_error_handling(sync=True) 

138 raise 

139 

140 @cached_property 

141 def _arender_code(self): 

142 self.compile(sync=False) 

143 return self._builder.get_render_function().__code__.replace(co_filename=self.name, co_name="arender") 

144 

145 async def arender(self, context: Context) -> str: 

146 try: 

147 return await eval(self._arender_code, context) 

148 except Exception: 

149 self._patch_for_error_handling(sync=False) 

150 raise 

151 

152 def get_script(self, sync=True, indent_str=" "): 

153 """compile template string into python script""" 

154 self.compile(sync, indent_str) 

155 return str(self._builder) 

156 

157 

158class Loader(AutoNaming): 

159 @classmethod 

160 def read(cls, path: str | Path, encoding="utf-8"): 

161 path = Path(path) 

162 obj = cls(path.read_text(encoding)) 

163 obj.name = path.stem 

164 return obj 

165 

166 @classmethod 

167 async def aread(cls, path: str | Path, encoding="utf-8"): 

168 from aiofiles import open 

169 

170 async with open(path, encoding=encoding) as f: 

171 content = await f.read() 

172 

173 path = Path(path) 

174 obj = cls(content) 

175 obj.name = path.stem 

176 return obj 

177 

178 @classmethod 

179 def _patch_kwargs(cls, kwargs: dict) -> dict: 

180 return {"headers": {"User-Agent": get_user_agent(cls)}} | kwargs 

181 

182 @staticmethod 

183 def _join_url(url: str): 

184 if url.startswith("http"): 

185 return url 

186 

187 from urllib.parse import urljoin 

188 

189 return urljoin("https://promplate.dev/", url) 

190 

191 @classmethod 

192 def fetch(cls, url: str, **kwargs): 

193 from .utils import _get_client 

194 

195 response = _get_client().get(cls._join_url(url), **cls._patch_kwargs(kwargs)) 

196 obj = cls(response.raise_for_status().text) 

197 obj.name = Path(url).stem 

198 return obj 

199 

200 @classmethod 

201 async def afetch(cls, url: str, **kwargs): 

202 from .utils import _get_aclient 

203 

204 response = await _get_aclient().get(cls._join_url(url), **cls._patch_kwargs(kwargs)) 

205 obj = cls(response.raise_for_status().text) 

206 obj.name = Path(url).stem 

207 return obj 

208 

209 

210class SafeChainMapContext(ChainMap, dict): 

211 if TYPE_CHECKING: # fix type from `collections.ChainMap` 211 ↛ 212line 211 didn't jump to line 212 because the condition on line 211 was never true

212 from sys import version_info 

213 

214 if version_info >= (3, 11): 

215 from typing_extensions import Self 

216 else: 

217 from typing import Self 

218 

219 copy: Callable[[Self], Self] 

220 

221 

222class Template(TemplateCore, Loader): 

223 def __init__(self, text: str, /, context: Context | None = None): 

224 super().__init__(text) 

225 self.context = {} if context is None else context 

226 

227 def render(self, context: Context | None = None): 

228 if context is None: 

229 context = SafeChainMapContext({}, self.context) 

230 else: 

231 context = SafeChainMapContext({}, context, self.context) 

232 

233 return super().render(context) 

234 

235 async def arender(self, context: Context | None = None): 

236 if context is None: 

237 context = SafeChainMapContext({}, self.context) 

238 else: 

239 context = SafeChainMapContext({}, context, self.context) 

240 

241 return await super().arender(context)