Skip to content

[quantization] Introduce QuantConv3dDecomposed wrapper for Conv3d#516

Merged
mhs4670go merged 1 commit intoSamsung:mainfrom
dvsav:quant_conv3d_decomposed
Mar 6, 2026
Merged

[quantization] Introduce QuantConv3dDecomposed wrapper for Conv3d#516
mhs4670go merged 1 commit intoSamsung:mainfrom
dvsav:quant_conv3d_decomposed

Conversation

@dvsav
Copy link
Contributor

@dvsav dvsav commented Feb 24, 2026

Why?

Conv3d is decomposed into Conv2d+Add operations during the conversion to Circle (tico.convert) via ConvertConv3dToConv2d pass. The ConvertConv3dToConv2d pass introduces the Conv2d and Add operations that don't participate in quantization (their outputs are not calibrated and quantized) and therefore remain floating-point which undermines the whole task of model quantization (#489). Therefore we need to decompose Conv3d at the Quant wrapper level (before calibration/quantization) to inject observers for Conv2d and Add outputs.

What

This change introduces:

  • Class QuantConv3dDecomposed (tico/quantization/wrapq/wrappers/nn/quant_conv3d_decomposed.py).
  • Unit tests: class TestQuantConv3dDecomposed (test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py).
  • New entry in _CORE_MODULES (tico/quantization/wrapq/wrappers/registry.py).
  • Examples of Conv3d quantization and conversion to Circle:
    • tico/quantization/wrapq/examples/nn/quantize_conv3d.py
    • tico/quantization/wrapq/examples/nn/quantize_conv3d_special_case.py

Unit Tests

$ coverage run -m pytest test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py -v
======================================================================================= test session starts ========================================================================================
platform linux -- Python 3.10.12, pytest-8.4.0, pluggy-1.6.0 -- /home/d.savchenkov/myenv/bin/python
cachedir: .pytest_cache
rootdir: /home/d.savchenkov/TICO
configfile: pyproject.toml
plugins: anyio-4.12.0, mock-3.15.1, xdist-3.7.0, cov-6.2.1
collected 20 items                                                                                                                                                                                 

test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_activation_stats_collected         PASSED                                               [  5%]
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_all_observers_yielded              PASSED                                               [ 10%]
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_conv3d_without_bias                PASSED                                               [ 15%]
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_decomposition_correctness_no_quant PASSED                                               [ 20%]
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_decomposition_various_shapes       PASSED                                               [ 25%]
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_different_kernel_sizes             PASSED                                               [ 30%]
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_different_padding                  PASSED                                               [ 35%]
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_different_strides                  PASSED                                               [ 40%]
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_dilation                           PASSED                                               [ 45%]
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_dtype_override                     PASSED                                               [ 50%]
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_dynamic_activation_stats_collected PASSED                                               [ 55%]
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_dynamic_observers_created          PASSED                                               [ 60%]
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_mode_transitions                   PASSED                                               [ 65%]
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_multiple_calibration_cycles        PASSED                                               [ 70%]
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_observers_reused_across_calls      PASSED                                               [ 75%]
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_output_shape_correctness           PASSED                                               [ 80%]
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_per_channel_weight_quantization    PASSED                                               [ 85%]
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_quantized_output_close             PASSED                                               [ 90%]
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_registration_in_registry           PASSED                                               [ 95%]
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_temporal_padding                   PASSED                                               [100%]

================================================================================== 20 passed, 2 warnings in 5.25s ==================================================================================

Coverage info (irrelevant files skipped):

$ coverage report -m
Name                                                                    Stmts   Miss  Cover   Missing
-----------------------------------------------------------------------------------------------------
tico/quantization/wrapq/wrappers/nn/quant_conv3d_decomposed.py            115      7    94%   116, 123-127, 250, 275
-----------------------------------------------------------------------------------------------------
TOTAL                                                                   10445   6695    36%

Not covered lines are related to invalid padding scheme exception generation in QuantConv3dDecomposed._parse_padding. These lines are not covered because the creation of Conv3d with an invalid padding scheme fails before we have a chance to reach QuantConv3dDecomposed._parse_padding.

Example Script (quantize Conv3d and convert to Circle)

$ python tico/quantization/wrapq/examples/nn/quantize_conv3d.py
Input channels:  3
Output channels: 1024
Kernel size:     (2, 3, 3)
Stride:          (1, 1, 1)
Padding:         (0, 1, 1)
Input shape:              torch.Size([1, 3, 4, 8, 8])
Output shape (FP32):      torch.Size([1, 1024, 3, 8, 8])
Output shape (Quantized): torch.Size([1, 1024, 3, 8, 8])
┌───────────── Quantization Error Summary ─────────────
│ Mean |diff|: 0.010323
│ PEIR       : 0.989106 %
└──────────────────────────────────────────────────────
    ┌────────────────────────────────────────────┐
 2.9┤                                            │
    │                                         •  │
    │                                      ••    │
 1.9┤                                    •••     │
    │                                  •••       │
    │                                •••         │
    │                              •••           │
 1.0┤                           ••••             │
    │                         ••••               │
    │                       ••••                 │
-0.0┤                     ••••                   │
    │                   ••••                     │
    │                 ••••                       │
    │               ••••                         │
-1.0┤             ••••                           │
    │           •••                              │
    │         •••                                │
-2.0┤       •••                                  │
    │     •••                                    │
    │    ••                                      │
    │  •                                         │
-3.0┤                                            │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -3.0       -1.5       -0.0       1.5       2.9 


Converting to Circle format...
Circle model saved as 'quantized_conv3d.circle'

1. Note For Reviewers

I've decided to intriduce a new QauntConv3dDecomposed wrapper rather then modify the existing QauntConv3d because:

  • QauntConv3d is much simpler.
  • QauntConv3d might become useful in future (e.g. when Circle starts supporting Conv3d).

2. Note For Reviewers

Decomposition of Conv3d assumes applying Conv2d in a loop to the temporal slices the of input tensors and the temporal slices of Conv3d kernel. The loop takes place in QauntConv3dDecomposed.forward method and the number of loop iterations depends on:

  • kernel_depth (kT) - fixed in __init__
  • T_out - depends on input size, only known during forward
  • N, C_in - depends on input shape, only known during forward

The number of loop iterations determines the number of Conv2d+Add operations and hence, the number of observers to be injected. Therefore the observers cannot be created until QauntConv3dDecomposed.forward is called for the first time in CALIB mode. Hence, the observers for Conv2d+Add operations are created in QauntConv3dDecomposed.forward.

Example Script: Special Case (kernel_size = input_size = stride)

$ python3 tico/quantization/wrapq/examples/nn/quantize_conv3d_special_case.py
Input channels:  3
Output channels: 1024
Kernel size:     (2, 16, 16)
Stride:          (2, 16, 16)
Padding:         (0, 0, 0)
Input shape:              torch.Size([1, 3, 2, 16, 16])
Output shape (FP32):      torch.Size([1, 1024, 1, 1, 1])
Output shape (Quantized): torch.Size([1, 1024, 1, 1, 1])
┌───────────── Quantization Error Summary ─────────────
│ Mean |diff|: 0.006116
│ PEIR       : 0.571971 %
└──────────────────────────────────────────────────────
     ┌───────────────────────────────────────────┐
 2.04┤                                           │
     │                                        •  │
     │                                     ••    │
 1.37┤                                   •••     │
     │                                 •••       │
     │                               •••         │
     │                             •••           │
 0.69┤                           •••             │
     │                         •••               │
     │                       •••                 │
 0.02┤                     •••                   │
     │                   •••                     │
     │                 •••                       │
     │               •••                         │
-0.65┤             •••                           │
     │           •••                             │
     │         •••                               │
-1.33┤       •••                                 │
     │     •••                                   │
     │   • •                                     │
     │  •                                        │
-2.00┤                                           │
     └┬──────────┬─────────┬──────────┬─────────┬┘
    -2.0       -1.0       0.0        1.0      2.0 


Converting to Circle format...
Circle model saved as 'quantized_conv3d.circle'
QuantConv3dDecomposedSpecialCase

@dvsav dvsav force-pushed the quant_conv3d_decomposed branch 6 times, most recently from 35e7ff4 to c066948 Compare February 25, 2026 13:13
@dvsav dvsav marked this pull request as ready for review February 25, 2026 13:22
@dayo09
Copy link
Contributor

dayo09 commented Feb 26, 2026

Good! Now I can see that all operators are quantized well. However, we could remove the 'add' operation remaining too.

Patch embed operation from this PR

image

Our final optimized patch embed suggestion

However, patch embed Conv3d can be lowered to Linear projection when the kernel size perfectly fits the stride size, as below, optimization done within PR #518

class Conv3dWithPerfectFitKernel(torch.nn.Module):
    """Conv3D with perfect fitting kernel"""

    def __init__(self):
        super().__init__()
        self.conv3d = torch.nn.Conv3d(
            in_channels=3,
            out_channels=1024,
            kernel_size=(2, 16, 16),
            stride=(2, 16, 16),
            padding=(0, 0, 0),
        )

    def forward(self, input):
        return self.conv3d(input)

    def get_example_inputs(self):
        return (torch.randn(5, 3, 2, 16, 16),), {}
image

This Conv3dToConv2d optimization logic is implemented on TICO legalization pass, so this need to be performed into quantization wrapper. Could you please perform more optimization likewise above PR does?

To grab the full context, see:

Hi. I want to update this conv3d conversion workflow.

As @mhs4670go demonstrated, Qwen3-Vl's patch embedding layer assumes that same stride == same kernel, thus can be lowered to conv2d. Not only that, it's actually Linear projection.

After figuring out that fact, we also assessed the impacts of both Conv3d-to-Conv2d conversion and Conv3d-to-linear conversion with @llFreetimell and @parjong.

  • Converting it into linear may harm performance as some reshape operations could be added.
  • Thus, we shortly decided to perform conv2d with (1, 1, -1, D * H * W) as the input shape comes as (Batch, D * H * W) - @llFreetimell is still inspecting the previous layer to ensure this.

this proposal is based on above condition. I will update this to the issue.

Copied from #430 (comment)

@dvsav dvsav force-pushed the quant_conv3d_decomposed branch 2 times, most recently from b69449f to 5b32757 Compare March 2, 2026 07:06
@dvsav dvsav marked this pull request as draft March 2, 2026 07:28
@dvsav dvsav force-pushed the quant_conv3d_decomposed branch 2 times, most recently from 14b659f to 3527997 Compare March 2, 2026 10:41
@dvsav dvsav marked this pull request as ready for review March 2, 2026 11:01
@dvsav
Copy link
Contributor Author

dvsav commented Mar 2, 2026

Could you please perform more optimization likewise above PR does?

Hi @dayo09 , thanks for your feedback!

  • I've updated QuantConv3dDecomposed.forward method to handle the special case where Conv3d kernel perfectly fits the stride and input tensor.
  • I've added two testcases to the unit tests for the special case (test_special_case_optimization, test_special_case_without_bias).
  • I've also added an additional example script (tico/quantization/wrapq/examples/nn/quantize_conv3d_special_case.py). You may see the results of its work in the 1st comment to this PR.

@dayo09
Copy link
Contributor

dayo09 commented Mar 3, 2026

python3 tico/quantization/wrapq/examples/nn/quantize_conv3d_special_case.py
netron quantized_conv3d.circle
image

Thanks!

dayo09
dayo09 previously approved these changes Mar 3, 2026
Copy link
Contributor

@dayo09 dayo09 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mhs4670go mhs4670go self-requested a review March 3, 2026 05:32
obs_name = f"{obs_name_prefix}{dict_key}"
obs = self._make_obs(obs_name)
obs_dictionary[dict_key] = obs
self.add_module(obs_name, obs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you let me know the full warning message? Because even though I comment out this line, I got no such warnings.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dvsav Could you check this comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go Could you please clarify what warning message you are referring to?
Do you mean the 2 warnings in the unit tests (20 passed, 2 warnings in 5.25s)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's about the 2 warnings in the unit tests, here's the full context.

Warning Message

$ python -m pytest test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py -v
...
...
...
test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_registration_in_registry
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

test/quantization/wrapq/wrappers/nn/test_quant_conv3d_decomposed.py::TestQuantConv3dDecomposed::test_registration_in_registry
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute
...
...
...

Root Cause

The warnings are triggered during the test execution when lookup(nn.Conv3d) is called, which invokes _lazy_init() in the registry. This lazy initialization imports all core wrapper modules, which in turn import PyTorch components.

The Issue

The warnings come from PyTorch's Swig-generated C++ types (SwigPyPacked and SwigPyObject) that don't have a __module__ attribute. Python's importlib checks for this attribute during imports, and Python 3.12+ is stricter about this check, emitting DeprecationWarning.

Why It Happens

  1. Test calls lookup(nn.Conv3d) → triggers _lazy_init()
  2. _lazy_init() imports all _CORE_MODULES including PyTorch-dependent modules
  3. During import, PyTorch loads its C++ extensions (Swig-wrapped)
  4. Swig types SwigPyPacked and SwigPyObject don't have __module__ attribute
  5. Python's importlib detects this and emits deprecation warnings

I've Just Fixed It

By adding these ilnes to def test_registration_in_registry(self)::

        # Suppress warnings from PyTorch's Swig-generated types
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", message="builtin type SwigPyPacked has no __module__ attribute")
            warnings.filterwarnings("ignore", message="builtin type SwigPyObject has no __module__ attribute")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, no. What I referred is about self.add_module(obs_name, obs). Below comments from L100 explains why self.add_module(obs_name, obs) call is needed. But, I can't see the warnings when I run the test.

Copy link
Contributor Author

@dvsav dvsav Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see now.
Here's the full warning message that is generated by the example script tico/quantization/wrapq/examples/nn/quantize_conv3d.pywhen calling tico.convert unless we add the dynamic observers as submodules through self.add_module(obs_name, obs):

torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  getattr_node = gm.graph.get_attr(lifted_node)
torch/fx/graph.py:1801: UserWarning: Node l__self___wrapped__input_slice_obs_0__cached_scale target L__self___wrapped__input_slice_obs_0__cached_scale L__self___wrapped__input_slice_obs_0__cached_scale of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(
torch/fx/graph.py:1801: UserWarning: Node l__self___wrapped__input_slice_obs_0__cached_zp target L__self___wrapped__input_slice_obs_0__cached_zp L__self___wrapped__input_slice_obs_0__cached_zp of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(
torch/fx/graph.py:1801: UserWarning: Node l__self___wrapped__conv2d_obs_0__cached_scale target L__self___wrapped__conv2d_obs_0__cached_scale L__self___wrapped__conv2d_obs_0__cached_scale of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(
torch/fx/graph.py:1801: UserWarning: Node l__self___wrapped__conv2d_obs_0__cached_zp target L__self___wrapped__conv2d_obs_0__cached_zp L__self___wrapped__conv2d_obs_0__cached_zp of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(
torch/fx/graph.py:1801: UserWarning: Node l__self___wrapped__input_slice_obs_1__cached_scale target L__self___wrapped__input_slice_obs_1__cached_scale L__self___wrapped__input_slice_obs_1__cached_scale of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(
torch/fx/graph.py:1810: UserWarning: Additional 9 warnings suppressed about get_attr references
  warnings.warn(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dvsav

I got no warnings. Could you let me know the versions for transformers and torch? I'll reproduce the warnings with them.

Anyway, the codes makes sense to me. I'll approve this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mhs4670go
Sure. Here we go:

$ pip freeze | grep -e torch -e transformers
torch==2.6.0
torchvision==0.21.0
transformers==5.0.0

I've double-checked that the warnings appear if I comment 1 line in tico/quantization/wrapq/wrappers/nn/quant_conv3d_decomposed.py:

        def create_observer(obs_name_prefix, obs_dictionary, dict_key):
            obs_name = f"{obs_name_prefix}{dict_key}"
            obs = self._make_obs(obs_name)
            obs_dictionary[dict_key] = obs
            #self.add_module(obs_name, obs)

Make sure you're running tico/quantization/wrapq/examples/nn/quantize_conv3d.py to check that:

$ python tico/quantization/wrapq/examples/nn/quantize_conv3d.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks for the information. Turns out that I ran tico/quantization/wrapq/examples/nn/quantize_conv3d_special_case.py scirpt instead.


# Convert to Circle format
example_input = torch.randn(1, in_channels, depth, height, width)
circle_model = tico.convert(quantized_model, (example_input,))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
circle_model = tico.convert(quantized_model, (example_input,))
circle_model = tico.convert(quantized_model.eval(), (example_input,))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

# Convert to Circle format
print("\nConverting to Circle format...")
example_input = torch.randn(1, in_channels, depth, height, width)
circle_model = tico.convert(quantized_model, (example_input,))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
circle_model = tico.convert(quantized_model, (example_input,))
circle_model = tico.convert(quantized_model.eval(), (example_input,))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

@mhs4670go
Copy link
Contributor

Additionally, there's higher PEIR on patch embed script. This should be investigated.

python tico/quantization/wrapq/examples/qwen/quantize_vision_patch_embed.py 
┌───────────── Quantization Error Summary ─────────────
│ Mean |diff|: 0.650984
│ PEIR       : 73.866314 %
└──────────────────────────────────────────────────────
    ┌────────────────────────────────────────────┐
 3.2┤                                            │
    │                   ••                       │
    │                • ••••••••                  │
 2.2┤          •  •••••••••••••••• •             │
    │         •• •••••••••••••••••••••••         │
    │       •••••••••••••••••••••••••••••        │
    │      •••••••••••••••••••••••••••••••       │
 1.1┤     ••••••••••••••••••••••••••••••••••     │
    │    •••••••••••••••••••••••••••••••••••     │
    │  •••••••••••••••••••••••••••••••••••••     │
 0.0┤  • •••••••••••••••••••••••••••••••••••• •  │
    │  • •••••••••••••••••••••••••••••••••••     │
    │   ••••••••••••••••••••••••••••••••••••     │
    │     ••••••••••••••••••••••••••••••••• • •  │
-1.0┤    •••••••••••••••••••••••••••••••••• •    │
    │     • •••••••••••••••••••••••••••••• •     │
    │       •••••••••••••••••••••••••••••   •    │
-2.1┤       •• •••••••••••••••••••••••••         │
    │           •••••••••••••••••••• •           │
    │              •• •••••••••••     •          │
    │                 •  ••   •                  │
-3.2┤                                            │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -3.2       -1.6        0.0       1.6       3.2 

@dvsav
Copy link
Contributor Author

dvsav commented Mar 4, 2026

Additionally, there's higher PEIR on patch embed script. This should be investigated.

@mhs4670go thanks for catching this 👍 . Indeed there was a bug in the special case handling code in QuantConv3dDecomposed.forward (tico/quantization/wrapq/wrappers/nn/quant_conv3d_decomposed.py). I've fixed it now:

-            # Reshape input: (N, C_in, T_in, H_in, W_in) -> (1, 1, N, C_in*T_in*H_in*W_in)
-            x_q = x_q.reshape(1, 1, -1, C_in * T_in * H_in * W_in)

+            # Reshape input: (N, C_in, T_in, H_in, W_in) -> (N, 1, 1, C_in*T_in*H_in*W_in)
+            x_q = x_q.reshape(N, 1, 1, -1)

I've also modified example script tico/quantization/wrapq/examples/nn/quantize_conv3d_special_case.py to catch this bug. Here's the PEIR for patch embed after the fix:

$ python tico/quantization/wrapq/examples/qwen/quantize_vision_patch_embed.py 
┌───────────── Quantization Error Summary ─────────────
│ Mean |diff|: 0.008683
│ PEIR       : 0.744628 %
└──────────────────────────────────────────────────────
    ┌────────────────────────────────────────────┐
 3.8┤                                            │
    │                                         •  │
    │                                            │
 2.7┤                                     •      │
    │                                  •••       │
    │                                •••         │
    │                              •••           │
 1.5┤                            •••             │
    │                         ••••               │
    │                       ••••                 │
 0.3┤                     ••••                   │
    │                   ••••                     │
    │                 ••••                       │
    │               ••••                         │
-0.8┤             •••                            │
    │           ••••                             │
    │         •••                                │
-2.0┤       •••                                  │
    │     •••                                    │
    │   •••                                      │
    │  ••                                        │
-3.2┤                                            │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -3.2       -1.4        0.3       2.1       3.8 

Circle model saved as 'quantized_vision_patch_embed.circle'

@dvsav dvsav requested review from dayo09 and mhs4670go March 4, 2026 12:11
dayo09
dayo09 previously approved these changes Mar 5, 2026
Copy link
Contributor

@dayo09 dayo09 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

This change introduces QuantConv3dDecomposed wrapper to support post-training quantization of Conv3d operation that uses Conv2d and Add operations internally.

TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>
Copy link
Contributor

@mhs4670go mhs4670go left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mhs4670go mhs4670go requested a review from dayo09 March 6, 2026 06:30
@mhs4670go mhs4670go merged commit 83fabdf into Samsung:main Mar 6, 2026
7 checks passed
@dvsav dvsav deleted the quant_conv3d_decomposed branch March 6, 2026 07:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants