From ae6f7a9cff9219c894bdc4e7c694cdccbb285169 Mon Sep 17 00:00:00 2001 From: Sourcery AI <> Date: Mon, 13 Sep 2021 02:27:01 +0000 Subject: [PATCH] 'Refactored by Sourcery' --- pccm/core/__init__.py | 167 ++++++++++++++------------------ pccm/core/buildmeta.py | 12 +-- pccm/core/codegen.py | 4 +- pccm/graph/__init__.py | 21 ++-- pccm/middlewares/expose_main.py | 8 +- pccm/middlewares/pybind.py | 109 +++++++++------------ pccm/targets/cuda_ptx.py | 46 ++++----- pccm/test_data/mod.py | 3 +- 8 files changed, 155 insertions(+), 215 deletions(-) diff --git a/pccm/core/__init__.py b/pccm/core/__init__.py index a80260f..8ab7e4a 100644 --- a/pccm/core/__init__.py +++ b/pccm/core/__init__.py @@ -313,9 +313,7 @@ def __init__(self, self.scoped = scoped def to_string(self) -> str: - scoped_str = "" - if self.scoped: - scoped_str = "class" + scoped_str = "class" if self.scoped else "" prefix = "enum {} {} {{".format(scoped_str, self.name) if self.base_type: prefix = "enum {} {}: {} {{".format(scoped_str, self.name, @@ -362,13 +360,12 @@ def to_string(self) -> str: else: return "{}{} {} = {};".format(doc, self.type_str, self.name, self.default) + elif self.default is None: + return "{}{} {}{};".format(doc, self.type_str, self.name, + self.array) else: - if self.default is None: - return "{}{} {}{};".format(doc, self.type_str, self.name, - self.array) - else: - return "{}{} {}{} = {};".format(doc, self.type_str, self.name, - self.array, self.default) + return "{}{} {}{} = {};".format(doc, self.type_str, self.name, + self.array, self.default) class TemplateTypeArgument(object): @@ -535,18 +532,10 @@ def unpack(self, args: list) -> str: return ", ".join(map(str, args)) def _clean_pre_attrs_impl(self, attrs: List[str]): - res_attrs = [] # type: List[str] - for attr in attrs: - if attr not in _HEADER_ONLY_PRE_ATTRS: - res_attrs.append(attr) - return res_attrs + return [attr for attr in attrs if attr not in _HEADER_ONLY_PRE_ATTRS] def _clean_post_attrs_impl(self, attrs: List[str]): - res_attrs = [] # type: List[str] - for attr in attrs: - if attr not in _HEADER_ONLY_POST_ATTRS: - res_attrs.append(attr) - return res_attrs + return [attr for attr in attrs if attr not in _HEADER_ONLY_POST_ATTRS] def get_sig(self, name: str, @@ -565,9 +554,8 @@ def get_sig(self, return_type = self.return_type if isinstance(meta, (ConstructorMeta, DestructorMeta)): return_type = "" - else: - if not header_only: - assert self.return_type != "auto" and self.return_type != "decltype(auto)" + elif not header_only: + assert self.return_type not in ["auto", "decltype(auto)"] fmt = "{ret_type} {name}({args})" if withpost: fmt += "{post_attrs}" @@ -650,8 +638,7 @@ def get_impl(self, name: str, meta: FunctionMeta, class_name: str = ""): post_attrs=post_attrs_str) if pre_attrs_str: prefix_fmt = pre_attrs_str + " " + prefix_fmt - blocks = [] # List[Union[Block, str]] - blocks.extend(self._blocks) + blocks = list(self._blocks) block = Block(template_fmt + prefix_fmt, blocks, "}") if meta.macro_guard is not None: block = Block("#if {}".format(meta.macro_guard), [block], "#endif") @@ -972,15 +959,18 @@ def _assign_overload_flag_to_func_decls(self): member_decl_count[cpp_func_name] += 1 for decl in self._function_decls: cpp_func_name = decl.get_function_name() - if isinstance(decl.meta, ExternalFunctionMeta): - if extend_decl_count[cpp_func_name] > 1: - decl.is_overload = True - elif isinstance(decl.meta, StaticMemberFunctionMeta): - if static_member_decl_count[cpp_func_name] > 1: - decl.is_overload = True - elif isinstance(decl.meta, MemberFunctionMeta): - if member_decl_count[cpp_func_name] > 1: - decl.is_overload = True + if ( + isinstance(decl.meta, ExternalFunctionMeta) + and extend_decl_count[cpp_func_name] > 1 + or not isinstance(decl.meta, ExternalFunctionMeta) + and isinstance(decl.meta, StaticMemberFunctionMeta) + and static_member_decl_count[cpp_func_name] > 1 + or not isinstance(decl.meta, ExternalFunctionMeta) + and not isinstance(decl.meta, StaticMemberFunctionMeta) + and isinstance(decl.meta, MemberFunctionMeta) + and member_decl_count[cpp_func_name] > 1 + ): + decl.is_overload = True def add_dependency(self, *no_param_class_cls: Type["Class"]): # TODO enable name alias for Class @@ -1134,7 +1124,7 @@ def get_includes_with_dep(self) -> List[str]: for d in self.get_common_deps()) return res - def get_parent_class(self): # -> Optional[Type["Class"]] + def get_parent_class(self): # -> Optional[Type["Class"]] """TODO find a better way to check invalid param class inherit """ if type(self) is Class: @@ -1145,24 +1135,23 @@ def get_parent_class(self): # -> Optional[Type["Class"]] base = candidates.pop() if issubclass(base, Class): cls_meta = get_class_meta(base) - if cls_meta is None: - pccm_base_types.append(base) + if cls_meta is not None and cls_meta.skip_inherit: + candidates.extend(base.__bases__) else: - if cls_meta.skip_inherit: - candidates.extend(base.__bases__) - else: - pccm_base_types.append(base) + pccm_base_types.append(base) assert len(pccm_base_types) == 1, "you can only inherit one class." pccm_base = pccm_base_types[0] - if pccm_base is not Class and base is not ParameterizedClass: - # assert not issubclass(mro[1], ParameterizedClass), "you can't inherit a param class." - if not issubclass(pccm_base, ParameterizedClass): - # you inherit a class. you must set _this_cls_type by self.set_this_class_type(__class__) - msg = ( - "you must use self.set_this_class_type(__class__) to init this class type" - " when you inherit pccm.Class") - assert self._this_cls_type is not None, msg - return pccm_base + if ( + pccm_base is not Class + and base is not ParameterizedClass + and not issubclass(pccm_base, ParameterizedClass) + ): + # you inherit a class. you must set _this_cls_type by self.set_this_class_type(__class__) + msg = ( + "you must use self.set_this_class_type(__class__) to init this class type" + " when you inherit pccm.Class") + assert self._this_cls_type is not None, msg + return pccm_base return None def get_class_deps(self) -> List[Type["Class"]]: @@ -1232,16 +1221,12 @@ def get_code_class_def( d.to_string() for d in self._members if d.cls_type is self._this_cls_type ] - parent_class_alias = None # type: Optional[str] parent = self.get_parent_class() - if parent is not None: - # TODO better way to get alias name - parent_class_alias = parent.__name__ - cdef = CodeSectionClassDef(cu_name, dep_alias, self._code_before_class, + parent_class_alias = parent.__name__ if parent is not None else None + return CodeSectionClassDef(cu_name, dep_alias, self._code_before_class, self._code_after_class, ext_decls, ec_strs, typedef_strs, sc_strs, member_func_decls, member_def_strs, parent_class_alias) - return cdef def get_common_deps(self) -> List["Class"]: assert self.graph_inited, "you must build dependency graph before generate code" @@ -1304,13 +1289,12 @@ def generate_namespace(self, namespace: str): if namespace == "": return [], [] namespace_parts = namespace.split(".") - namespace_before = [] # type: List[str] - namespace_after = [] # type: List[str] + namespace_before = ["namespace {} {{".format(p) for p in namespace_parts] + namespace_after = [ + "}} // namespace {}".format(p) for p in namespace_parts[::-1] + ] + - for p in namespace_parts: - namespace_before.append("namespace {} {{".format(p)) - for p in namespace_parts[::-1]: - namespace_after.append("}} // namespace {}".format(p)) return namespace_before, namespace_after @@ -1394,8 +1378,7 @@ def to_block(self) -> Block: prefix = code_before_cls + [ "struct {class_name} {{".format(class_name=self.class_name) ] - block = Block("\n".join(prefix), class_contents, "};") - return block + return Block("\n".join(prefix), class_contents, "};") class CodeSectionImpl(CodeSection): @@ -1438,9 +1421,9 @@ def extract_module_id_of_class( relative_path = path.relative_to(Path(root)) import_parts = list(relative_path.parts) import_parts[-1] = relative_path.stem + elif loader.locate_top_package(path) is None: + return None else: - if loader.locate_top_package(path) is None: - return None import_parts = loader.try_capture_import_parts(path, None) return ".".join(import_parts) @@ -1533,35 +1516,35 @@ def _apply_middleware_to_cus(self, uid_to_cu: Dict[str, Class]): new_uid_to_cu = OrderedDict() # type: Dict[str, Class] for middleware in self.middlewares: mw_type = type(middleware) - if isinstance(middleware, ManualClassGenerator): - for k, cu in uid_to_cu.items(): - decls_with_meta = [ - ] # type: List[Tuple[FunctionDecl, MiddlewareMeta]] - members_with_meta = [ - ] # type: List[Tuple[Member, MiddlewareMeta]] - # TODO only one meta is allowed - for decl in cu._function_decls: - for mw_meta in decl.meta.mw_metas: - if mw_meta.type is mw_type: - decls_with_meta.append((decl, mw_meta)) - for member in cu._members: - for mw_meta in member.mw_metas: - if mw_meta.type is mw_type: - members_with_meta.append((member, mw_meta)) - if not decls_with_meta and not members_with_meta: - continue - new_pcls = middleware.create_manual_class(cu) - if new_pcls.namespace is None: - new_pcls.namespace = cu.namespace + "." + middleware.subnamespace - for decl, mw_meta in decls_with_meta: - new_pcls.handle_function_decl(cu, decl, mw_meta) - for member, mw_meta in members_with_meta: - new_pcls.handle_member(cu, member, mw_meta) - uid = new_pcls.namespace + "-" + type(new_pcls).__name__ - new_uid_to_cu[uid] = new_pcls - else: + if not isinstance(middleware, ManualClassGenerator): raise NotImplementedError + for k, cu in uid_to_cu.items(): + decls_with_meta = [ + ] # type: List[Tuple[FunctionDecl, MiddlewareMeta]] + members_with_meta = [ + ] # type: List[Tuple[Member, MiddlewareMeta]] + # TODO only one meta is allowed + for decl in cu._function_decls: + for mw_meta in decl.meta.mw_metas: + if mw_meta.type is mw_type: + decls_with_meta.append((decl, mw_meta)) + for member in cu._members: + for mw_meta in member.mw_metas: + if mw_meta.type is mw_type: + members_with_meta.append((member, mw_meta)) + if not decls_with_meta and not members_with_meta: + continue + new_pcls = middleware.create_manual_class(cu) + if new_pcls.namespace is None: + new_pcls.namespace = cu.namespace + "." + middleware.subnamespace + for decl, mw_meta in decls_with_meta: + new_pcls.handle_function_decl(cu, decl, mw_meta) + for member, mw_meta in members_with_meta: + new_pcls.handle_member(cu, member, mw_meta) + uid = new_pcls.namespace + "-" + type(new_pcls).__name__ + new_uid_to_cu[uid] = new_pcls + def build_graph(self, cus: List[Union[Class, ParameterizedClass]], root: Optional[Union[str, Path]] = None, diff --git a/pccm/core/buildmeta.py b/pccm/core/buildmeta.py index 86e0fe9..ab6435f 100644 --- a/pccm/core/buildmeta.py +++ b/pccm/core/buildmeta.py @@ -12,11 +12,10 @@ def _unique_list_keep_order(seq: list): # https://www.peterbe.com/plog/fastest-way-to-uniquify-a-list-in-python-3.6 # only python 3.7 language std ensure the preserve-order dict return list(dict.fromkeys(seq)) - else: - # https://stackoverflow.com/questions/480214/how-do-you-remove-duplicates-from-a-list-whilst-preserving-order - seen = set() - seen_add = seen.add - return [x for x in seq if not (x in seen or seen_add(x))] + # https://stackoverflow.com/questions/480214/how-do-you-remove-duplicates-from-a-list-whilst-preserving-order + seen = set() + seen_add = seen.add + return [x for x in seq if not (x in seen or seen_add(x))] def _merge_compiler_to_flags(this: Dict[str, List[str]], @@ -66,11 +65,10 @@ def __add__(self, other: "BuildMeta"): merged_ldflags = _merge_compiler_to_flags(self.compiler_to_ldflags, other.compiler_to_ldflags) - res = BuildMeta( + return BuildMeta( self.includes + other.includes, self.libpaths + other.libpaths, _unique_list_keep_order(self.libraries + other.libraries), merged_cflags, merged_ldflags) - return res def __radd__(self, other: "BuildMeta"): return other.__add__(self) diff --git a/pccm/core/codegen.py b/pccm/core/codegen.py index 4728957..5278e32 100644 --- a/pccm/core/codegen.py +++ b/pccm/core/codegen.py @@ -33,9 +33,7 @@ def generate_code(block: Union[Block, str], start_col_offset: int, return [col_str + l for l in block_lines] res = [] # type: List[str] prefix = block.prefix - next_indent = indent - if block.indent is not None: - next_indent = block.indent + next_indent = block.indent if block.indent is not None else indent if prefix: prefix_lines = prefix.split("\n") res.extend([col_str + l for l in prefix_lines]) diff --git a/pccm/graph/__init__.py b/pccm/graph/__init__.py index 12bd9e5..a858da1 100644 --- a/pccm/graph/__init__.py +++ b/pccm/graph/__init__.py @@ -83,7 +83,7 @@ def postorder_traversal(node: Node, node_map: Dict[str, Node]): if namedio.key not in node_map: continue inp = node_map[namedio.key] - if not inp in visited: + if inp not in visited: next_nodes.append(inp) ready = False if ready: @@ -112,11 +112,10 @@ def _cycle_detection(node_map: Dict[str, Node], node: Node, visited: Set[str], def cycle_detection(node_map: Dict[str, Node]): visited = set() trace = set() - for node in node_map.values(): - if node.key not in visited: - if _cycle_detection(node_map, visited, trace): - return True - return False + return any( + node.key not in visited and _cycle_detection(node_map, visited, trace) + for node in node_map.values() + ) class Graph(object): @@ -177,10 +176,7 @@ def is_source_of(self, lfs: Node, rfs: Node): def get_sources_of(self, node: Node): assert not self._has_cycle, "graph must be DAG" all_sources = set() - stack = [] - for inp in node.inputs: - if inp.key in self: - stack.append(self[inp.key]) + stack = [self[inp.key] for inp in node.inputs if inp.key in self] while stack: n = stack.pop() if n in all_sources: @@ -214,10 +210,7 @@ def is_branch_node(self, node: Node): for io in n.inputs: if io.key in self: stack.append(self[io.key]) - for s in all_sources: - if s.key in visited: - return True - return False + return any(s.key in visited for s in all_sources) def create_node(key: str, *inputs: List[NamedIO]): diff --git a/pccm/middlewares/expose_main.py b/pccm/middlewares/expose_main.py index c1c2b6f..e67be86 100644 --- a/pccm/middlewares/expose_main.py +++ b/pccm/middlewares/expose_main.py @@ -63,12 +63,12 @@ def create_manual_class(self, cu: Class) -> ManualClass: return self.singleton def get_code_units(self) -> List[Class]: - if self.singleton.main_cu is not None: - self.singleton.postprocess() - return [self.singleton] - else: + if self.singleton.main_cu is None: return [] + self.singleton.postprocess() + return [self.singleton] + def mark(func=None): meta = ExposeMainMeta() diff --git a/pccm/middlewares/pybind.py b/pccm/middlewares/pybind.py index dddab7f..e17858e 100644 --- a/pccm/middlewares/pybind.py +++ b/pccm/middlewares/pybind.py @@ -188,11 +188,11 @@ def to_pyanno(self) -> str: return "Any" if self.exist_anno: return self.exist_anno - if self.name in self.NameToHandler: - pyanno_generic = self.NameToHandler[self.name](self.args) - else: - pyanno_generic = "Any" - return pyanno_generic + return ( + self.NameToHandler[self.name](self.args) + if self.name in self.NameToHandler + else "Any" + ) def _simple_template_type_parser_recursive( @@ -239,9 +239,8 @@ def _simple_template_type_parser( invalid = TemplateTypeStmt("", [], False, True) bracket_stack = [] # type: List[Tuple[str, int]] N = len(stmt) - pos = 0 bracket_pair = {} # type: Dict[int, int] - while pos < N: + for pos in range(N): val = stmt[pos] if val == "<": bracket_stack.append((val, pos)) @@ -250,7 +249,6 @@ def _simple_template_type_parser( return invalid start_val, start = bracket_stack.pop() bracket_pair[start] = pos - pos += 1 if bracket_stack: return invalid if "\"" in stmt: @@ -375,12 +373,11 @@ def __init__(self, decl: FunctionDecl, namespace: str, class_name: str, def get_overload_addr(self): addr = self.addr - arg_types = ", ".join([a.type_str for a in self.decl.code.arguments]) + arg_types = ", ".join(a.type_str for a in self.decl.code.arguments) meta = self.decl.meta addr_fmt = "pybind11::overload_cast<{}>({})" - if isinstance(meta, MemberFunctionMeta): - if meta.const: - addr_fmt = "pybind11::overload_cast<{}>({}, pybind11::const_)" + if isinstance(meta, MemberFunctionMeta) and meta.const: + addr_fmt = "pybind11::overload_cast<{}>({}, pybind11::const_)" addr = addr_fmt.format(arg_types, self.addr) return addr @@ -425,9 +422,7 @@ def get_virtual_string(self, parent_cls_name: str): if self.decl.meta.pure_virtual: fmt = "{} {} {{PYBIND11_OVERRIDE_PURE({}, {}, {}, {});}}" post_meta_attrs = self.decl.meta.get_post_attrs() - override = "" - if "override" not in post_meta_attrs: - override = "override" + override = "override" if "override" not in post_meta_attrs else "" sig_str = self.decl.code.get_sig(self.decl.get_function_name(), self.decl.meta, withpost=True, @@ -453,9 +448,7 @@ def get_prop_name(self) -> str: return self.decl.name def to_string(self) -> str: - def_stmt = "def_readwrite" - if not self.mw_meta.readwrite: - def_stmt = "def_readonly" + def_stmt = "def_readonly" if not self.mw_meta.readwrite else "def_readwrite" return ".{}(\"{}\", {})".format(def_stmt, self.mw_meta.name, self.addr) @@ -545,11 +538,10 @@ def _postprocess_class(cls_name: str, cls_namespace: str, submod: str, setter_prop_name.add(prop_name) getter_decl.setter_pybind_decl = decl - has_constructor = False - for d in method_decls: - if isinstance(d.decl.meta, ConstructorMeta): - has_constructor = True - break + has_constructor = any( + isinstance(d.decl.meta, ConstructorMeta) for d in method_decls + ) + cls_qual_name = "{}::{}".format(cls_namespace.replace(".", "::"), cls_name) cls_def_arguments = [cls_qual_name] if has_virtual: @@ -581,14 +573,12 @@ def _postprocess_class(cls_name: str, cls_namespace: str, submod: str, ec_prefix = "pybind11::enum_<{}::{}>({}, \"{}\", pybind11::arithmetic())".format( cls_name, ec.name, cls_def_name, ec.name) ec_values = [] # type: List[Union[Block, str]] - cnt = 0 - for key, value in ec.items: + for cnt, (key, value) in enumerate(ec.items): stmt = ".value(\"{key}\", {class_name}::{enum_name}::{key})".format( key=key, class_name=cls_name, enum_name=ec.name) if is_scoped and cnt == len(ec.items) - 1: stmt += ";" ec_values.append(stmt) - cnt += 1 if not is_scoped: ec_values.append(".export_values();") cls_def_stmts.append(Block(ec_prefix, ec_values, "")) @@ -609,11 +599,13 @@ def _extract_anno_default( type_str, exist_annos) try_extract_pyanno = try_extract_pyanno_res.to_pyanno() if try_extract_pyanno != "Any": - if default is None: - if cpp_default is not None: - if type_str in _AUTO_ANNO_TYPES_DEFAULT_HANDLER: - handler = _AUTO_ANNO_TYPES_DEFAULT_HANDLER[type_str] - default = handler(cpp_default) + if ( + default is None + and cpp_default is not None + and type_str in _AUTO_ANNO_TYPES_DEFAULT_HANDLER + ): + handler = _AUTO_ANNO_TYPES_DEFAULT_HANDLER[type_str] + default = handler(cpp_default) try_extract_pyanno, from_imports = python_anno_parser( try_extract_pyanno) return try_extract_pyanno, from_imports, default @@ -625,11 +617,13 @@ def _extract_anno_default( default = user_anno_type_default[1] anno, from_imports = python_anno_parser(user_anno_type) - if default is None: - if cpp_default is not None: - if type_str in _AUTO_ANNO_TYPES_DEFAULT_HANDLER: - handler = _AUTO_ANNO_TYPES_DEFAULT_HANDLER[type_str] - default = handler(cpp_default) + if ( + default is None + and cpp_default is not None + and type_str in _AUTO_ANNO_TYPES_DEFAULT_HANDLER + ): + handler = _AUTO_ANNO_TYPES_DEFAULT_HANDLER[type_str] + default = handler(cpp_default) return anno, from_imports, default @@ -648,19 +642,19 @@ def _collect_exist_annos(decls: List[Union[PybindMethodDecl, PybindPropDecl]]): prop_decls.append(decl) exist_annos = {} # type: Dict[str, str] for prop_decl in prop_decls: - prop_type = prop_decl.decl.type_str user_anno = prop_decl.decl.pyanno if user_anno is not None: user_anno_pair = user_anno.split("=") user_anno_type = user_anno_pair[0].strip() + prop_type = prop_decl.decl.type_str exist_annos[prop_type] = user_anno_type for pydecl in method_decls: user_anno = pydecl.decl.code.ret_pyanno - ret_type = pydecl.decl.code.return_type if user_anno is not None: user_anno_pair = user_anno.split("=") user_anno_type = user_anno_pair[0].strip() + ret_type = pydecl.decl.code.return_type exist_annos[ret_type] = user_anno_type for arg in pydecl.decl.code.arguments: @@ -701,13 +695,11 @@ class xxx: ) # type: Dict[str, List[PybindMethodDecl]] decl_codes = [] # type: List[Union[Block, str]] for prop_decl in prop_decls: - prop_anno = "Any" prop_type = prop_decl.decl.type_str user_anno = prop_decl.decl.pyanno anno, from_imports, default = _extract_anno_default( user_anno, prop_type, exist_annos) - if anno is not None: - prop_anno = anno + prop_anno = anno if anno is not None else "Any" imports.extend(from_imports) if prop_anno == cls_name: prop_anno = "\"{}\"".format(prop_anno) @@ -757,11 +749,10 @@ class xxx: if default is not None: default_str = " = {}".format(default) have_default = True - else: - if have_default: - msg = ("you must provide a python default anno value " - "for {} of {}. format: PythonType = Default") - raise ValueError(msg.format(arg.name, decl_bind_name)) + elif have_default: + msg = ("you must provide a python default anno value " + "for {} of {}. format: PythonType = Default") + raise ValueError(msg.format(arg.name, decl_bind_name)) if user_anno is not None: if user_anno == cls_name: user_anno = "\"{}\"".format(user_anno) @@ -789,15 +780,10 @@ class xxx: fmt.format(decorator, decl_bind_name, py_sig, res_anno, doc)) # Class EnumName: for ec in enum_classes: - ec_items = [] # type: List[Union[Block, str]] - enum_type = "EnumValue" - if ec.scoped: - enum_type = "EnumClassValue" - + enum_type = "EnumClassValue" if ec.scoped else "EnumValue" prefix = "class {}:".format(ec.name) - for key, value in ec.items: - ec_items.append("{k} = {ectype}({v}) # type: {ectype}".format( - k=key, v=value, ectype=enum_type)) + ec_items = ["{k} = {ectype}({v}) # type: {ectype}".format( + k=key, v=value, ectype=enum_type) for key, value in ec.items] def_items = ec_items.copy() def_items.append("@staticmethod") def_items.append( @@ -830,8 +816,7 @@ def __init__(self, cu: Class, file_suffix: str = ".cc"): def get_pybind_decls( self) -> List[Union[PybindMethodDecl, PybindPropDecl]]: - res = [] # type: List[Union[PybindMethodDecl, PybindPropDecl]] - res.extend(self.func_decls) + res = list(self.func_decls) res.extend(self.prop_decls) return res @@ -900,10 +885,7 @@ def postprocess(self, bind_cus: List[Pybind11SingleClassHandler]): for i in range(1, len(ns_parts) + 1): sub_name = "_".join(ns_parts[:i]) sub_ns = ".".join(ns_parts[:i]) - if i == 1: - parent_name = "m" - else: - parent_name = "m_{}".format("_".join(ns_parts[:i - 1])) + parent_name = "m" if i == 1 else "m_{}".format("_".join(ns_parts[:i - 1])) if sub_ns not in submodules: stmt = "pybind11::module_ m_{} = {}.def_submodule(\"{}\");".format( sub_name, parent_name, ns_parts[i - 1]) @@ -1008,8 +990,7 @@ def get_code_units(self) -> List[Class]: for bind in self.bind_cus: bind.postprocess() self.main_cu.postprocess(self.bind_cus) - res = [] # type: List[Class] - res.extend(self.bind_cus) + res = list(self.bind_cus) res.append(self.main_cu) return res @@ -1027,9 +1008,7 @@ def mark(func=None, keep_alives: Optional[List[Tuple[int, int]]] = None): if virtual: assert not nogil, "you can't release gil for python virtual function." - call_guard = None # type: Optional[str] - if nogil: - call_guard = "pybind11::gil_scoped_release" + call_guard = "pybind11::gil_scoped_release" if nogil else None pybind_meta = Pybind11MethodMeta(bind_name, method_type, prop_name, ret_policy, call_guard, virtual, keep_alives) diff --git a/pccm/targets/cuda_ptx.py b/pccm/targets/cuda_ptx.py index a3ff47f..b20a6a3 100644 --- a/pccm/targets/cuda_ptx.py +++ b/pccm/targets/cuda_ptx.py @@ -274,10 +274,7 @@ def get_stmt_str(self) -> str: op_strs: List[str] = [] name_to_op_str: Dict[str, str] = {} for op in self.params: - if not isinstance(op, list): - ops = [op] - else: - ops = op + ops = [op] if not isinstance(op, list) else op for op in ops: if isinstance(op, Pointer): # addr @@ -329,11 +326,10 @@ def get_stmt_str(self) -> str: # else: op_str_with_packed = "{{{}}}".format( ", ".join(op_strs_regs)) + elif not isinstance(op, (str, int)): + op_str_with_packed = name_to_op_str[op.name] else: - if not isinstance(op, (str, int)): - op_str_with_packed = name_to_op_str[op.name] - else: - op_str_with_packed = str(op) + op_str_with_packed = str(op) op_strs.append(op_str_with_packed) if self.pred_reg: return "@{} {} {};".format(self.pred_reg, self.stmt, @@ -435,11 +431,10 @@ def _ld_or_st(self, is_ld: bool, addr: Pointer, cache_op: Optional[Union[CacheOpLd, CacheOpSt]]): if not addr.is_input: addr.is_input = is_ld - if is_ld: - if cache_op is not None: + if cache_op is not None: + if is_ld: assert isinstance(cache_op, CacheOpLd) - else: - if cache_op is not None: + else: assert isinstance(cache_op, CacheOpSt) if not isinstance(regs, list): @@ -497,10 +492,7 @@ def st(self, def mov(self, dst: REG_OPERAND_TYPES, srcs: Union[List[REG_INPUT_OPERAND_TYPES], REG_INPUT_OPERAND_TYPES]): - if not isinstance(srcs, list): - srcs_check = [srcs] - else: - srcs_check = srcs + srcs_check = [srcs] if not isinstance(srcs, list) else srcs for src in srcs_check: if not isinstance(src, (str, int)): src.is_input = True @@ -568,15 +560,11 @@ def _get_str(self) -> str: asm_input_strs: List[str] = [] # TODO check all ops with same name, they must have same type. for op in all_outputs: - op_has_input = False - for op_same_name in name_to_operands[op.name]: - if op_same_name.is_input: - op_has_input = True - break - if op_has_input: - constraint_letter = "+" - else: - constraint_letter = "=" + op_has_input = any( + op_same_name.is_input for op_same_name in name_to_operands[op.name] + ) + + constraint_letter = "+" if op_has_input else "=" asm_output_strs.append("\"{}{}\"({})".format( constraint_letter, op.get_inline_asm_dtype(), op.name)) for op in all_inputs: @@ -586,9 +574,11 @@ def _get_str(self) -> str: inputs_str = ", ".join(asm_input_strs) stmt_prefix = " " asm_stmts_str = "\n".join( - [stmt_prefix + "\" {}\\n\"".format(s) for s in stmt_strs]) - if len(asm_output_strs) == 0: - if len(asm_input_strs) == 0 and len(stmt_strs) == 0: + stmt_prefix + "\" {}\\n\"".format(s) for s in stmt_strs + ) + + if not asm_output_strs: + if not asm_input_strs and not stmt_strs: return "" return """ asm volatile ( diff --git a/pccm/test_data/mod.py b/pccm/test_data/mod.py index f840033..8153bd4 100644 --- a/pccm/test_data/mod.py +++ b/pccm/test_data/mod.py @@ -103,8 +103,7 @@ def func_2(self): @pccm.pybind.mark @pccm.member_function def run_virtual_func_0(self): - code = pccm.FunctionCode("return func_0();").ret("int") - return code + return pccm.FunctionCode("return func_0();").ret("int") @pccm.pybind.mark @pccm.member_function