diff --git a/pyproject.toml b/pyproject.toml index 5064f715299..1baf708022b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,8 @@ files = [ "src/nncf/api", "src/nncf/data", "src/nncf/common", + "src/nncf/torch/hardware/fused_patterns.py", + "src/nncf/torch/quantization/ignored_patterns.py", "src/nncf/torch/function_hook", "src/nncf/torch/engine.py", "src/nncf/torch/functions.py", @@ -94,6 +96,10 @@ files = [ "src/nncf/experimental/torch/gptqmodel/", "src/nncf/quantization/algorithms/weight_compression/config.py", "src/nncf/quantization/algorithms/weight_compression/constants.py", + "src/nncf/onnx/hardware/fused_patterns.py", + "src/nncf/onnx/quantization/ignored_patterns.py", + "src/nncf/openvino/hardware/fused_patterns.py", + "src/nncf/openvino/quantization/ignored_patterns.py", "examples/llm_compression/torch/gptq_model_convertor/main.py", ] disable_error_code = ["import-untyped"] diff --git a/src/nncf/common/graph/operator_metatypes.py b/src/nncf/common/graph/operator_metatypes.py index 073cbd833c5..ed6bac12777 100644 --- a/src/nncf/common/graph/operator_metatypes.py +++ b/src/nncf/common/graph/operator_metatypes.py @@ -65,10 +65,10 @@ def subtype_check(cls, metatype: type["OperatorMetatype"]) -> bool: return any(subtype.subtype_check(metatype) for subtype in subtypes) -TOpClass = TypeVar("TOpClass", bound=type[OperatorMetatype]) +TRegisterObject = TypeVar("TRegisterObject", bound=type[OperatorMetatype]) -class OperatorMetatypeRegistry(Registry): +class OperatorMetatypeRegistry(Registry[str, type[OperatorMetatype]]): """ Operator Metatypes Registry. """ @@ -82,18 +82,19 @@ def __init__(self, name: str): super().__init__(name) self._op_name_to_op_meta_dict: dict[str, type[OperatorMetatype]] = {} - def register(self, name: str | None = None, is_subtype: bool = False) -> Callable[[TOpClass], TOpClass]: + def register( # type: ignore[override] + self, name: str | None = None + ) -> Callable[[TRegisterObject], TRegisterObject]: """ Decorator for registering operator metatypes. :param name: The registration name. - :param is_subtype: Whether the decorated metatype is a subtype of another registered operator. :return: The inner function for registering operator metatypes. """ name_ = name super_register = super()._register - def wrap(obj: TOpClass) -> TOpClass: + def wrap(obj: TRegisterObject) -> TRegisterObject: """ Inner function for registering operator metatypes. @@ -104,16 +105,16 @@ def wrap(obj: TOpClass) -> TOpClass: if cls_name is None: cls_name = obj.__name__ super_register(obj, cls_name) - if not is_subtype: - op_names = obj.get_all_aliases() - for name in op_names: - if name in self._op_name_to_op_meta_dict: - msg = ( - "Inconsistent operator metatype registry - single patched " - f"op name `{name}` maps to multiple metatypes!" - ) - raise nncf.InternalError(msg) - self._op_name_to_op_meta_dict[name] = obj + op_names = obj.get_all_aliases() + for name in op_names: + if name in self._op_name_to_op_meta_dict and not obj.subtype_check(self._op_name_to_op_meta_dict[name]): + msg = ( + "Inconsistent operator metatype registry - single patched " + f"op name `{name}` maps to multiple metatypes!" + ) + raise nncf.InternalError(msg) + + self._op_name_to_op_meta_dict[name] = obj return obj return wrap @@ -130,10 +131,10 @@ def get_operator_metatype_by_op_name(self, op_name: str) -> type[OperatorMetatyp return self._op_name_to_op_meta_dict[op_name] -NOOP_METATYPES = Registry("noop_metatypes") -INPUT_NOOP_METATYPES = Registry("input_noop_metatypes") -OUTPUT_NOOP_METATYPES = Registry("output_noop_metatypes") -CONST_NOOP_METATYPES = Registry("const_noop_metatypes") +NOOP_METATYPES = Registry[str, type[OperatorMetatype]]("noop_metatypes") +INPUT_NOOP_METATYPES = Registry[str, type[OperatorMetatype]]("input_noop_metatypes") +OUTPUT_NOOP_METATYPES = Registry[str, type[OperatorMetatype]]("output_noop_metatypes") +CONST_NOOP_METATYPES = Registry[str, type[OperatorMetatype]]("const_noop_metatypes") class UnknownMetatype(OperatorMetatype): diff --git a/src/nncf/common/graph/patterns/manager.py b/src/nncf/common/graph/patterns/manager.py index d2657e7941e..e6df1379130 100644 --- a/src/nncf/common/graph/patterns/manager.py +++ b/src/nncf/common/graph/patterns/manager.py @@ -34,24 +34,18 @@ def _get_backend_hw_patterns_map(backend: BackendType) -> dict[HWFusedPatternNam :param backend: BackendType instance. :return: Dictionary with the HWFusedPatternNames instance as keys and creator function as a value. """ - registry: dict[HWFusedPatternNames, Callable[[], GraphPattern]] = {} if backend == BackendType.ONNX: from nncf.onnx.hardware.fused_patterns import ONNX_HW_FUSED_PATTERNS - registry = cast(dict[HWFusedPatternNames, Callable[[], GraphPattern]], ONNX_HW_FUSED_PATTERNS.registry_dict) - return registry + return ONNX_HW_FUSED_PATTERNS.registry_dict if backend == BackendType.OPENVINO: from nncf.openvino.hardware.fused_patterns import OPENVINO_HW_FUSED_PATTERNS - registry = cast( - dict[HWFusedPatternNames, Callable[[], GraphPattern]], OPENVINO_HW_FUSED_PATTERNS.registry_dict - ) - return registry + return OPENVINO_HW_FUSED_PATTERNS.registry_dict if backend in (BackendType.TORCH, BackendType.TORCH_FX): from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS - registry = cast(dict[HWFusedPatternNames, Callable[[], GraphPattern]], PT_HW_FUSED_PATTERNS.registry_dict) - return registry + return PT_HW_FUSED_PATTERNS.registry_dict msg = f"Hardware-fused patterns not implemented for {backend} backend." raise ValueError(msg) @@ -65,24 +59,18 @@ def _get_backend_ignored_patterns_map( :param backend: BackendType instance. :return: Dictionary with the HWFusedPatternNames instance as keys and creator function as a value. """ - registry: dict[IgnoredPatternNames, Callable[[], GraphPattern]] = {} if backend == BackendType.ONNX: from nncf.onnx.quantization.ignored_patterns import ONNX_IGNORED_PATTERNS - registry = cast(dict[IgnoredPatternNames, Callable[[], GraphPattern]], ONNX_IGNORED_PATTERNS.registry_dict) - return registry + return ONNX_IGNORED_PATTERNS.registry_dict if backend == BackendType.OPENVINO: from nncf.openvino.quantization.ignored_patterns import OPENVINO_IGNORED_PATTERNS - registry = cast( - dict[IgnoredPatternNames, Callable[[], GraphPattern]], OPENVINO_IGNORED_PATTERNS.registry_dict - ) - return registry + return OPENVINO_IGNORED_PATTERNS.registry_dict if backend in (BackendType.TORCH, BackendType.TORCH_FX): from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS - registry = cast(dict[IgnoredPatternNames, Callable[[], GraphPattern]], PT_IGNORED_PATTERNS.registry_dict) - return registry + return PT_IGNORED_PATTERNS.registry_dict msg = f"Ignored patterns not implemented for {backend} backend." raise ValueError(msg) diff --git a/src/nncf/common/utils/registry.py b/src/nncf/common/utils/registry.py index 2f432bd15a3..68d2e7cb7cf 100644 --- a/src/nncf/common/utils/registry.py +++ b/src/nncf/common/utils/registry.py @@ -9,52 +9,91 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, TypeVar +from typing import Callable, Generic, TypeVar, ValuesView -TClass = TypeVar("TClass", bound=type) +TKey = TypeVar("TKey") +TObject = TypeVar("TObject") +TRegisterObject = TypeVar("TRegisterObject") -class Registry: - REGISTERED_NAME_ATTR = "_registered_name" +class Registry(Generic[TKey, TObject]): + """ + Generic key-to-object registry. - def __init__(self, name: str, add_name_as_attr: bool = False): + Stores objects by key and provides a decorator-based registration API. + """ + + def __init__(self, name: str): + """ + Initialize a registry. + + :param name: Human-readable registry name used in error messages. + """ self._name = name - self._registry_dict: dict[str, Any] = {} - self._add_name_as_attr = add_name_as_attr + self._registry_dict: dict[TKey, TObject] = {} @property - def registry_dict(self) -> dict[str, Any]: + def registry_dict(self) -> dict[TKey, TObject]: + """ + Return the underlying registry mapping. + + :return: Dictionary with registered objects. + """ return self._registry_dict - def values(self) -> Any: + def values(self) -> ValuesView[TObject]: + """ + Return registered object values. + + :return: View over registered objects. + """ return self._registry_dict.values() - def _register(self, obj: Any, name: str) -> None: + def _register(self, obj: TObject, name: TKey) -> None: if name in self._registry_dict: msg = f"{name} is already registered in {self._name}" raise KeyError(msg) self._registry_dict[name] = obj - def register(self, name: str = None) -> Callable[[TClass], TClass]: - def wrap(obj: TClass) -> TClass: + def register(self, name: TKey | None = None) -> Callable[[TRegisterObject], TRegisterObject]: + """ + Create a decorator that registers an object in the registry. + + If `name` is not provided, `obj.__name__` is used. + + :param name: Explicit key for registration. + :return: Decorator that registers and returns the input object. + """ + + def wrap(obj: TRegisterObject) -> TRegisterObject: cls_name = name if cls_name is None: - cls_name = obj.__name__ - if self._add_name_as_attr: - setattr(obj, self.REGISTERED_NAME_ATTR, name) - self._register(obj, cls_name) + cls_name = obj.__name__ # type: ignore[attr-defined] + self._register(obj, cls_name) # type: ignore[arg-type] return obj return wrap - def get(self, name: str) -> Any: + def get(self, name: TKey) -> TObject: + """ + Get a registered object by key. + + :param name: Registry key. + :return: Registered object associated with `name`. + """ if name not in self._registry_dict: self._key_not_found(name) return self._registry_dict[name] - def _key_not_found(self, name: str) -> None: + def _key_not_found(self, name: TKey) -> None: msg = f"{name} is unknown type of {self._name} " raise KeyError(msg) - def __contains__(self, item: Any) -> bool: + def __contains__(self, item: TObject) -> bool: + """ + Check whether an object instance is present in registered values. + + :param item: Object to check. + :return: `True` if object is registered, otherwise `False`. + """ return item in self._registry_dict.values() diff --git a/src/nncf/onnx/graph/metatypes/onnx_metatypes.py b/src/nncf/onnx/graph/metatypes/onnx_metatypes.py index fcb6e0bc288..8ae545e928a 100644 --- a/src/nncf/onnx/graph/metatypes/onnx_metatypes.py +++ b/src/nncf/onnx/graph/metatypes/onnx_metatypes.py @@ -71,7 +71,7 @@ class ONNXOpWithWeightsMetatype(ONNXOpMetatype): possible_weight_ports: list[int] = [] -@ONNX_OPERATION_METATYPES.register(is_subtype=True) +@ONNX_OPERATION_METATYPES.register() class ONNXDepthwiseConvolutionMetatype(ONNXOpWithWeightsMetatype): name = "DepthwiseConvOp" op_names = ["Conv"] @@ -86,7 +86,7 @@ def matches(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> bool: return _is_depthwise_conv(model, node) -@ONNX_OPERATION_METATYPES.register(is_subtype=True) +@ONNX_OPERATION_METATYPES.register() class ONNXGroupConvolutionMetatype(ONNXOpWithWeightsMetatype): name = "GroupConvOp" op_names = ["Conv"] @@ -450,7 +450,7 @@ class ONNXReciprocalMetatype(ONNXOpMetatype): hw_config_names = [HWOpName.POWER] -@ONNX_OPERATION_METATYPES.register(is_subtype=True) +@ONNX_OPERATION_METATYPES.register() class ONNXEmbeddingMetatype(ONNXOpWithWeightsMetatype): name = "EmbeddingOp" hw_config_names = [HWOpName.EMBEDDING] diff --git a/src/nncf/onnx/hardware/fused_patterns.py b/src/nncf/onnx/hardware/fused_patterns.py index cf95f372c8c..453c1d8e3bf 100644 --- a/src/nncf/onnx/hardware/fused_patterns.py +++ b/src/nncf/onnx/hardware/fused_patterns.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable + from nncf.common.graph.operator_metatypes import InputNoopMetatype from nncf.common.graph.patterns import GraphPattern from nncf.common.graph.patterns import HWFusedPatternNames @@ -19,7 +21,7 @@ from nncf.onnx.graph.metatypes.groups import BATCH_NORMALIZATION_OPERATIONS from nncf.onnx.graph.metatypes.groups import LINEAR_OPERATIONS -ONNX_HW_FUSED_PATTERNS = Registry("onnx") +ONNX_HW_FUSED_PATTERNS = Registry[HWFusedPatternNames, Callable[[], GraphPattern]]("onnx_hw_fused_patterns") # BLOCK PATTERNS diff --git a/src/nncf/onnx/quantization/ignored_patterns.py b/src/nncf/onnx/quantization/ignored_patterns.py index e28cd918234..e183fca2b0e 100644 --- a/src/nncf/onnx/quantization/ignored_patterns.py +++ b/src/nncf/onnx/quantization/ignored_patterns.py @@ -8,6 +8,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable + from nncf.common.graph.patterns.patterns import GraphPattern from nncf.common.graph.patterns.patterns import IgnoredPatternNames from nncf.common.utils.registry import Registry @@ -15,7 +17,7 @@ from nncf.onnx.graph.metatypes.groups import MATMUL_METATYPES from nncf.onnx.hardware.fused_patterns import atomic_activations_operations -ONNX_IGNORED_PATTERNS = Registry("IGNORED_PATTERNS") +ONNX_IGNORED_PATTERNS = Registry[IgnoredPatternNames, Callable[[], GraphPattern]]("onnx_ignored_patterns") def _add_softmax_matmul(pattern: GraphPattern) -> None: diff --git a/src/nncf/openvino/graph/metatypes/openvino_metatypes.py b/src/nncf/openvino/graph/metatypes/openvino_metatypes.py index 3868a5337ae..47288ea5b51 100644 --- a/src/nncf/openvino/graph/metatypes/openvino_metatypes.py +++ b/src/nncf/openvino/graph/metatypes/openvino_metatypes.py @@ -70,7 +70,7 @@ class OVConvolutionBackpropDataMetatype(OVOpMetatype): output_channel_axis = 1 -@OV_OPERATOR_METATYPES.register(is_subtype=True) +@OV_OPERATOR_METATYPES.register() class OVDepthwiseConvolutionMetatype(OVOpMetatype): name = "DepthwiseConvolutionOp" op_names = ["GroupConvolution"] @@ -424,7 +424,7 @@ class OVLogicalXorMetatype(OVOpMetatype): hw_config_names = [HWOpName.LOGICAL_XOR] -@OV_OPERATOR_METATYPES.register(is_subtype=True) +@OV_OPERATOR_METATYPES.register() class OVEmbeddingMetatype(OVOpMetatype): name = "EmbeddingOp" hw_config_names = [HWOpName.EMBEDDING] diff --git a/src/nncf/openvino/hardware/fused_patterns.py b/src/nncf/openvino/hardware/fused_patterns.py index 74c5e7d2c76..c2f3cea6528 100644 --- a/src/nncf/openvino/hardware/fused_patterns.py +++ b/src/nncf/openvino/hardware/fused_patterns.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable + from nncf.common.graph.patterns import GraphPattern from nncf.common.graph.patterns import HWFusedPatternNames from nncf.common.utils.registry import Registry @@ -19,7 +21,7 @@ from nncf.openvino.graph.metatypes.groups import ELEMENTWISE_OPERATIONS from nncf.openvino.graph.metatypes.groups import LINEAR_OPERATIONS -OPENVINO_HW_FUSED_PATTERNS = Registry("openvino") +OPENVINO_HW_FUSED_PATTERNS = Registry[HWFusedPatternNames, Callable[[], GraphPattern]]("openvino_hw_fused_patterns") # BLOCK PATTERNS @@ -596,7 +598,7 @@ def create_mvn_scale_shift_activations() -> GraphPattern: @OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.LINEAR_ACTIVATIONS_UNSQUEEZE_BN_SQUEEZE) -def create_linear_activations_unsqueeze_bn_squeeze(): +def create_linear_activations_unsqueeze_bn_squeeze() -> GraphPattern: linear_biased = create_biased_op() activations = atomic_activations_operations() unsqueeze_op = unsqueeze_operation() diff --git a/src/nncf/openvino/quantization/ignored_patterns.py b/src/nncf/openvino/quantization/ignored_patterns.py index 713e7c149b2..689d88e2548 100644 --- a/src/nncf/openvino/quantization/ignored_patterns.py +++ b/src/nncf/openvino/quantization/ignored_patterns.py @@ -8,16 +8,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable + +from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.patterns.patterns import GraphPattern from nncf.common.graph.patterns.patterns import IgnoredPatternNames from nncf.common.utils.registry import Registry from nncf.openvino.graph.metatypes import openvino_metatypes as om from nncf.openvino.graph.metatypes.groups import LINEAR_OPERATIONS -OPENVINO_IGNORED_PATTERNS = Registry("IGNORED_PATTERNS") +OPENVINO_IGNORED_PATTERNS = Registry[IgnoredPatternNames, Callable[[], GraphPattern]]("openvino_ignored_patterns") -def _add_softmax_matmul(pattern: GraphPattern, branch_matmul_nodes: list[om.OperatorMetatype]) -> None: +def _add_softmax_matmul(pattern: GraphPattern, branch_matmul_nodes: list[type[OperatorMetatype]]) -> None: # SOFTMAX READVALUE||RESHAPE||TRANSPOSE||GATHER||SQUEEZE||CONCAT # \ / # \ / @@ -37,7 +40,7 @@ def _add_softmax_matmul(pattern: GraphPattern, branch_matmul_nodes: list[om.Oper pattern.add_edge(matmul_branch_nodes, matmul) -def _add_softmax_reshape_matmul(pattern: GraphPattern, branch_matmul_nodes: list[om.OperatorMetatype]) -> None: +def _add_softmax_reshape_matmul(pattern: GraphPattern, branch_matmul_nodes: list[type[OperatorMetatype]]) -> None: # SOFTMAX # \ # \ diff --git a/src/nncf/torch/graph/operator_metatypes.py b/src/nncf/torch/graph/operator_metatypes.py index e800825ffba..ce8807bb378 100644 --- a/src/nncf/torch/graph/operator_metatypes.py +++ b/src/nncf/torch/graph/operator_metatypes.py @@ -178,7 +178,7 @@ class PTNoopMetatype(PTOperatorMetatype): } -@PT_OPERATOR_METATYPES.register(is_subtype=True) +@PT_OPERATOR_METATYPES.register() class PTDepthwiseConv1dSubtype(PTDepthwiseConvOperatorSubtype): name = "Conv1DOp" hw_config_name = [HWOpName.DEPTHWISE_CONVOLUTION] @@ -201,7 +201,7 @@ class PTConv1dMetatype(PTOperatorMetatype): bias_port_id = 2 -@PT_OPERATOR_METATYPES.register(is_subtype=True) +@PT_OPERATOR_METATYPES.register() class PTDepthwiseConv2dSubtype(PTDepthwiseConvOperatorSubtype): name = "Conv2DOp" hw_config_names = [HWOpName.DEPTHWISE_CONVOLUTION] @@ -224,7 +224,7 @@ class PTConv2dMetatype(PTOperatorMetatype): bias_port_id = 2 -@PT_OPERATOR_METATYPES.register(is_subtype=True) +@PT_OPERATOR_METATYPES.register() class PTDepthwiseConv3dSubtype(PTDepthwiseConvOperatorSubtype): name = "Conv3DOp" hw_config_names = [HWOpName.DEPTHWISE_CONVOLUTION] diff --git a/src/nncf/torch/hardware/fused_patterns.py b/src/nncf/torch/hardware/fused_patterns.py index ed9d93789eb..19ea444c0d2 100644 --- a/src/nncf/torch/hardware/fused_patterns.py +++ b/src/nncf/torch/hardware/fused_patterns.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable + from nncf.common.graph.patterns import GraphPattern from nncf.common.graph.patterns import HWFusedPatternNames from nncf.common.utils.registry import Registry @@ -21,7 +23,7 @@ from nncf.torch.graph.pattern_operations import LINEAR_OPERATIONS from nncf.torch.graph.pattern_operations import RELU_OPERATIONS -PT_HW_FUSED_PATTERNS = Registry("torch") +PT_HW_FUSED_PATTERNS = Registry[HWFusedPatternNames, Callable[[], GraphPattern]]("torch_hw_fused_patterns") # ATOMIC OPERATIONS diff --git a/src/nncf/torch/quantization/ignored_patterns.py b/src/nncf/torch/quantization/ignored_patterns.py index 6211d6a65d7..3f5ee79b962 100644 --- a/src/nncf/torch/quantization/ignored_patterns.py +++ b/src/nncf/torch/quantization/ignored_patterns.py @@ -8,6 +8,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable + +from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.patterns.patterns import GraphPattern from nncf.common.graph.patterns.patterns import IgnoredPatternNames from nncf.common.utils.registry import Registry @@ -15,16 +18,16 @@ from nncf.torch.graph.pattern_operations import ATOMIC_ACTIVATIONS_OPERATIONS from nncf.torch.graph.pattern_operations import LINEAR_OPERATIONS -PT_IGNORED_PATTERNS = Registry("IGNORED_PATTERNS") +PT_IGNORED_PATTERNS = Registry[IgnoredPatternNames, Callable[[], GraphPattern]]("pt_ignored_patterns") def _add_softmax_matmul( pattern: GraphPattern, - matmul_metatypes, - reshape_squeeze_metatypes, - gather_metatypes, - transpose_metatypes, - concat_metatypes, + matmul_metatypes: list[type[OperatorMetatype]], + reshape_squeeze_metatypes: list[type[OperatorMetatype]], + gather_metatypes: list[type[OperatorMetatype]], + transpose_metatypes: list[type[OperatorMetatype]], + concat_metatypes: list[type[OperatorMetatype]], ) -> None: # SOFTMAX RESHAPE||TRANSPOSE||GATHER||SQUEEZE||CONCAT # \ / @@ -45,11 +48,11 @@ def _add_softmax_matmul( def _add_softmax_reshape_matmul( pattern: GraphPattern, - matmul_metatypes, - reshape_squeeze_metatypes, - gather_metatypes, - transpose_metatypes, - concat_metatypes, + matmul_metatypes: list[type[OperatorMetatype]], + reshape_squeeze_metatypes: list[type[OperatorMetatype]], + gather_metatypes: list[type[OperatorMetatype]], + transpose_metatypes: list[type[OperatorMetatype]], + concat_metatypes: list[type[OperatorMetatype]], ) -> None: # SOFTMAX # \ @@ -78,16 +81,19 @@ def _add_softmax_reshape_matmul( pattern.add_edge(softmax, reshape) pattern.add_edge(reshape, matmul) pattern.add_edge(matmul_branch_nodes, matmul) - return pattern @PT_IGNORED_PATTERNS.register(IgnoredPatternNames.MULTIHEAD_ATTENTION_OUTPUT) def create_multihead_attention_output() -> GraphPattern: - matmul_metatypes = [om.PTLinearMetatype, om.PTAddmmMetatype, om.PTMatMulMetatype] - reshape_squeeze_metatypes = [om.PTReshapeMetatype, om.PTSqueezeMetatype, om.PTSplitMetatype] - gather_metatypes = [om.PTGatherMetatype] - transpose_metatypes = [om.PTTransposeMetatype] - concat_metatypes = [om.PTCatMetatype] + matmul_metatypes: list[type[OperatorMetatype]] = [om.PTLinearMetatype, om.PTAddmmMetatype, om.PTMatMulMetatype] + gather_metatypes: list[type[OperatorMetatype]] = [om.PTGatherMetatype] + transpose_metatypes: list[type[OperatorMetatype]] = [om.PTTransposeMetatype] + concat_metatypes: list[type[OperatorMetatype]] = [om.PTCatMetatype] + reshape_squeeze_metatypes: list[type[OperatorMetatype]] = [ + om.PTReshapeMetatype, + om.PTSqueezeMetatype, + om.PTSplitMetatype, + ] pattern = GraphPattern() _add_softmax_matmul(