Skip to content

Commit 7ad2c2b

Browse files
committed
split main
1 parent a463ea6 commit 7ad2c2b

1 file changed

Lines changed: 183 additions & 166 deletions

File tree

packages/bigframes/scripts/generate_bigframes_bigquery.py

Lines changed: 183 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -101,181 +101,198 @@ def to_snake_case(name):
101101
return name
102102

103103

104-
def main():
104+
def load_templates():
105105
env = jinja2.Environment(
106106
loader=jinja2.FileSystemLoader(TEMPLATE_DIR),
107107
trim_blocks=True,
108108
lstrip_blocks=True,
109109
)
110-
template = env.get_template("operation.py.j2")
111-
test_template = env.get_template("test_operation.py.j2")
112-
113-
for yaml_file in DATA_DIR.glob("**/*.yaml"):
114-
print(f"Processing {yaml_file}...")
115-
with open(yaml_file, "r") as f:
116-
data = yaml.safe_load(f)
117-
118-
rel_path = yaml_file.relative_to(DATA_DIR)
119-
module_path = rel_path.with_suffix("")
120-
module_name = module_path.name
121-
output_file = OUTPUT_DIR.joinpath(module_path).with_suffix(".py")
122-
123-
ops_list = []
124-
functions_list = []
125-
126-
if "scalar_functions" in data:
127-
for func_data in data["scalar_functions"]:
128-
sql_name = func_data["name"]
129-
python_name = to_snake_case(sql_name)
130-
if python_name.startswith(module_name + "_"):
131-
python_name = python_name[len(module_name) + 1 :]
132-
133-
internal_op_name = f"_{python_name.upper()}_OP"
134-
135-
# Aggregate args across impls
136-
args_by_name = {}
137-
arg_order = []
138-
for impl in func_data["impls"]:
139-
for arg in impl["args"]:
140-
name = arg["name"]
141-
if name not in args_by_name:
142-
args_by_name[name] = {
143-
"types": set(),
144-
"optional": arg["optional"],
145-
"keyword_only": arg["keyword_only"],
146-
}
147-
arg_order.append(name)
148-
args_by_name[name]["types"].add(arg["value"])
149-
150-
# Build ArgSpecs
151-
arg_specs = []
152-
for name in arg_order:
153-
arg_info = args_by_name[name]
154-
spec = "googlesql.ArgSpec("
155-
if arg_info["keyword_only"]:
156-
spec += f'arg_name="{name}", '
157-
if arg_info["optional"]:
158-
spec += "optional=True, "
159-
spec = spec.rstrip(", ") + ")"
160-
arg_specs.append(spec)
161-
162-
# Determine return dtype
163-
return_types = {impl["return"] for impl in func_data["impls"]}
164-
if len(return_types) == 1:
165-
ret_type = list(return_types)[0]
166-
signature = f"lambda *args: {DTYPE_MAP.get(ret_type, 'None')}"
167-
else:
168-
# Fallback to Any/None if ambiguous
169-
signature = "lambda *args: None"
170-
171-
ops_list.append(
172-
{
173-
"internal_name": internal_op_name,
174-
"sql_name": sql_name.upper(),
175-
"arg_specs": ", ".join(arg_specs),
176-
"signature": signature,
177-
}
178-
)
179-
180-
# Function args
181-
func_args = []
182-
for name in arg_order:
183-
arg_info = args_by_name[name]
184-
types = [PY_TYPE_MAP.get(t, "Any") for t in arg_info["types"]] + ["Literal[sentinels.Sentinel.ARGUMENT_DEFAULT]"]
185-
type_hint = (
186-
"Union[" + ", ".join(sorted(set(types))) + "]"
187-
if len(types) > 1
188-
else types[0]
189-
)
190-
default = "sentinels.Sentinel.ARGUMENT_DEFAULT" if arg_info["optional"] else ""
191-
func_args.append(
192-
{
193-
"name": name,
194-
"type_hint": type_hint,
195-
"default": default,
196-
}
197-
)
198-
199-
# Clean up default values for mandatory args
200-
# In Python, mandatory args come first.
201-
for arg in func_args:
202-
if not arg["default"]:
203-
del arg["default"]
204-
205-
# Test args
206-
test_args = []
207-
for name in arg_order:
208-
arg_info = args_by_name[name]
209-
some_type = list(arg_info["types"])[0]
210-
col_name = YAML_TYPE_TO_COL.get(some_type, "string_col")
211-
test_args.append({"col_name": col_name})
212-
213-
functions_list.append(
214-
{
215-
"name": python_name,
216-
"op_name": internal_op_name,
217-
"description": func_data["description"],
218-
"args": func_args,
219-
"test_args": test_args,
110+
return env.get_template("operation.py.j2"), env.get_template("test_operation.py.j2")
111+
112+
113+
def parse_scalar_functions(data, module_name):
114+
ops_list = []
115+
functions_list = []
116+
117+
if "scalar_functions" not in data:
118+
return ops_list, functions_list
119+
120+
for func_data in data["scalar_functions"]:
121+
sql_name = func_data["name"]
122+
python_name = to_snake_case(sql_name)
123+
if python_name.startswith(module_name + "_"):
124+
python_name = python_name[len(module_name) + 1 :]
125+
126+
internal_op_name = f"_{python_name.upper()}_OP"
127+
128+
# Aggregate args across impls
129+
args_by_name = {}
130+
arg_order = []
131+
for impl in func_data["impls"]:
132+
for arg in impl["args"]:
133+
name = arg["name"]
134+
if name not in args_by_name:
135+
args_by_name[name] = {
136+
"types": set(),
137+
"optional": arg["optional"],
138+
"keyword_only": arg["keyword_only"],
220139
}
221-
)
222-
223-
# Render and write
224-
output_file.parent.mkdir(parents=True, exist_ok=True)
225-
content = template.render(
226-
yaml_path=str(yaml_file),
227-
script_path="scripts/generate_bigframes_bigquery.py",
228-
ops=ops_list,
229-
functions=functions_list,
230-
)
231-
with open(output_file, "w") as f:
232-
f.write(content)
233-
234-
subprocess.run(
235-
RUFF_ARGS
236-
+ [
237-
str(output_file),
238-
],
239-
check=True,
240-
)
241-
print(f" Generated {output_file}")
242-
243-
# Render and write test
244-
import_path = "bigframes.bigquery._operations." + ".".join(module_path.parts)
245-
test_output_file = TEST_OUTPUT_DIR.joinpath(
246-
module_path.with_name(f"test_{module_path.name}")
247-
).with_suffix(".py")
248-
249-
test_output_file.parent.mkdir(parents=True, exist_ok=True)
250-
test_content = test_template.render(
251-
yaml_path=str(yaml_file),
252-
script_path="scripts/generate_bigframes_bigquery.py",
253-
import_path=import_path,
254-
short_name=module_path.name,
255-
functions=functions_list,
256-
)
257-
with open(test_output_file, "w") as f:
258-
f.write(test_content)
259-
260-
subprocess.run(
261-
RUFF_ARGS
262-
+ [
263-
str(test_output_file),
264-
],
265-
check=True,
140+
arg_order.append(name)
141+
args_by_name[name]["types"].add(arg["value"])
142+
143+
# Build ArgSpecs
144+
arg_specs = []
145+
for name in arg_order:
146+
arg_info = args_by_name[name]
147+
spec = "googlesql.ArgSpec("
148+
if arg_info["keyword_only"]:
149+
spec += f'arg_name="{name}", '
150+
if arg_info["optional"]:
151+
spec += "optional=True, "
152+
spec = spec.rstrip(", ") + ")"
153+
arg_specs.append(spec)
154+
155+
# Determine return dtype
156+
return_types = {impl["return"] for impl in func_data["impls"]}
157+
if len(return_types) == 1:
158+
ret_type = list(return_types)[0]
159+
signature = f"lambda *args: {DTYPE_MAP.get(ret_type, 'None')}"
160+
else:
161+
# Fallback to Any/None if ambiguous
162+
signature = "lambda *args: None"
163+
164+
ops_list.append(
165+
{
166+
"internal_name": internal_op_name,
167+
"sql_name": sql_name.upper(),
168+
"arg_specs": ", ".join(arg_specs),
169+
"signature": signature,
170+
}
266171
)
267-
print(f" Generated {test_output_file}")
268-
269-
print(f" Updating snapshots for {test_output_file}...")
270-
subprocess.run(
271-
[
272-
"pytest",
273-
str(test_output_file),
274-
"--snapshot-update",
275-
],
276-
check=False,
172+
173+
# Function args
174+
func_args = []
175+
for name in arg_order:
176+
arg_info = args_by_name[name]
177+
types = [PY_TYPE_MAP.get(t, "Any") for t in arg_info["types"]] + [
178+
"Literal[sentinels.Sentinel.ARGUMENT_DEFAULT]"
179+
]
180+
type_hint = (
181+
"Union[" + ", ".join(sorted(set(types))) + "]"
182+
if len(types) > 1
183+
else types[0]
184+
)
185+
default = (
186+
"sentinels.Sentinel.ARGUMENT_DEFAULT" if arg_info["optional"] else ""
187+
)
188+
func_args.append(
189+
{
190+
"name": name,
191+
"type_hint": type_hint,
192+
"default": default,
193+
}
194+
)
195+
196+
# Clean up default values for mandatory args
197+
# In Python, mandatory args come first.
198+
for arg in func_args:
199+
if not arg.get("default"):
200+
arg.pop("default", None)
201+
202+
# Test args
203+
test_args = []
204+
for name in arg_order:
205+
arg_info = args_by_name[name]
206+
some_type = list(arg_info["types"])[0]
207+
col_name = YAML_TYPE_TO_COL.get(some_type, "string_col")
208+
test_args.append({"col_name": col_name})
209+
210+
functions_list.append(
211+
{
212+
"name": python_name,
213+
"op_name": internal_op_name,
214+
"description": func_data["description"],
215+
"args": func_args,
216+
"test_args": test_args,
217+
}
277218
)
278219

220+
return ops_list, functions_list
221+
222+
223+
def run_ruff(path: pathlib.Path):
224+
subprocess.run(
225+
RUFF_ARGS
226+
+ [
227+
str(path),
228+
],
229+
check=True,
230+
)
231+
232+
233+
def process_yaml_file(yaml_file, template, test_template):
234+
print(f"Processing {yaml_file}...")
235+
with open(yaml_file, "r") as f:
236+
data = yaml.safe_load(f)
237+
238+
rel_path = yaml_file.relative_to(DATA_DIR)
239+
module_path = rel_path.with_suffix("")
240+
module_name = module_path.name
241+
output_file = OUTPUT_DIR.joinpath(module_path).with_suffix(".py")
242+
243+
ops_list, functions_list = parse_scalar_functions(data, module_name)
244+
245+
# Render and write
246+
output_file.parent.mkdir(parents=True, exist_ok=True)
247+
content = template.render(
248+
yaml_path=str(yaml_file),
249+
script_path="scripts/generate_bigframes_bigquery.py",
250+
ops=ops_list,
251+
functions=functions_list,
252+
)
253+
with open(output_file, "w") as f:
254+
f.write(content)
255+
256+
run_ruff(output_file)
257+
print(f" Generated {output_file}")
258+
259+
# Render and write test
260+
import_path = "bigframes.bigquery._operations." + ".".join(module_path.parts)
261+
test_output_file = TEST_OUTPUT_DIR.joinpath(
262+
module_path.with_name(f"test_{module_path.name}")
263+
).with_suffix(".py")
264+
265+
test_output_file.parent.mkdir(parents=True, exist_ok=True)
266+
test_content = test_template.render(
267+
yaml_path=str(yaml_file),
268+
script_path="scripts/generate_bigframes_bigquery.py",
269+
import_path=import_path,
270+
short_name=module_path.name,
271+
functions=functions_list,
272+
)
273+
with open(test_output_file, "w") as f:
274+
f.write(test_content)
275+
276+
run_ruff(test_output_file)
277+
print(f" Generated {test_output_file}")
278+
279+
print(f" Updating snapshots for {test_output_file}...")
280+
subprocess.run(
281+
[
282+
"pytest",
283+
str(test_output_file),
284+
"--snapshot-update",
285+
],
286+
check=False,
287+
)
288+
289+
290+
def main():
291+
template, test_template = load_templates()
292+
293+
for yaml_file in DATA_DIR.glob("**/*.yaml"):
294+
process_yaml_file(yaml_file, template, test_template)
295+
279296

280297
if __name__ == "__main__":
281298
main()

0 commit comments

Comments
 (0)