@@ -101,181 +101,198 @@ def to_snake_case(name):
101101 return name
102102
103103
104- def main ():
104+ def load_templates ():
105105 env = jinja2 .Environment (
106106 loader = jinja2 .FileSystemLoader (TEMPLATE_DIR ),
107107 trim_blocks = True ,
108108 lstrip_blocks = True ,
109109 )
110- template = env .get_template ("operation.py.j2" )
111- test_template = env .get_template ("test_operation.py.j2" )
112-
113- for yaml_file in DATA_DIR .glob ("**/*.yaml" ):
114- print (f"Processing { yaml_file } ..." )
115- with open (yaml_file , "r" ) as f :
116- data = yaml .safe_load (f )
117-
118- rel_path = yaml_file .relative_to (DATA_DIR )
119- module_path = rel_path .with_suffix ("" )
120- module_name = module_path .name
121- output_file = OUTPUT_DIR .joinpath (module_path ).with_suffix (".py" )
122-
123- ops_list = []
124- functions_list = []
125-
126- if "scalar_functions" in data :
127- for func_data in data ["scalar_functions" ]:
128- sql_name = func_data ["name" ]
129- python_name = to_snake_case (sql_name )
130- if python_name .startswith (module_name + "_" ):
131- python_name = python_name [len (module_name ) + 1 :]
132-
133- internal_op_name = f"_{ python_name .upper ()} _OP"
134-
135- # Aggregate args across impls
136- args_by_name = {}
137- arg_order = []
138- for impl in func_data ["impls" ]:
139- for arg in impl ["args" ]:
140- name = arg ["name" ]
141- if name not in args_by_name :
142- args_by_name [name ] = {
143- "types" : set (),
144- "optional" : arg ["optional" ],
145- "keyword_only" : arg ["keyword_only" ],
146- }
147- arg_order .append (name )
148- args_by_name [name ]["types" ].add (arg ["value" ])
149-
150- # Build ArgSpecs
151- arg_specs = []
152- for name in arg_order :
153- arg_info = args_by_name [name ]
154- spec = "googlesql.ArgSpec("
155- if arg_info ["keyword_only" ]:
156- spec += f'arg_name="{ name } ", '
157- if arg_info ["optional" ]:
158- spec += "optional=True, "
159- spec = spec .rstrip (", " ) + ")"
160- arg_specs .append (spec )
161-
162- # Determine return dtype
163- return_types = {impl ["return" ] for impl in func_data ["impls" ]}
164- if len (return_types ) == 1 :
165- ret_type = list (return_types )[0 ]
166- signature = f"lambda *args: { DTYPE_MAP .get (ret_type , 'None' )} "
167- else :
168- # Fallback to Any/None if ambiguous
169- signature = "lambda *args: None"
170-
171- ops_list .append (
172- {
173- "internal_name" : internal_op_name ,
174- "sql_name" : sql_name .upper (),
175- "arg_specs" : ", " .join (arg_specs ),
176- "signature" : signature ,
177- }
178- )
179-
180- # Function args
181- func_args = []
182- for name in arg_order :
183- arg_info = args_by_name [name ]
184- types = [PY_TYPE_MAP .get (t , "Any" ) for t in arg_info ["types" ]] + ["Literal[sentinels.Sentinel.ARGUMENT_DEFAULT]" ]
185- type_hint = (
186- "Union[" + ", " .join (sorted (set (types ))) + "]"
187- if len (types ) > 1
188- else types [0 ]
189- )
190- default = "sentinels.Sentinel.ARGUMENT_DEFAULT" if arg_info ["optional" ] else ""
191- func_args .append (
192- {
193- "name" : name ,
194- "type_hint" : type_hint ,
195- "default" : default ,
196- }
197- )
198-
199- # Clean up default values for mandatory args
200- # In Python, mandatory args come first.
201- for arg in func_args :
202- if not arg ["default" ]:
203- del arg ["default" ]
204-
205- # Test args
206- test_args = []
207- for name in arg_order :
208- arg_info = args_by_name [name ]
209- some_type = list (arg_info ["types" ])[0 ]
210- col_name = YAML_TYPE_TO_COL .get (some_type , "string_col" )
211- test_args .append ({"col_name" : col_name })
212-
213- functions_list .append (
214- {
215- "name" : python_name ,
216- "op_name" : internal_op_name ,
217- "description" : func_data ["description" ],
218- "args" : func_args ,
219- "test_args" : test_args ,
110+ return env .get_template ("operation.py.j2" ), env .get_template ("test_operation.py.j2" )
111+
112+
113+ def parse_scalar_functions (data , module_name ):
114+ ops_list = []
115+ functions_list = []
116+
117+ if "scalar_functions" not in data :
118+ return ops_list , functions_list
119+
120+ for func_data in data ["scalar_functions" ]:
121+ sql_name = func_data ["name" ]
122+ python_name = to_snake_case (sql_name )
123+ if python_name .startswith (module_name + "_" ):
124+ python_name = python_name [len (module_name ) + 1 :]
125+
126+ internal_op_name = f"_{ python_name .upper ()} _OP"
127+
128+ # Aggregate args across impls
129+ args_by_name = {}
130+ arg_order = []
131+ for impl in func_data ["impls" ]:
132+ for arg in impl ["args" ]:
133+ name = arg ["name" ]
134+ if name not in args_by_name :
135+ args_by_name [name ] = {
136+ "types" : set (),
137+ "optional" : arg ["optional" ],
138+ "keyword_only" : arg ["keyword_only" ],
220139 }
221- )
222-
223- # Render and write
224- output_file .parent .mkdir (parents = True , exist_ok = True )
225- content = template .render (
226- yaml_path = str (yaml_file ),
227- script_path = "scripts/generate_bigframes_bigquery.py" ,
228- ops = ops_list ,
229- functions = functions_list ,
230- )
231- with open (output_file , "w" ) as f :
232- f .write (content )
233-
234- subprocess .run (
235- RUFF_ARGS
236- + [
237- str (output_file ),
238- ],
239- check = True ,
240- )
241- print (f" Generated { output_file } " )
242-
243- # Render and write test
244- import_path = "bigframes.bigquery._operations." + "." .join (module_path .parts )
245- test_output_file = TEST_OUTPUT_DIR .joinpath (
246- module_path .with_name (f"test_{ module_path .name } " )
247- ).with_suffix (".py" )
248-
249- test_output_file .parent .mkdir (parents = True , exist_ok = True )
250- test_content = test_template .render (
251- yaml_path = str (yaml_file ),
252- script_path = "scripts/generate_bigframes_bigquery.py" ,
253- import_path = import_path ,
254- short_name = module_path .name ,
255- functions = functions_list ,
256- )
257- with open (test_output_file , "w" ) as f :
258- f .write (test_content )
259-
260- subprocess .run (
261- RUFF_ARGS
262- + [
263- str (test_output_file ),
264- ],
265- check = True ,
140+ arg_order .append (name )
141+ args_by_name [name ]["types" ].add (arg ["value" ])
142+
143+ # Build ArgSpecs
144+ arg_specs = []
145+ for name in arg_order :
146+ arg_info = args_by_name [name ]
147+ spec = "googlesql.ArgSpec("
148+ if arg_info ["keyword_only" ]:
149+ spec += f'arg_name="{ name } ", '
150+ if arg_info ["optional" ]:
151+ spec += "optional=True, "
152+ spec = spec .rstrip (", " ) + ")"
153+ arg_specs .append (spec )
154+
155+ # Determine return dtype
156+ return_types = {impl ["return" ] for impl in func_data ["impls" ]}
157+ if len (return_types ) == 1 :
158+ ret_type = list (return_types )[0 ]
159+ signature = f"lambda *args: { DTYPE_MAP .get (ret_type , 'None' )} "
160+ else :
161+ # Fallback to Any/None if ambiguous
162+ signature = "lambda *args: None"
163+
164+ ops_list .append (
165+ {
166+ "internal_name" : internal_op_name ,
167+ "sql_name" : sql_name .upper (),
168+ "arg_specs" : ", " .join (arg_specs ),
169+ "signature" : signature ,
170+ }
266171 )
267- print (f" Generated { test_output_file } " )
268-
269- print (f" Updating snapshots for { test_output_file } ..." )
270- subprocess .run (
271- [
272- "pytest" ,
273- str (test_output_file ),
274- "--snapshot-update" ,
275- ],
276- check = False ,
172+
173+ # Function args
174+ func_args = []
175+ for name in arg_order :
176+ arg_info = args_by_name [name ]
177+ types = [PY_TYPE_MAP .get (t , "Any" ) for t in arg_info ["types" ]] + [
178+ "Literal[sentinels.Sentinel.ARGUMENT_DEFAULT]"
179+ ]
180+ type_hint = (
181+ "Union[" + ", " .join (sorted (set (types ))) + "]"
182+ if len (types ) > 1
183+ else types [0 ]
184+ )
185+ default = (
186+ "sentinels.Sentinel.ARGUMENT_DEFAULT" if arg_info ["optional" ] else ""
187+ )
188+ func_args .append (
189+ {
190+ "name" : name ,
191+ "type_hint" : type_hint ,
192+ "default" : default ,
193+ }
194+ )
195+
196+ # Clean up default values for mandatory args
197+ # In Python, mandatory args come first.
198+ for arg in func_args :
199+ if not arg .get ("default" ):
200+ arg .pop ("default" , None )
201+
202+ # Test args
203+ test_args = []
204+ for name in arg_order :
205+ arg_info = args_by_name [name ]
206+ some_type = list (arg_info ["types" ])[0 ]
207+ col_name = YAML_TYPE_TO_COL .get (some_type , "string_col" )
208+ test_args .append ({"col_name" : col_name })
209+
210+ functions_list .append (
211+ {
212+ "name" : python_name ,
213+ "op_name" : internal_op_name ,
214+ "description" : func_data ["description" ],
215+ "args" : func_args ,
216+ "test_args" : test_args ,
217+ }
277218 )
278219
220+ return ops_list , functions_list
221+
222+
223+ def run_ruff (path : pathlib .Path ):
224+ subprocess .run (
225+ RUFF_ARGS
226+ + [
227+ str (path ),
228+ ],
229+ check = True ,
230+ )
231+
232+
233+ def process_yaml_file (yaml_file , template , test_template ):
234+ print (f"Processing { yaml_file } ..." )
235+ with open (yaml_file , "r" ) as f :
236+ data = yaml .safe_load (f )
237+
238+ rel_path = yaml_file .relative_to (DATA_DIR )
239+ module_path = rel_path .with_suffix ("" )
240+ module_name = module_path .name
241+ output_file = OUTPUT_DIR .joinpath (module_path ).with_suffix (".py" )
242+
243+ ops_list , functions_list = parse_scalar_functions (data , module_name )
244+
245+ # Render and write
246+ output_file .parent .mkdir (parents = True , exist_ok = True )
247+ content = template .render (
248+ yaml_path = str (yaml_file ),
249+ script_path = "scripts/generate_bigframes_bigquery.py" ,
250+ ops = ops_list ,
251+ functions = functions_list ,
252+ )
253+ with open (output_file , "w" ) as f :
254+ f .write (content )
255+
256+ run_ruff (output_file )
257+ print (f" Generated { output_file } " )
258+
259+ # Render and write test
260+ import_path = "bigframes.bigquery._operations." + "." .join (module_path .parts )
261+ test_output_file = TEST_OUTPUT_DIR .joinpath (
262+ module_path .with_name (f"test_{ module_path .name } " )
263+ ).with_suffix (".py" )
264+
265+ test_output_file .parent .mkdir (parents = True , exist_ok = True )
266+ test_content = test_template .render (
267+ yaml_path = str (yaml_file ),
268+ script_path = "scripts/generate_bigframes_bigquery.py" ,
269+ import_path = import_path ,
270+ short_name = module_path .name ,
271+ functions = functions_list ,
272+ )
273+ with open (test_output_file , "w" ) as f :
274+ f .write (test_content )
275+
276+ run_ruff (test_output_file )
277+ print (f" Generated { test_output_file } " )
278+
279+ print (f" Updating snapshots for { test_output_file } ..." )
280+ subprocess .run (
281+ [
282+ "pytest" ,
283+ str (test_output_file ),
284+ "--snapshot-update" ,
285+ ],
286+ check = False ,
287+ )
288+
289+
290+ def main ():
291+ template , test_template = load_templates ()
292+
293+ for yaml_file in DATA_DIR .glob ("**/*.yaml" ):
294+ process_yaml_file (yaml_file , template , test_template )
295+
279296
280297if __name__ == "__main__" :
281298 main ()
0 commit comments