Coverage for promplate/prompt/template.py: 80%
176 statements
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-09 22:54 +0800
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-09 22:54 +0800
1from ast import Expr, parse, unparse
2from collections import ChainMap
3from functools import cached_property, partial
4from pathlib import Path
5from sys import path as sys_path
6from sys import version_info
7from textwrap import dedent
8from typing import TYPE_CHECKING, Any, Literal, Protocol
10from .builder import *
11from .utils import *
13Context = dict[str, Any] # globals must be a real dict
16class Component(Protocol):
17 def render(self, context: Context) -> str: ... 17 ↛ exitline 17 didn't return from function 'render' because
19 async def arender(self, context: Context) -> str: ... 19 ↛ exitline 19 didn't return from function 'arender' because
22class TemplateCore(AutoNaming):
23 """A simple template compiler, for a jinja2-like syntax."""
25 def __init__(self, text: str):
26 """Construct a Templite with the given `text`."""
28 self.text = text
30 def _flush(self):
31 for line in self._buffer:
32 self._builder.add_line(line)
33 self._buffer.clear()
35 @staticmethod
36 def _unwrap_token(token: str):
37 return dedent(token.strip()[2:-2].strip("-")).strip()
39 def _on_literal_token(self, token: str):
40 self._buffer.append(f"__append__({repr(token)})")
42 def _on_eval_token(self, token):
43 token = self._unwrap_token(token)
44 if "\n" in token:
45 mod = parse(token)
46 [*rest, last] = mod.body
47 assert isinstance(last, Expr), "{{ }} block must end with an expression, or you should use {# #} block"
48 self._buffer.extend(unparse(rest).splitlines()) # type: ignore
49 exp = unparse(last)
50 else:
51 exp = token
52 self._buffer.append(f"__append__({exp})")
54 def _on_exec_token(self, token):
55 self._buffer.extend(self._unwrap_token(token).splitlines())
57 def _on_special_token(self, token, sync: bool):
58 inner = self._unwrap_token(token)
60 if inner.startswith("end"):
61 last = self._ops_stack.pop()
62 assert last == inner.removeprefix("end")
63 self._flush()
64 self._builder.dedent()
66 else:
67 op = inner.split(" ", 1)[0]
69 if op == "if" or op == "for" or op == "while":
70 self._ops_stack.append(op)
71 self._flush()
72 self._builder.add_line(f"{inner}:")
73 self._builder.indent()
75 elif op == "else" or op == "elif":
76 self._flush()
77 self._builder.dedent()
78 self._builder.add_line(f"{inner}:")
79 self._builder.indent()
81 else:
82 params: str = self._make_context(inner)
83 if sync: 83 ↛ 86line 83 didn't jump to line 86 because the condition on line 83 was always true
84 self._buffer.append(f"__append__({op}.render({params}))")
85 else:
86 self._buffer.append(f"__append__(await {op}.arender({params}))")
88 @staticmethod
89 def _make_context(text: str):
90 """generate context parameter if specified otherwise use locals() by default"""
91 if version_info >= (3, 13): 91 ↛ 92line 91 didn't jump to line 92 because the condition on line 91 was never true
92 return f"globals() | locals() | dict({text[text.index(' ') + 1:]})" if " " in text else "globals() | locals()"
93 return f"locals() | dict({text[text.index(' ') + 1:]})" if " " in text else "locals()"
95 def compile(self, sync=True, indent_str="\t"):
96 self._buffer = []
97 self._ops_stack = []
98 self._builder = get_base_builder(sync, indent_str)
100 for token in split_template_tokens(self.text):
101 if not token:
102 continue
103 s_token = token.strip()
104 if s_token.startswith("{{") and s_token.endswith("}}"):
105 self._on_eval_token(token)
106 elif s_token.startswith("{#") and s_token.endswith("#}"):
107 self._on_exec_token(token)
108 elif s_token.startswith("{%") and s_token.endswith("%}") and "\n" not in s_token:
109 self._on_special_token(token, sync)
110 else:
111 self._on_literal_token(token)
113 if self._ops_stack:
114 raise SyntaxError(self._ops_stack)
116 self._flush()
117 self._builder.add_line("return ''.join(map(str, __parts__))")
118 self._builder.dedent()
120 error_handling: Literal["linecache", "tempfile", "file"] = "file" if __debug__ else "tempfile"
122 def _patch_for_error_handling(self, sync: bool):
123 match self.error_handling:
124 case "linecache":
125 add_linecache(self.name, partial(self.get_script, sync, "\t"))
126 case "file" | "tempfile":
127 file = save_tempfile(self.name, self.get_script(sync, "\t"), self.error_handling == "tempfile")
128 sys_path.append(str(file.parent))
130 @cached_property
131 def _render_code(self):
132 self.compile()
133 return self._builder.get_render_function().__code__.replace(co_filename=self.name, co_name="render")
135 def render(self, context: Context) -> str:
136 try:
137 return eval(self._render_code, context)
138 except Exception:
139 self._patch_for_error_handling(sync=True)
140 raise
142 @cached_property
143 def _arender_code(self):
144 self.compile(sync=False)
145 return self._builder.get_render_function().__code__.replace(co_filename=self.name, co_name="arender")
147 async def arender(self, context: Context) -> str:
148 try:
149 return await eval(self._arender_code, context)
150 except Exception:
151 self._patch_for_error_handling(sync=False)
152 raise
154 def get_script(self, sync=True, indent_str=" "):
155 """compile template string into python script"""
156 self.compile(sync, indent_str)
157 return str(self._builder)
160class Loader(AutoNaming):
161 @classmethod
162 def read(cls, path: str | Path, encoding="utf-8"):
163 path = Path(path)
164 obj = cls(path.read_text(encoding))
165 obj.name = path.stem
166 return obj
168 @classmethod
169 async def aread(cls, path: str | Path, encoding="utf-8"):
170 from aiofiles import open
172 async with open(path, encoding=encoding) as f:
173 content = await f.read()
175 path = Path(path)
176 obj = cls(content)
177 obj.name = path.stem
178 return obj
180 @classmethod
181 def _patch_kwargs(cls, kwargs: dict) -> dict:
182 return {"headers": {"User-Agent": get_user_agent(cls)}} | kwargs
184 @staticmethod
185 def _join_url(url: str):
186 if url.startswith("http"):
187 return url
189 from urllib.parse import urljoin
191 return urljoin("https://promplate.dev/", url)
193 @classmethod
194 def fetch(cls, url: str, **kwargs):
195 from .utils import _get_client
197 response = _get_client().get(cls._join_url(url), **cls._patch_kwargs(kwargs))
198 obj = cls(response.raise_for_status().text)
199 obj.name = Path(url).stem
200 return obj
202 @classmethod
203 async def afetch(cls, url: str, **kwargs):
204 from .utils import _get_aclient
206 response = await _get_aclient().get(cls._join_url(url), **cls._patch_kwargs(kwargs))
207 obj = cls(response.raise_for_status().text)
208 obj.name = Path(url).stem
209 return obj
212class SafeChainMapContext(ChainMap, dict):
213 if TYPE_CHECKING: # fix type from `collections.ChainMap` 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true
214 from sys import version_info
216 if version_info >= (3, 11):
217 from typing_extensions import Self
218 else:
219 from typing import Self
221 copy: Callable[[Self], Self]
224class Template(TemplateCore, Loader):
225 def __init__(self, text: str, /, context: Context | None = None):
226 super().__init__(text)
227 self.context = {} if context is None else context
229 def render(self, context: Context | None = None):
230 if context is None:
231 context = SafeChainMapContext({}, self.context)
232 else:
233 context = SafeChainMapContext({}, context, self.context)
235 return super().render(context)
237 async def arender(self, context: Context | None = None):
238 if context is None:
239 context = SafeChainMapContext({}, self.context)
240 else:
241 context = SafeChainMapContext({}, context, self.context)
243 return await super().arender(context)