diff --git a/projects/mock_transformers/dist_infer_opt.py b/projects/mock_transformers/dist_infer_opt.py index 169b382f3..117e35db9 100644 --- a/projects/mock_transformers/dist_infer_opt.py +++ b/projects/mock_transformers/dist_infer_opt.py @@ -107,12 +107,18 @@ def __init__(self, *args, **kwargs): ) # generate id - for i in range(100): - with global_mode(True, **placement_sbp_dict): - model = init_env.compile_auto_placement( - model, - input_ids - ) - generated_ids = model.generate(input_ids, max_length=30) - out_put_ids = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - print(out_put_ids) + # for i in range(100): + + # generated_ids = model.generate(input_ids, max_length=30) + # raise KeyError + with global_mode(True, **placement_sbp_dict): + compiled_model = init_env.compile_auto_placement( + model, + input_ids=input_ids, + ) + # print(model.code) # use this to print the compiled module code + generated_ids = compiled_model.run(input_ids) + print(generated_ids) + # generated_ids = model(input_ids) + # out_put_ids = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + # print(generated_ids) diff --git a/projects/mock_transformers/init_env.py b/projects/mock_transformers/init_env.py index 0eb9b3f1f..c8f4ee8e8 100644 --- a/projects/mock_transformers/init_env.py +++ b/projects/mock_transformers/init_env.py @@ -20,6 +20,10 @@ import copy # noqa +import sys +sys.path.append("../") +sys.path.append("./") +sys.path.append("./libai/") import onefx as fx # noqa from typing import List, Dict, Any # noqa from oneflow import Tensor, nn # noqa @@ -212,10 +216,235 @@ def auto_set_pipeline_stage_id(model, pipeline_parallel_size=1): # ---------------def fx for auto changing placement ---------------------- +import inspect +import math +from typing import Tuple, Dict, Optional, Any, Callable, Union +from copy import deepcopy +import traceback +import builtins + +_customized_not_wrapped_oneflow_functions = [ + flow.ones_like, + flow.zeros_like, + flow.randn, + flow.randn_like, + flow.randint, flow.randint_like, + flow.device +] + +class CustomiziedTracer(fx.Tracer): + def __init__(self, autowrap_modules = (math, ), autowrap_functions: Tuple[Callable, ...] = (), param_shapes_constant: bool = False, + not_wrapped_oneflow_functions=_customized_not_wrapped_oneflow_functions, input_args=None) -> None: + super().__init__(autowrap_modules, autowrap_functions, param_shapes_constant, not_wrapped_oneflow_functions) + self.registered_values = {} + self.args_iter = iter(input_args) + + def to_bool(self, obj: fx.Proxy) -> bool: #override + if obj.node.name in self.registered_values: + return self.registered_values[obj.node.name] + return super().to_bool(obj) + + def create_proxy(self, kind: str, target, args: Tuple[Any, ...], kwargs: Dict[str, Any], + name: Optional[str] = None, type_expr : Optional[Any] = None, + proxy_factory_fn: Callable[[fx.Node], fx.Proxy] = None): #override + arg_values = [] + for i, arg in enumerate(args): + # if isinstance(arg, tuple) and isinstance(arg[0], fx.Proxy) and callable(target) + # and list(inspect.signature(target).parameters.keys())[0].startswith('*'): + # arg = arg[0] + if isinstance(arg, tuple): + has_proxy = [1 if isinstance(item, fx.Proxy) else 0 for item in arg] + has_proxy = sum(has_proxy) + current_arg_value = [] + if has_proxy > 0: + for proxy in arg: + if not isinstance(proxy, fx.Proxy): + current_arg_value.append(proxy) + continue + if not proxy.node.name in self.registered_values: + raise ValueError(f"{arg.node.name} cannot be found.") + else: + current_arg_value.append(self.registered_values[proxy.node.name]) + arg_values.append(tuple(current_arg_value)) + continue + if not isinstance(arg, fx.Proxy): + arg_values.append(arg) + continue + if not arg.node.name in self.registered_values: + raise ValueError(f"{arg.node.name} cannot be found.") + else: + arg_values.append(self.registered_values[arg.node.name]) + + kwarg_values = {} + for arg_name, arg in kwargs.items(): + if isinstance(arg, tuple): + has_proxy = [1 if isinstance(item, fx.Proxy) else 0 for item in arg] + has_proxy = sum(has_proxy) + current_arg_value = [] + if has_proxy > 0: + for proxy in arg: + if not isinstance(proxy, fx.Proxy): + current_arg_value.append(proxy) + continue + if not proxy.node.name in self.registered_values: + raise ValueError(f"{arg.node.name} cannot be found.") + else: + current_arg_value.append(self.registered_values[proxy.node.name]) + kwarg_values[arg_name] = tuple(current_arg_value) + continue + if not isinstance(arg, fx.Proxy): + kwarg_values[arg_name] = arg + continue + if not arg.node.name in self.registered_values: + raise ValueError(f"{arg.node.name} cannot be found.") + else: + kwarg_values[arg_name] = self.registered_values[arg.node.name] + + assert kind != "call_function" or callable(target) + + with fx.fx_no_wrap_context(self): + if kind == "call_function": + result_value = target(*arg_values, **kwarg_values) + elif kind == "call_method": + self_obj, *args_tail = arg_values + + # Execute the method and return the result + assert isinstance(target, str) + method = getattr(self_obj, target) + result_value = method(*args_tail, **kwarg_values) + elif kind == "call_module": + assert isinstance(target, str) + submod = self.fetch_attr(target) + + result_value = submod(*arg_values, **kwarg_values) + elif kind == "placeholder": + assert isinstance(target, str) + if target.startswith('*'): + # For a starred parameter e.g. `*args`, retrieve all + # remaining values from the args list. + result_value = list(self.args_iter) + else: + try: + result_value = next(self.args_iter) + except StopIteration as si: + raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si + elif kind == "get_attr": + assert isinstance(target, str) + result_value = self.fetch_attr(target) + elif kind == "output": + result_value = arg_values[0] + elif kind == "root": + raise NotImplementedError + else: + raise NotImplementedError + + if isinstance(result_value, fx.Proxy): + if result_value.node.name in self.registered_values: + result_value = self.registered_values[result_value.node.name] + else: + raise ValueError("Got a proxy object when running with original values.") + + if not self.fx_no_wrap: + result_proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) + self.registered_values[result_proxy.node.name] = result_value + return result_proxy + + def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]): + if self.fx_no_wrap: + return attr_val + def maybe_get_proxy_for_attr( + attr_val, collection_to_search, parameter_proxy_cache + ): + for n, p in collection_to_search: + if attr_val is p: + if n not in parameter_proxy_cache: + kwargs = {} + if ( + "proxy_factory_fn" + in inspect.signature(self.create_proxy).parameters + ): + kwargs["proxy_factory_fn"] = ( + None + if not self.param_shapes_constant + else lambda node: fx.ParameterProxy( + self, node, n, attr_val + ) + ) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] + parameter_proxy_cache[n] = val_proxy + return parameter_proxy_cache[n] + return None + + if isinstance(attr_val, flow.nn.Parameter): + maybe_parameter_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_parameters(), parameter_proxy_cache + ) + if maybe_parameter_proxy is not None: + if not maybe_parameter_proxy.node.name in self.registered_values: + self.registered_values[maybe_parameter_proxy.node.name] = attr_val + return maybe_parameter_proxy + + if self.proxy_buffer_attributes and isinstance(attr_val, flow.Tensor): + maybe_buffer_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_buffers(), parameter_proxy_cache + ) + if maybe_buffer_proxy is not None: + if not maybe_buffer_proxy.node.name in self.registered_values: + self.registered_values[maybe_buffer_proxy.node.name] = attr_val + return maybe_buffer_proxy + + return attr_val + + def call_module( + self, + m: flow.nn.Module, + forward: Callable[..., Any], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ) -> Any: # override + if self.fx_no_wrap: + return forward(*args, **kwargs) + else: + return super().call_module(m, forward, args, kwargs) + + def trace( + self, + root: Union[flow.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, + ) -> fx.Graph: # override + self.module = root + return super().trace(root, concrete_args) + + def fetch_attr(self, target : str): + target_atoms = target.split('.') + attr_itr = self.module + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") + attr_itr = getattr(attr_itr, atom) + if not isinstance(attr_itr, fx.Proxy): + return attr_itr + if attr_itr.node.name in self.registered_values: + return self.registered_values[attr_itr.node.name] + + raise ValueError(f"No attr <{target}> was found.") + + +def customized_symbolic_trace( + root: Union[flow.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, + input_args=None +) -> fx.GraphModule: + tracer = CustomiziedTracer(input_args=input_args) + graph = tracer.trace(root, concrete_args) + name = ( + root.__class__.__name__ if isinstance(root, flow.nn.Module) else root.__name__ + ) + return fx.GraphModule(tracer.root, graph, name) class AutoPlacementInterpreter(fx.Interpreter): - def __init__(self, mod : flow.nn.Module): - gm = fx.symbolic_trace(mod) + def __init__(self, mod : flow.nn.Module, concrete_args=None, input_args=None): + gm = customized_symbolic_trace(mod, concrete_args=concrete_args, input_args=input_args) super().__init__(gm) self.global_infos : Dict[int, Dict[int, Any]] = {} @@ -258,12 +487,11 @@ def run_node(self, n : fx.Node) -> Any: return return_val -def add_auto_placement(model: flow.nn.Module, global_info_dict: Dict[int, Dict[int, List[int]]]) -> flow.nn.Module: +def add_auto_placement(model: flow.nn.Module, global_info_dict: Dict[int, Dict[int, List[int]]], concrete_args=None, input_args=None) -> flow.nn.Module: model = copy.deepcopy(model) - fx_model: fx.GraphModule = fx.symbolic_trace(model) + fx_model: fx.GraphModule = customized_symbolic_trace(model, concrete_args=concrete_args, input_args=input_args) for node_id, node in enumerate(fx_model.graph.nodes): - print(node_id, " ", node.op) if not node_id in global_info_dict: continue @@ -277,14 +505,23 @@ def add_auto_placement(model: flow.nn.Module, global_info_dict: Dict[int, Dict[i fx_model.graph.lint() fx_model.recompile() - return fx_model - -def compile_auto_placement(model: flow.nn.Module, input_x: flow.Tensor): - assert input_x.is_global - interpret = AutoPlacementInterpreter(model) - interpret.run(input_x) - model = add_auto_placement(model, interpret.global_infos) - return model + return fx.Interpreter(fx_model) + +fx.wrap(len) +def compile_auto_placement(model: flow.nn.Module, concrete_args=None, **kwargs): + with fx.global_wrap([dist.get_nd_sbp, dist.same_sbp], dist): + with fx.global_wrap([flow.finfo], flow): + if concrete_args is None: + all_args = inspect.signature(model.forward).parameters + concrete_args = {} + for arg_name, param in all_args.items(): + if not arg_name in kwargs and param.default != inspect._empty: + concrete_args.update({arg_name:param.default}) + + interpret = AutoPlacementInterpreter(model, concrete_args=concrete_args, input_args=list(kwargs.values()) + list(concrete_args.values())) + interpret.run(*(kwargs.values())) + model = add_auto_placement(model, interpret.global_infos, concrete_args, input_args=list(kwargs.values()) + list(concrete_args.values())) + return model # b = flow.ones( # (2,2),