3535OUTPUT_DIR = pathlib .Path ("bigframes/bigquery/_operations" )
3636# Directory where the generated test files will be placed
3737TEST_OUTPUT_DIR = pathlib .Path ("tests/unit/bigquery/_operations" )
38+ # Directory containing the Jinja2 templates
39+ TEMPLATE_DIR = pathlib .Path ("scripts/templates" )
3840
3941RUFF_ARGS = [
4042 "ruff" ,
4648 "--line-length=88" ,
4749]
4850
49- LICENSE_HEADER = """# Copyright 2026 Google LLC
50- #
51- # Licensed under the Apache License, Version 2.0 (the "License");
52- # you may not use this file except in compliance with the License.
53- # You may obtain a copy of the License at
54- #
55- # http://www.apache.org/licenses/LICENSE-2.0
56- #
57- # Unless required by applicable law or agreed to in writing, software
58- # distributed under the License is distributed on an "AS IS" BASIS,
59- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
60- # See the License for the specific language governing permissions and
61- # limitations under the License.
62- """
63-
64- TEMPLATE = """{{ license_header }}
65- #
66- # DO NOT MODIFY THIS FILE DIRECTLY.
67- # This file was generated from: {{ yaml_path }}
68- # by the script: {{ script_path }}
69-
70- from __future__ import annotations
71-
72- import datetime
73- from typing import Any, Optional, TypeVar, Union
74-
75- from bigframes import dtypes
76- import bigframes.bigquery._googlesql
77- import bigframes.core.col
78- import bigframes.core.expression as ex
79- import bigframes.core.sentinels as sentinels
80- from bigframes.operations import googlesql
81- import bigframes.operations as ops
82- import bigframes.series as series
83-
84- T = TypeVar("T", series.Series, bigframes.core.col.Expression)
85-
86- {% for op in ops %}
87- {{ op.internal_name }} = googlesql.GoogleSqlScalarOp(
88- "{{ op.sql_name }}",
89- args=({{ op.arg_specs }}),
90- signature={{ op.signature }},
91- )
92- {% endfor %}
93- {% for func in functions %}
94-
95-
96- def {{ func.name }}(
97- {% for arg in func.args %}
98- {{ arg.name }}: Union[T, bigframes.core.col.Expression, {{ arg.type_hint }}]{% if arg.default %} = {{ arg.default }}{% endif %},
99- {% endfor %}
100- ) -> T:
101- \" \" \" {{ func.description }}\" \" \"
102- return bigframes.bigquery._googlesql.apply_googlesql_scalar_op(
103- {{ func.op_name }},
104- {% for arg in func.args %}
105- {{ arg.name }},
106- {% endfor %}
107- ) # type: ignore
108- {% endfor %}
109- """
110-
111- TEST_TEMPLATE = r"""{{ license_header }}
112- #
113- # DO NOT MODIFY THIS FILE DIRECTLY.
114- # This file was generated from: {{ yaml_path }}
115- # by the script: {{ script_path }}
116-
117- from typing import cast
118-
119- import pytest
120-
121- import bigframes.pandas as bpd
122- import {{ import_path }} as {{ short_name }}
123-
124- pytest.importorskip("pytest_snapshot")
125-
126-
127- {% for func in functions %}
128- def test_{{ func.name }}(scalar_types_df: bpd.DataFrame, snapshot):
129- result = {{ short_name }}.{{ func.name }}(
130- {% for arg in func.test_args %}
131- cast(bpd.Series, scalar_types_df["{{ arg.col_name }}"]),
132- {% endfor %}
133- ).to_frame()
134- snapshot.assert_match(result.sql.rstrip() + "\n", "out.sql")
135-
136-
137- {% endfor %}
138- """
139-
14051DTYPE_MAP = {
14152 "binary" : "dtypes.BYTES_DTYPE" ,
14253 "string" : "dtypes.STRING_DTYPE" ,
@@ -191,9 +102,13 @@ def to_snake_case(name):
191102
192103
193104def main ():
194- env = jinja2 .Environment (trim_blocks = True , lstrip_blocks = True )
195- template = env .from_string (TEMPLATE )
196- test_template = env .from_string (TEST_TEMPLATE )
105+ env = jinja2 .Environment (
106+ loader = jinja2 .FileSystemLoader (TEMPLATE_DIR ),
107+ trim_blocks = True ,
108+ lstrip_blocks = True ,
109+ )
110+ template = env .get_template ("operation.py.j2" )
111+ test_template = env .get_template ("test_operation.py.j2" )
197112
198113 for yaml_file in DATA_DIR .glob ("**/*.yaml" ):
199114 print (f"Processing { yaml_file } ..." )
@@ -266,13 +181,13 @@ def main():
266181 func_args = []
267182 for name in arg_order :
268183 arg_info = args_by_name [name ]
269- types = [PY_TYPE_MAP .get (t , "Any" ) for t in arg_info ["types" ]]
184+ types = [PY_TYPE_MAP .get (t , "Any" ) for t in arg_info ["types" ]] + [ "Literal[sentinels.Sentinel.ARGUMENT_DEFAULT]" ]
270185 type_hint = (
271186 "Union[" + ", " .join (sorted (set (types ))) + "]"
272187 if len (types ) > 1
273188 else types [0 ]
274189 )
275- default = "sentinels.DEFAULT " if arg_info ["optional" ] else ""
190+ default = "sentinels.Sentinel.ARGUMENT_DEFAULT " if arg_info ["optional" ] else ""
276191 func_args .append (
277192 {
278193 "name" : name ,
@@ -308,7 +223,6 @@ def main():
308223 # Render and write
309224 output_file .parent .mkdir (parents = True , exist_ok = True )
310225 content = template .render (
311- license_header = LICENSE_HEADER ,
312226 yaml_path = str (yaml_file ),
313227 script_path = "scripts/generate_bigframes_bigquery.py" ,
314228 ops = ops_list ,
@@ -334,7 +248,6 @@ def main():
334248
335249 test_output_file .parent .mkdir (parents = True , exist_ok = True )
336250 test_content = test_template .render (
337- license_header = LICENSE_HEADER ,
338251 yaml_path = str (yaml_file ),
339252 script_path = "scripts/generate_bigframes_bigquery.py" ,
340253 import_path = import_path ,
0 commit comments