Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ files = [
"src/nncf/api",
"src/nncf/data",
"src/nncf/common",
"src/nncf/torch/hardware/fused_patterns.py",
"src/nncf/torch/quantization/ignored_patterns.py",
"src/nncf/torch/function_hook",
"src/nncf/torch/engine.py",
"src/nncf/torch/functions.py",
Expand All @@ -94,6 +96,10 @@ files = [
"src/nncf/experimental/torch/gptqmodel/",
"src/nncf/quantization/algorithms/weight_compression/config.py",
"src/nncf/quantization/algorithms/weight_compression/constants.py",
"src/nncf/onnx/hardware/fused_patterns.py",
"src/nncf/onnx/quantization/ignored_patterns.py",
"src/nncf/openvino/hardware/fused_patterns.py",
"src/nncf/openvino/quantization/ignored_patterns.py",
"examples/llm_compression/torch/gptq_model_convertor/main.py",
]
disable_error_code = ["import-untyped"]
Expand Down
39 changes: 20 additions & 19 deletions src/nncf/common/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def subtype_check(cls, metatype: type["OperatorMetatype"]) -> bool:
return any(subtype.subtype_check(metatype) for subtype in subtypes)


TOpClass = TypeVar("TOpClass", bound=type[OperatorMetatype])
TRegisterObject = TypeVar("TRegisterObject", bound=type[OperatorMetatype])


class OperatorMetatypeRegistry(Registry):
class OperatorMetatypeRegistry(Registry[str, type[OperatorMetatype]]):
"""
Operator Metatypes Registry.
"""
Expand All @@ -82,18 +82,19 @@ def __init__(self, name: str):
super().__init__(name)
self._op_name_to_op_meta_dict: dict[str, type[OperatorMetatype]] = {}

def register(self, name: str | None = None, is_subtype: bool = False) -> Callable[[TOpClass], TOpClass]:
def register( # type: ignore[override]
self, name: str | None = None
) -> Callable[[TRegisterObject], TRegisterObject]:
"""
Decorator for registering operator metatypes.

:param name: The registration name.
:param is_subtype: Whether the decorated metatype is a subtype of another registered operator.
:return: The inner function for registering operator metatypes.
"""
name_ = name
super_register = super()._register

def wrap(obj: TOpClass) -> TOpClass:
def wrap(obj: TRegisterObject) -> TRegisterObject:
"""
Inner function for registering operator metatypes.

Expand All @@ -104,16 +105,16 @@ def wrap(obj: TOpClass) -> TOpClass:
if cls_name is None:
cls_name = obj.__name__
super_register(obj, cls_name)
if not is_subtype:
op_names = obj.get_all_aliases()
for name in op_names:
if name in self._op_name_to_op_meta_dict:
msg = (
"Inconsistent operator metatype registry - single patched "
f"op name `{name}` maps to multiple metatypes!"
)
raise nncf.InternalError(msg)
self._op_name_to_op_meta_dict[name] = obj
op_names = obj.get_all_aliases()
for name in op_names:
if name in self._op_name_to_op_meta_dict and not obj.subtype_check(self._op_name_to_op_meta_dict[name]):
msg = (
"Inconsistent operator metatype registry - single patched "
f"op name `{name}` maps to multiple metatypes!"
)
raise nncf.InternalError(msg)

self._op_name_to_op_meta_dict[name] = obj
return obj

return wrap
Expand All @@ -130,10 +131,10 @@ def get_operator_metatype_by_op_name(self, op_name: str) -> type[OperatorMetatyp
return self._op_name_to_op_meta_dict[op_name]


NOOP_METATYPES = Registry("noop_metatypes")
INPUT_NOOP_METATYPES = Registry("input_noop_metatypes")
OUTPUT_NOOP_METATYPES = Registry("output_noop_metatypes")
CONST_NOOP_METATYPES = Registry("const_noop_metatypes")
NOOP_METATYPES = Registry[str, type[OperatorMetatype]]("noop_metatypes")
INPUT_NOOP_METATYPES = Registry[str, type[OperatorMetatype]]("input_noop_metatypes")
OUTPUT_NOOP_METATYPES = Registry[str, type[OperatorMetatype]]("output_noop_metatypes")
CONST_NOOP_METATYPES = Registry[str, type[OperatorMetatype]]("const_noop_metatypes")


class UnknownMetatype(OperatorMetatype):
Expand Down
24 changes: 6 additions & 18 deletions src/nncf/common/graph/patterns/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,18 @@ def _get_backend_hw_patterns_map(backend: BackendType) -> dict[HWFusedPatternNam
:param backend: BackendType instance.
:return: Dictionary with the HWFusedPatternNames instance as keys and creator function as a value.
"""
registry: dict[HWFusedPatternNames, Callable[[], GraphPattern]] = {}
if backend == BackendType.ONNX:
from nncf.onnx.hardware.fused_patterns import ONNX_HW_FUSED_PATTERNS

registry = cast(dict[HWFusedPatternNames, Callable[[], GraphPattern]], ONNX_HW_FUSED_PATTERNS.registry_dict)
return registry
return ONNX_HW_FUSED_PATTERNS.registry_dict
if backend == BackendType.OPENVINO:
from nncf.openvino.hardware.fused_patterns import OPENVINO_HW_FUSED_PATTERNS

registry = cast(
dict[HWFusedPatternNames, Callable[[], GraphPattern]], OPENVINO_HW_FUSED_PATTERNS.registry_dict
)
return registry
return OPENVINO_HW_FUSED_PATTERNS.registry_dict
if backend in (BackendType.TORCH, BackendType.TORCH_FX):
from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS

registry = cast(dict[HWFusedPatternNames, Callable[[], GraphPattern]], PT_HW_FUSED_PATTERNS.registry_dict)
return registry
return PT_HW_FUSED_PATTERNS.registry_dict
msg = f"Hardware-fused patterns not implemented for {backend} backend."
raise ValueError(msg)

Expand All @@ -65,24 +59,18 @@ def _get_backend_ignored_patterns_map(
:param backend: BackendType instance.
:return: Dictionary with the HWFusedPatternNames instance as keys and creator function as a value.
"""
registry: dict[IgnoredPatternNames, Callable[[], GraphPattern]] = {}
if backend == BackendType.ONNX:
from nncf.onnx.quantization.ignored_patterns import ONNX_IGNORED_PATTERNS

registry = cast(dict[IgnoredPatternNames, Callable[[], GraphPattern]], ONNX_IGNORED_PATTERNS.registry_dict)
return registry
return ONNX_IGNORED_PATTERNS.registry_dict
if backend == BackendType.OPENVINO:
from nncf.openvino.quantization.ignored_patterns import OPENVINO_IGNORED_PATTERNS

registry = cast(
dict[IgnoredPatternNames, Callable[[], GraphPattern]], OPENVINO_IGNORED_PATTERNS.registry_dict
)
return registry
return OPENVINO_IGNORED_PATTERNS.registry_dict
if backend in (BackendType.TORCH, BackendType.TORCH_FX):
from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS

registry = cast(dict[IgnoredPatternNames, Callable[[], GraphPattern]], PT_IGNORED_PATTERNS.registry_dict)
return registry
return PT_IGNORED_PATTERNS.registry_dict
msg = f"Ignored patterns not implemented for {backend} backend."
raise ValueError(msg)

Expand Down
77 changes: 58 additions & 19 deletions src/nncf/common/utils/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,91 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, TypeVar
from typing import Callable, Generic, TypeVar, ValuesView

TClass = TypeVar("TClass", bound=type)
TKey = TypeVar("TKey")
TObject = TypeVar("TObject")
TRegisterObject = TypeVar("TRegisterObject")


class Registry:
REGISTERED_NAME_ATTR = "_registered_name"
class Registry(Generic[TKey, TObject]):
"""
Generic key-to-object registry.

def __init__(self, name: str, add_name_as_attr: bool = False):
Stores objects by key and provides a decorator-based registration API.
"""

def __init__(self, name: str):
"""
Initialize a registry.

:param name: Human-readable registry name used in error messages.
"""
self._name = name
self._registry_dict: dict[str, Any] = {}
self._add_name_as_attr = add_name_as_attr
self._registry_dict: dict[TKey, TObject] = {}

@property
def registry_dict(self) -> dict[str, Any]:
def registry_dict(self) -> dict[TKey, TObject]:
"""
Return the underlying registry mapping.

:return: Dictionary with registered objects.
"""
return self._registry_dict

def values(self) -> Any:
def values(self) -> ValuesView[TObject]:
"""
Return registered object values.

:return: View over registered objects.
"""
return self._registry_dict.values()

def _register(self, obj: Any, name: str) -> None:
def _register(self, obj: TObject, name: TKey) -> None:
if name in self._registry_dict:
msg = f"{name} is already registered in {self._name}"
raise KeyError(msg)
self._registry_dict[name] = obj

def register(self, name: str = None) -> Callable[[TClass], TClass]:
def wrap(obj: TClass) -> TClass:
def register(self, name: TKey | None = None) -> Callable[[TRegisterObject], TRegisterObject]:
"""
Create a decorator that registers an object in the registry.

If `name` is not provided, `obj.__name__` is used.

:param name: Explicit key for registration.
:return: Decorator that registers and returns the input object.
"""

def wrap(obj: TRegisterObject) -> TRegisterObject:
cls_name = name
if cls_name is None:
cls_name = obj.__name__
if self._add_name_as_attr:
setattr(obj, self.REGISTERED_NAME_ATTR, name)
self._register(obj, cls_name)
cls_name = obj.__name__ # type: ignore[attr-defined]
self._register(obj, cls_name) # type: ignore[arg-type]
Comment thread
AlexanderDokuchaev marked this conversation as resolved.
return obj

return wrap

def get(self, name: str) -> Any:
def get(self, name: TKey) -> TObject:
"""
Get a registered object by key.

:param name: Registry key.
:return: Registered object associated with `name`.
"""
if name not in self._registry_dict:
self._key_not_found(name)
return self._registry_dict[name]

def _key_not_found(self, name: str) -> None:
def _key_not_found(self, name: TKey) -> None:
msg = f"{name} is unknown type of {self._name} "
raise KeyError(msg)

def __contains__(self, item: Any) -> bool:
def __contains__(self, item: TObject) -> bool:
"""
Check whether an object instance is present in registered values.

:param item: Object to check.
:return: `True` if object is registered, otherwise `False`.
"""
return item in self._registry_dict.values()
6 changes: 3 additions & 3 deletions src/nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class ONNXOpWithWeightsMetatype(ONNXOpMetatype):
possible_weight_ports: list[int] = []


@ONNX_OPERATION_METATYPES.register(is_subtype=True)
@ONNX_OPERATION_METATYPES.register()
class ONNXDepthwiseConvolutionMetatype(ONNXOpWithWeightsMetatype):
name = "DepthwiseConvOp"
op_names = ["Conv"]
Expand All @@ -86,7 +86,7 @@ def matches(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
return _is_depthwise_conv(model, node)


@ONNX_OPERATION_METATYPES.register(is_subtype=True)
@ONNX_OPERATION_METATYPES.register()
class ONNXGroupConvolutionMetatype(ONNXOpWithWeightsMetatype):
name = "GroupConvOp"
op_names = ["Conv"]
Expand Down Expand Up @@ -450,7 +450,7 @@ class ONNXReciprocalMetatype(ONNXOpMetatype):
hw_config_names = [HWOpName.POWER]


@ONNX_OPERATION_METATYPES.register(is_subtype=True)
@ONNX_OPERATION_METATYPES.register()
class ONNXEmbeddingMetatype(ONNXOpWithWeightsMetatype):
name = "EmbeddingOp"
hw_config_names = [HWOpName.EMBEDDING]
Expand Down
4 changes: 3 additions & 1 deletion src/nncf/onnx/hardware/fused_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable

from nncf.common.graph.operator_metatypes import InputNoopMetatype
from nncf.common.graph.patterns import GraphPattern
from nncf.common.graph.patterns import HWFusedPatternNames
Expand All @@ -19,7 +21,7 @@
from nncf.onnx.graph.metatypes.groups import BATCH_NORMALIZATION_OPERATIONS
from nncf.onnx.graph.metatypes.groups import LINEAR_OPERATIONS

ONNX_HW_FUSED_PATTERNS = Registry("onnx")
ONNX_HW_FUSED_PATTERNS = Registry[HWFusedPatternNames, Callable[[], GraphPattern]]("onnx_hw_fused_patterns")
Comment thread
AlexanderDokuchaev marked this conversation as resolved.

# BLOCK PATTERNS

Expand Down
4 changes: 3 additions & 1 deletion src/nncf/onnx/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable

from nncf.common.graph.patterns.patterns import GraphPattern
from nncf.common.graph.patterns.patterns import IgnoredPatternNames
from nncf.common.utils.registry import Registry
from nncf.onnx.graph.metatypes import onnx_metatypes as om
from nncf.onnx.graph.metatypes.groups import MATMUL_METATYPES
from nncf.onnx.hardware.fused_patterns import atomic_activations_operations

ONNX_IGNORED_PATTERNS = Registry("IGNORED_PATTERNS")
ONNX_IGNORED_PATTERNS = Registry[IgnoredPatternNames, Callable[[], GraphPattern]]("onnx_ignored_patterns")


def _add_softmax_matmul(pattern: GraphPattern) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/nncf/openvino/graph/metatypes/openvino_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class OVConvolutionBackpropDataMetatype(OVOpMetatype):
output_channel_axis = 1


@OV_OPERATOR_METATYPES.register(is_subtype=True)
@OV_OPERATOR_METATYPES.register()
class OVDepthwiseConvolutionMetatype(OVOpMetatype):
name = "DepthwiseConvolutionOp"
op_names = ["GroupConvolution"]
Expand Down Expand Up @@ -424,7 +424,7 @@ class OVLogicalXorMetatype(OVOpMetatype):
hw_config_names = [HWOpName.LOGICAL_XOR]


@OV_OPERATOR_METATYPES.register(is_subtype=True)
@OV_OPERATOR_METATYPES.register()
class OVEmbeddingMetatype(OVOpMetatype):
name = "EmbeddingOp"
hw_config_names = [HWOpName.EMBEDDING]
Expand Down
6 changes: 4 additions & 2 deletions src/nncf/openvino/hardware/fused_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable

from nncf.common.graph.patterns import GraphPattern
from nncf.common.graph.patterns import HWFusedPatternNames
from nncf.common.utils.registry import Registry
Expand All @@ -19,7 +21,7 @@
from nncf.openvino.graph.metatypes.groups import ELEMENTWISE_OPERATIONS
from nncf.openvino.graph.metatypes.groups import LINEAR_OPERATIONS

OPENVINO_HW_FUSED_PATTERNS = Registry("openvino")
OPENVINO_HW_FUSED_PATTERNS = Registry[HWFusedPatternNames, Callable[[], GraphPattern]]("openvino_hw_fused_patterns")

# BLOCK PATTERNS

Expand Down Expand Up @@ -596,7 +598,7 @@ def create_mvn_scale_shift_activations() -> GraphPattern:


@OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.LINEAR_ACTIVATIONS_UNSQUEEZE_BN_SQUEEZE)
def create_linear_activations_unsqueeze_bn_squeeze():
def create_linear_activations_unsqueeze_bn_squeeze() -> GraphPattern:
linear_biased = create_biased_op()
activations = atomic_activations_operations()
unsqueeze_op = unsqueeze_operation()
Expand Down
Loading