Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
322 changes: 286 additions & 36 deletions src/ecooptimizer/refactorers/concrete/long_parameter_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,53 +260,195 @@ def update_parameter_usages(
):
"""
Updates the function body to use encapsulated parameter objects.
This method transforms parameter references in the function body to use new data_params
and config_params objects.

Args:
function_node: CST node of the function to transform
classified_params: Dictionary mapping parameter groups ('data_params' or 'config_params')
to lists of parameter names in each group

Returns:
The transformed function node with updated parameter usages
"""
# Create a module with just the function to get metadata
module = cst.Module(body=[function_node])
wrapper = MetadataWrapper(module)

class ParameterUsageTransformer(cst.CSTTransformer):
def __init__(self, classified_params: dict[str, list[str]]):
self.param_to_group = {}
"""
A CST transformer that updates parameter references to use the new parameter objects.
"""

METADATA_DEPENDENCIES = (ParentNodeProvider,)

def __init__(
self, classified_params: dict[str, list[str]], metadata_wrapper: MetadataWrapper
):
super().__init__()
# map each parameter to its group (data_params or config_params)
self.param_to_group = {}
self.parent_provider = metadata_wrapper.resolve(ParentNodeProvider)
# flatten classified_params to map each param to its group (dataParams or configParams)
for group, params in classified_params.items():
for param in params:
self.param_to_group[param] = group

def leave_Assign(
self,
original_node: cst.Assign, # noqa: ARG002
updated_node: cst.Assign,
) -> cst.Assign:
def is_in_assignment_target(self, node: cst.CSTNode) -> bool:
"""
Transform only right-hand side references to parameters that need to be updated.
Ensure left-hand side (self attributes) remain unchanged.
Check if a node is part of an assignment target (left side of =).

Args:
node: The CST node to check

Returns:
True if the node is part of an assignment target that should not be transformed,
False otherwise
"""
current = node
while current:
parent = self.parent_provider.get(current)

# if we're at an AssignTarget, check if it's a simple Name assignment
if isinstance(parent, cst.AssignTarget):
if isinstance(current, cst.Name):
# allow transformation for simple parameter assignments
return False
return True

if isinstance(parent, cst.Assign):
# if we reach an Assign node, check if we came from the targets
for target in parent.targets:
if target.target.deep_equals(current):
if isinstance(current, cst.Name):
# allow transformation for simple parameter assignments
return False
return True
return False

if isinstance(parent, cst.Module):
return False

current = parent
return False

def leave_Name(
self, original_node: cst.Name, updated_node: cst.Name
) -> cst.BaseExpression:
"""
if not isinstance(updated_node.value, cst.Name):
Transform standalone parameter references.

Skip transformation if:
1. The name is part of an attribute access (eg: self.param)
2. The name is part of a complex assignment target (eg: self.x = y)

Transform if:
1. The name is a simple parameter being assigned (eg: param1 = value)
2. The name is used as a value (eg: x = param1)

Args:
original_node: The original Name node
updated_node: The current state of the Name node

Returns:
The transformed node or the original if no transformation is needed
"""
# dont't transform if this is part of a complex assignment target
if self.is_in_assignment_target(original_node):
return updated_node

var_name = updated_node.value.value
# dont't transform if this is part of an attribute access (e.g., self.param)
parent = self.parent_provider.get(original_node)
if isinstance(parent, cst.Attribute) and original_node is parent.attr:
return updated_node

if var_name in self.param_to_group:
new_value = cst.Attribute(
value=cst.Name(self.param_to_group[var_name]), attr=cst.Name(var_name)
name_value = updated_node.value
if name_value in self.param_to_group:
# transform the name into an attribute access on the appropriate parameter object
return cst.Attribute(
value=cst.Name(self.param_to_group[name_value]), attr=cst.Name(name_value)
)
return updated_node

def leave_Attribute(
self, original_node: cst.Attribute, updated_node: cst.Attribute
) -> cst.BaseExpression:
"""
Handle method calls and attribute access on parameters.
This method handles several cases:

1. Assignment targets (eg: self.x = y)
2. Simple attribute access (eg: self.x or report.x)
3. Nested attribute access (eg: data_params.user_id)
4. Subscript access (eg: self.settings["timezone"])
5. Parameter attribute access (eg: username.strip())

Args:
original_node: The original Attribute node
updated_node: The current state of the Attribute node

Returns:
The transformed node or the original if no transformation is needed
"""
# don't transform if this is part of an assignment target
if self.is_in_assignment_target(original_node):
# if this is a simple attribute access (eg: self.x or report.x), don't transform it
if isinstance(updated_node.value, cst.Name) and updated_node.value.value in {
"self",
"report",
}:
return original_node
return updated_node

# if this is a nested attribute access (eg: data_params.user_id), don't transform it further
if (
isinstance(updated_node.value, cst.Attribute)
and isinstance(updated_node.value.value, cst.Name)
and updated_node.value.value.value in {"data_params", "config_params"}
):
return updated_node

# if this is a simple attribute access (eg: self.x or report.x), don't transform it
if isinstance(updated_node.value, cst.Name) and updated_node.value.value in {
"self",
"report",
}:
# check if this is part of a subscript target (eg: self.settings["timezone"])
parent = self.parent_provider.get(original_node)
if isinstance(parent, cst.Subscript):
return original_node
# check if this is part of a subscript value
if isinstance(parent, cst.SubscriptElement):
return original_node
return original_node

# if the attribute's value is a parameter name, update it to use the encapsulated parameter object
if (
isinstance(updated_node.value, cst.Name)
and updated_node.value.value in self.param_to_group
):
param_name = updated_node.value.value
return cst.Attribute(
value=cst.Name(self.param_to_group[param_name]), attr=updated_node.attr
)
return updated_node.with_changes(value=new_value)

return updated_node

# wrap CST node in a MetadataWrapper to enable metadata analysis
transformer = ParameterUsageTransformer(classified_params)
return function_node.visit(transformer)
# create transformer with metadata wrapper
transformer = ParameterUsageTransformer(classified_params, wrapper)
# transform the function body
updated_module = module.visit(transformer)
# return the transformed function
return updated_module.body[0]

@staticmethod
def get_enclosing_class_name(
tree: cst.Module, # noqa: ARG004
init_node: cst.FunctionDef,
parent_metadata: Mapping[cst.CSTNode, cst.CSTNode],
) -> Optional[str]:
"""
Finds the class name enclosing the given __init__ function node.
"""
# wrapper = MetadataWrapper(tree)
current_node = init_node
while current_node in parent_metadata:
parent = parent_metadata[current_node]
Expand All @@ -324,15 +466,7 @@ def update_function_calls(
classified_param_names: tuple[str, str],
enclosing_class_name: str,
) -> cst.Module:
"""
Updates all calls to a given function in the provided CST tree to reflect new encapsulated parameters
:param tree: CST tree of the code.
:param function_node: CST node of the function to update calls for.
:param params: A dictionary containing 'data' and 'config' parameters.
:return: The updated CST tree
"""
param_to_group = {}

for group_name, params in zip(classified_param_names, classified_params.values()):
for param in params:
param_to_group[param] = group_name
Expand All @@ -341,6 +475,15 @@ def update_function_calls(
if function_name == "__init__":
function_name = enclosing_class_name

# Get all parameter names from the function definition
all_param_names = [p.name.value for p in function_node.params.params]
# Find where variadic args start (if any)
variadic_start = len(all_param_names)
for i, param in enumerate(function_node.params.params):
if param.star == "*" or param.star == "**":
variadic_start = i
break

class FunctionCallTransformer(cst.CSTTransformer):
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: # noqa: ARG002
"""Transforms function calls to use grouped parameters."""
Expand All @@ -361,13 +504,27 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal

positional_args = []
keyword_args = {}

# Separate positional and keyword arguments
for arg in updated_node.args:
if arg.keyword is None:
positional_args.append(arg.value)
else:
keyword_args[arg.keyword.value] = arg.value
variadic_args = []
variadic_kwargs = {}

# Separate positional, keyword, and variadic arguments
for i, arg in enumerate(updated_node.args):
if isinstance(arg, cst.Arg):
if arg.keyword is None:
# If this is a positional argument beyond the number of parameters,
# it's a variadic arg
if i >= variadic_start:
variadic_args.append(arg.value)
elif i < len(used_params):
positional_args.append(arg.value)
else:
# If this is a keyword argument for a used parameter, keep it
if arg.keyword.value in param_to_group:
keyword_args[arg.keyword.value] = arg.value
# If this is a keyword argument not in the original parameters,
# it's a variadic kwarg
elif arg.keyword.value not in all_param_names:
variadic_kwargs[arg.keyword.value] = arg.value

# Group arguments based on classified_params
grouped_args = {group: [] for group in classified_param_names}
Expand Down Expand Up @@ -397,6 +554,94 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal
if grouped_args[group_name] # Skip empty groups
]

# Add variadic positional arguments
new_args.extend([cst.Arg(value=arg) for arg in variadic_args])

# Add variadic keyword arguments
new_args.extend(
[
cst.Arg(keyword=cst.Name(key), value=value)
for key, value in variadic_kwargs.items()
]
)

return updated_node.with_changes(args=new_args)

transformer = FunctionCallTransformer()
return tree.visit(transformer)

@staticmethod
def update_function_calls_unclassified(
tree: cst.Module,
function_node: cst.FunctionDef,
used_params: list[str],
enclosing_class_name: str,
) -> cst.Module:
"""
Updates all calls to a given function to only include used parameters.
This is used when parameters are removed without being classified into objects.

Args:
tree: CST tree of the code
function_node: CST node of the function to update calls for
used_params: List of parameter names that are actually used in the function
enclosing_class_name: Name of the enclosing class if this is a method

Returns:
Updated CST tree with modified function calls
"""
function_name = function_node.name.value
if function_name == "__init__":
function_name = enclosing_class_name

class FunctionCallTransformer(cst.CSTTransformer):
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: # noqa: ARG002
"""Transforms function calls to only include used parameters."""
# handle both standalone function calls and instance method calls
if not isinstance(updated_node.func, (cst.Name, cst.Attribute)):
return updated_node

# extract the function/method name
func_name = (
updated_node.func.attr.value
if isinstance(updated_node.func, cst.Attribute)
else updated_node.func.value
)

# if not the target function, leave unchanged
if func_name != function_name:
return updated_node

# map original parameters to their positions
param_positions = {
param.name.value: i for i, param in enumerate(function_node.params.params)
}

# keep track of which positions in the argument list correspond to used parameters
used_positions = {i for param, i in param_positions.items() if param in used_params}

new_args = []
pos_arg_count = 0

# process all arguments
for arg in updated_node.args:
if arg.keyword is None:
# handle positional arguments
if pos_arg_count in used_positions:
new_args.append(arg)
pos_arg_count += 1
else:
# handle keyword arguments
if arg.keyword.value in used_params:
# keep keyword arguments for used parameters
new_args.append(arg)

# ensure the last argument does not have a trailing comma
if new_args:
final_args = new_args[:-1]
final_args.append(new_args[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT))
new_args = final_args

return updated_node.with_changes(args=new_args)

transformer = FunctionCallTransformer()
Expand Down Expand Up @@ -499,7 +744,7 @@ def refactor(
self.is_constructor = self.function_node.name.value == "__init__"
if self.is_constructor:
self.enclosing_class_name = FunctionCallUpdater.get_enclosing_class_name(
tree, self.function_node, parent_metadata
self.function_node, parent_metadata
)
param_names = [
param.name.value
Expand Down Expand Up @@ -562,6 +807,11 @@ def refactor(
self.function_node, self.used_params, default_value_params
)

# update all calls to match the new signature
tree = self.function_updater.update_function_calls_unclassified(
tree, self.function_node, self.used_params, self.enclosing_class_name
)

class FunctionReplacer(cst.CSTTransformer):
def __init__(
self, original_function: cst.FunctionDef, updated_function: cst.FunctionDef
Expand Down
Loading
Loading