From daadb76a9fd87c273348246f3a1699b828da3728 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 17 Apr 2025 19:42:05 -0700 Subject: [PATCH] init --- .../coreml/test/test_coreml_quantizer.py | 4 +- backends/arm/quantizer/arm_quantizer.py | 12 ++--- backends/arm/quantizer/arm_quantizer_utils.py | 4 +- .../arm/quantizer/quantization_annotator.py | 9 ++-- backends/arm/quantizer/quantization_config.py | 4 +- backends/arm/test/ops/test_add.py | 4 +- backends/arm/test/ops/test_sigmoid_16bit.py | 4 +- backends/arm/test/ops/test_sigmoid_32bit.py | 4 +- backends/cadence/aot/compiler.py | 2 +- backends/cadence/aot/quantizer/patterns.py | 2 +- backends/cadence/aot/quantizer/quantizer.py | 6 +-- backends/cadence/aot/quantizer/utils.py | 2 +- .../aot/tests/test_remove_ops_passes.py | 4 +- .../permute_memory_formats_pass.py | 2 +- backends/example/example_operators/utils.py | 2 +- backends/example/example_partitioner.py | 2 +- backends/example/example_quantizer.py | 6 +-- backends/example/test_example_delegate.py | 4 +- backends/mediatek/quantizer/annotator.py | 12 ++--- backends/mediatek/quantizer/qconfig.py | 6 +-- backends/mediatek/quantizer/quantizer.py | 2 +- backends/nxp/quantizer/neutron_quantizer.py | 6 +-- backends/nxp/quantizer/patterns.py | 2 +- backends/nxp/quantizer/utils.py | 2 +- backends/nxp/tests/test_quantizer.py | 2 +- backends/openvino/quantizer/quantizer.py | 4 +- .../qualcomm/_passes/annotate_quant_attrs.py | 4 +- backends/qualcomm/_passes/qnn_pass_manager.py | 4 +- backends/qualcomm/builders/node_visitor.py | 6 +-- backends/qualcomm/partition/utils.py | 4 +- backends/qualcomm/quantizer/annotators.py | 10 ++--- .../qualcomm/quantizer/custom_annotation.py | 6 +-- .../observers/per_block_param_observer.py | 4 +- .../observers/per_channel_param_observer.py | 2 +- backends/qualcomm/quantizer/qconfig.py | 11 +++-- backends/qualcomm/quantizer/quantizer.py | 5 ++- backends/qualcomm/tests/utils.py | 13 +++--- backends/qualcomm/utils/utils.py | 2 +- .../duplicate_dynamic_quant_chain.py | 8 ++-- .../test_duplicate_dynamic_quant_chain.py | 2 +- backends/vulkan/quantizer/vulkan_quantizer.py | 6 +-- .../xnnpack/quantizer/xnnpack_quantizer.py | 10 ++--- .../quantizer/xnnpack_quantizer_utils.py | 18 ++++---- .../test/ops/test_check_quant_params.py | 2 +- .../test/quantizer/test_pt2e_quantization.py | 45 ++++++++++--------- .../test/quantizer/test_representation.py | 4 +- .../test/quantizer/test_xnnpack_quantizer.py | 10 ++--- backends/xnnpack/test/test_xnnpack_utils.py | 18 ++++---- backends/xnnpack/test/tester/tester.py | 4 +- docs/source/backends-coreml.md | 2 +- docs/source/backends-xnnpack.md | 10 ++--- docs/source/llm/getting-started.md | 2 +- .../tutorial-xnnpack-delegate-lowering.md | 2 +- .../export-to-executorch-tutorial.py | 2 +- examples/arm/aot_arm_compiler.py | 4 +- examples/arm/ethos_u_minimal_example.ipynb | 22 ++++----- .../mediatek/aot_utils/oss_utils/utils.py | 2 +- .../mediatek/model_export_scripts/llama.py | 2 +- examples/models/moshi/mimi/test_mimi.py | 2 +- .../models/phi-3-mini/export_phi-3-mini.py | 2 +- examples/qualcomm/oss_scripts/llama/llama.py | 4 +- examples/qualcomm/oss_scripts/moshi/mimi.py | 2 +- examples/qualcomm/scripts/export_example.py | 2 +- examples/qualcomm/utils.py | 15 ++++--- examples/xnnpack/quantization/example.py | 2 +- examples/xnnpack/quantization/utils.py | 2 +- .../test/demos/test_xnnpack_qnnpack.py | 8 ++-- exir/tests/test_memory_planning.py | 8 ++-- exir/tests/test_passes.py | 8 ++-- exir/tests/test_quantization.py | 11 +++-- exir/tests/test_quantize_io_pass.py | 2 +- extension/llm/export/builder.py | 6 +-- extension/llm/export/quantizer_lib.py | 6 +-- 73 files changed, 223 insertions(+), 216 deletions(-) diff --git a/backends/apple/coreml/test/test_coreml_quantizer.py b/backends/apple/coreml/test/test_coreml_quantizer.py index db75631dbc8..d5754328796 100644 --- a/backends/apple/coreml/test/test_coreml_quantizer.py +++ b/backends/apple/coreml/test/test_coreml_quantizer.py @@ -15,12 +15,12 @@ ) from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer -from torch.ao.quantization.quantize_pt2e import ( +from torch.export import export_for_training +from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, prepare_pt2e, prepare_qat_pt2e, ) -from torch.export import export_for_training class TestCoreMLQuantizer: diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index ee08f8e9eec..1cf24d70a30 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -34,11 +34,13 @@ is_ethosu, ) # usort: skip from executorch.exir.backend.compile_spec_schema import CompileSpec -from torch.ao.quantization.fake_quantize import ( +from torch.fx import GraphModule, Node +from torchao.quantization.pt2e import _ObserverOrFakeQuantizeConstructor +from torchao.quantization.pt2e.fake_quantize import ( FakeQuantize, FusedMovingAvgObsFakeQuantize, ) -from torch.ao.quantization.observer import ( +from torchao.quantization.pt2e.observer import ( HistogramObserver, MinMaxObserver, MovingAverageMinMaxObserver, @@ -46,13 +48,11 @@ PerChannelMinMaxObserver, PlaceholderObserver, ) -from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor -from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer -from torch.ao.quantization.quantizer.utils import ( +from torchao.quantization.pt2e.quantizer import QuantizationSpec, Quantizer +from torchao.quantization.pt2e.quantizer.utils import ( _annotate_input_qspec_map, _annotate_output_qspec, ) -from torch.fx import GraphModule, Node __all__ = [ "TOSAQuantizer", diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index 0ce11b620a6..d6eb72f1148 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -15,10 +15,10 @@ import torch from torch._subclasses import FakeTensor - -from torch.ao.quantization.quantizer import QuantizationAnnotation from torch.fx import GraphModule, Node +from torchao.quantization.pt2e.quantizer import QuantizationAnnotation + def is_annotated(node: Node) -> bool: """Given a node return whether the node is annotated.""" diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 5398101fd9a..f379813dedd 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -13,12 +13,15 @@ from executorch.backends.arm.quantizer import arm_quantizer_utils from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from executorch.backends.arm.tosa_utils import get_node_debug_info -from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec -from torch.ao.quantization.quantizer.utils import ( +from torch.fx import Node +from torchao.quantization.pt2e.quantizer import ( + QuantizationSpecBase, + SharedQuantizationSpec, +) +from torchao.quantization.pt2e.quantizer.utils import ( _annotate_input_qspec_map, _annotate_output_qspec, ) -from torch.fx import Node logger = logging.getLogger(__name__) diff --git a/backends/arm/quantizer/quantization_config.py b/backends/arm/quantizer/quantization_config.py index 65435ac7c63..54698d058c4 100644 --- a/backends/arm/quantizer/quantization_config.py +++ b/backends/arm/quantizer/quantization_config.py @@ -9,9 +9,9 @@ from dataclasses import dataclass import torch -from torch.ao.quantization import ObserverOrFakeQuantize +from torchao.quantization.pt2e import ObserverOrFakeQuantize -from torch.ao.quantization.quantizer import ( +from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, FixedQParamsQuantizationSpec, QuantizationSpec, diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 486e53c5f03..0a729d0799a 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -18,8 +18,8 @@ TosaPipelineMI, ) from executorch.backends.xnnpack.test.tester import Quantize -from torch.ao.quantization.observer import HistogramObserver -from torch.ao.quantization.quantizer import QuantizationSpec +from torchao.quantization.pt2e.observer import HistogramObserver +from torchao.quantization.pt2e.quantizer import QuantizationSpec aten_op = "torch.ops.aten.add.Tensor" diff --git a/backends/arm/test/ops/test_sigmoid_16bit.py b/backends/arm/test/ops/test_sigmoid_16bit.py index 240000e6973..fac6a1e06a7 100644 --- a/backends/arm/test/ops/test_sigmoid_16bit.py +++ b/backends/arm/test/ops/test_sigmoid_16bit.py @@ -18,8 +18,8 @@ TosaPipelineBI, ) from executorch.backends.xnnpack.test.tester import Quantize -from torch.ao.quantization.observer import HistogramObserver -from torch.ao.quantization.quantizer import QuantizationSpec +from torchao.quantization.pt2e.observer import HistogramObserver +from torchao.quantization.pt2e.quantizer import QuantizationSpec def _get_16_bit_quant_config(): diff --git a/backends/arm/test/ops/test_sigmoid_32bit.py b/backends/arm/test/ops/test_sigmoid_32bit.py index 14808eedaf9..4ff0e9852a4 100644 --- a/backends/arm/test/ops/test_sigmoid_32bit.py +++ b/backends/arm/test/ops/test_sigmoid_32bit.py @@ -14,8 +14,8 @@ TosaPipelineBI, ) from executorch.backends.xnnpack.test.tester import Quantize -from torch.ao.quantization.observer import HistogramObserver -from torch.ao.quantization.quantizer import QuantizationSpec +from torchao.quantization.pt2e.observer import HistogramObserver +from torchao.quantization.pt2e.quantizer import QuantizationSpec def _get_16_bit_quant_config(): diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 32a4427278b..75b67899c08 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -37,10 +37,10 @@ from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass from executorch.exir.program._program import to_edge_with_preserved_ops from torch._inductor.decomposition import remove_decompositions -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export from torch.export.exported_program import ExportedProgram +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from .passes import get_cadence_passes diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 66f6772d942..cd6a7287793 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -15,7 +15,7 @@ from torch import fx from torch._ops import OpOverload -from torch.ao.quantization.quantizer import ( +from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, SharedQuantizationSpec, ) diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 761b2bf8d31..3fbe1bcc0fd 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -38,9 +38,9 @@ from torch import fx -from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver -from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer -from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer +from torchao.quantization.pt2e.observer import HistogramObserver, MinMaxObserver +from torchao.quantization.pt2e.quantizer import DerivedQuantizationSpec, Quantizer +from torchao.quantization.pt2e.quantizer.composable_quantizer import ComposableQuantizer act_qspec_asym8s = QuantizationSpec( diff --git a/backends/cadence/aot/quantizer/utils.py b/backends/cadence/aot/quantizer/utils.py index 0f9c9399780..fad5ca41e22 100644 --- a/backends/cadence/aot/quantizer/utils.py +++ b/backends/cadence/aot/quantizer/utils.py @@ -14,13 +14,13 @@ import torch from torch import fx from torch._ops import OpOverload -from torch.ao.quantization import ObserverOrFakeQuantize from torch.fx import GraphModule from torch.fx.passes.utils.source_matcher_utils import ( check_subgraphs_connected, SourcePartition, ) +from torchao.quantization.pt2e import ObserverOrFakeQuantize def quantize_tensor_multiplier( diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index dba4f711864..7d47ff122f9 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -41,11 +41,11 @@ from parameterized.parameterized import parameterized from pyre_extensions import none_throws -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e - from torch.export import export_for_training from torch.fx.passes.infra.pass_base import PassResult +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + class TestRemoveOpsPasses(unittest.TestCase): @parameterized.expand( diff --git a/backends/example/example_backend_delegate_passes/permute_memory_formats_pass.py b/backends/example/example_backend_delegate_passes/permute_memory_formats_pass.py index 8e857f32376..b681a3d4c39 100644 --- a/backends/example/example_backend_delegate_passes/permute_memory_formats_pass.py +++ b/backends/example/example_backend_delegate_passes/permute_memory_formats_pass.py @@ -11,7 +11,7 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dim_order_utils import get_dim_order from executorch.exir.pass_base import ExportPass, PassResult -from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions +from torchao.quantization.pt2e.pt2e.graph_utils import find_sequential_partitions class PermuteMemoryFormatsPass(ExportPass): diff --git a/backends/example/example_operators/utils.py b/backends/example/example_operators/utils.py index 7dca2a3be6a..d9b3a436840 100644 --- a/backends/example/example_operators/utils.py +++ b/backends/example/example_operators/utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation +from torchao.quantization.pt2e.quantizer.quantizer import QuantizationAnnotation def _nodes_are_annotated(node_list): diff --git a/backends/example/example_partitioner.py b/backends/example/example_partitioner.py index 5e9102e999b..5862d0ca162 100644 --- a/backends/example/example_partitioner.py +++ b/backends/example/example_partitioner.py @@ -19,9 +19,9 @@ ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.graph_module import get_control_flow_submodules -from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions from torch.export import ExportedProgram from torch.fx.passes.operator_support import OperatorSupportBase +from torchao.quantization.pt2e.pt2e.graph_utils import find_sequential_partitions @final diff --git a/backends/example/example_quantizer.py b/backends/example/example_quantizer.py index 74a0057ba4a..2e37fcf0bb7 100644 --- a/backends/example/example_quantizer.py +++ b/backends/example/example_quantizer.py @@ -11,9 +11,9 @@ from executorch.backends.example.example_operators.ops import module_to_annotator from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import OperatorConfig from torch import fx -from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver -from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions -from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer +from torchao.quantization.pt2e.observer import HistogramObserver, MinMaxObserver +from torchao.quantization.pt2e.pt2e.graph_utils import find_sequential_partitions +from torchao.quantization.pt2e.quantizer import QuantizationSpec, Quantizer def get_uint8_tensor_spec(observer_or_fake_quant_ctr): diff --git a/backends/example/test_example_delegate.py b/backends/example/test_example_delegate.py index a382273af07..bc6ad4d7e4c 100644 --- a/backends/example/test_example_delegate.py +++ b/backends/example/test_example_delegate.py @@ -17,10 +17,10 @@ DuplicateDequantNodePass, ) from executorch.exir.delegate import executorch_call_delegate - -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + from torchvision.models.quantization import mobilenet_v2 diff --git a/backends/mediatek/quantizer/annotator.py b/backends/mediatek/quantizer/annotator.py index d250b774af8..d77fad387e4 100644 --- a/backends/mediatek/quantizer/annotator.py +++ b/backends/mediatek/quantizer/annotator.py @@ -10,18 +10,18 @@ from torch._ops import OpOverload from torch._subclasses import FakeTensor -from torch.ao.quantization.quantizer import QuantizationAnnotation -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, -) - from torch.export import export_for_training from torch.fx import Graph, Node from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( SubgraphMatcherWithNameNodeMap, ) +from torchao.quantization.pt2e.quantizer import QuantizationAnnotation +from torchao.quantization.pt2e.quantizer.utils import ( + _annotate_input_qspec_map, + _annotate_output_qspec, +) + from .qconfig import QuantizationConfig diff --git a/backends/mediatek/quantizer/qconfig.py b/backends/mediatek/quantizer/qconfig.py index e16f5e936cb..d9f105cd0d0 100644 --- a/backends/mediatek/quantizer/qconfig.py +++ b/backends/mediatek/quantizer/qconfig.py @@ -10,9 +10,9 @@ import torch -from torch.ao.quantization.fake_quantize import FakeQuantize -from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver -from torch.ao.quantization.quantizer import QuantizationSpec +from torchao.quantization.pt2e.fake_quantize import FakeQuantize +from torchao.quantization.pt2e.observer import MinMaxObserver, PerChannelMinMaxObserver +from torchao.quantization.pt2e.quantizer import QuantizationSpec @unique diff --git a/backends/mediatek/quantizer/quantizer.py b/backends/mediatek/quantizer/quantizer.py index 4e78d6dff1a..f9babdec997 100644 --- a/backends/mediatek/quantizer/quantizer.py +++ b/backends/mediatek/quantizer/quantizer.py @@ -4,8 +4,8 @@ # except in compliance with the License. See the license file in the root # directory of this source tree for more details. -from torch.ao.quantization.quantizer import Quantizer from torch.fx import GraphModule +from torchao.quantization.pt2e.quantizer import Quantizer from .._passes.decompose_scaled_dot_product_attention import ( DecomposeScaledDotProductAttention, diff --git a/backends/nxp/quantizer/neutron_quantizer.py b/backends/nxp/quantizer/neutron_quantizer.py index eff7f513cb9..8a64170b632 100644 --- a/backends/nxp/quantizer/neutron_quantizer.py +++ b/backends/nxp/quantizer/neutron_quantizer.py @@ -35,9 +35,9 @@ QuantizationSpec, ) from torch import fx -from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver -from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer -from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer +from torchao.quantization.pt2e.observer import HistogramObserver, MinMaxObserver +from torchao.quantization.pt2e.quantizer import DerivedQuantizationSpec, Quantizer +from torchao.quantization.pt2e.quantizer.composable_quantizer import ComposableQuantizer class NeutronAtenQuantizer(Quantizer): diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index 6797447c50c..b71f0621002 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -14,7 +14,7 @@ from executorch.backends.nxp.quantizer.utils import get_bias_qparams from torch import fx from torch._ops import OpOverload -from torch.ao.quantization.quantizer import ( +from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, FixedQParamsQuantizationSpec, SharedQuantizationSpec, diff --git a/backends/nxp/quantizer/utils.py b/backends/nxp/quantizer/utils.py index 1effcdff25a..1b941f6e632 100644 --- a/backends/nxp/quantizer/utils.py +++ b/backends/nxp/quantizer/utils.py @@ -14,11 +14,11 @@ import torch from torch import fx from torch._ops import OpOverload -from torch.ao.quantization import ObserverOrFakeQuantize from torch.fx.passes.utils.source_matcher_utils import ( check_subgraphs_connected, SourcePartition, ) +from torchao.quantization.pt2e import ObserverOrFakeQuantize def is_annotated(nodes: List[fx.Node]) -> bool: diff --git a/backends/nxp/tests/test_quantizer.py b/backends/nxp/tests/test_quantizer.py index 868a94059b5..dd1b691a18f 100644 --- a/backends/nxp/tests/test_quantizer.py +++ b/backends/nxp/tests/test_quantizer.py @@ -8,7 +8,7 @@ import executorch.backends.nxp.tests.models as models import torch from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def _get_target_name(node): diff --git a/backends/openvino/quantizer/quantizer.py b/backends/openvino/quantizer/quantizer.py index 5532235f573..134ed5cb4ac 100644 --- a/backends/openvino/quantizer/quantizer.py +++ b/backends/openvino/quantizer/quantizer.py @@ -17,12 +17,12 @@ import torch.fx from nncf.common.graph.graph import NNCFGraph # type: ignore[import-untyped] -from torch.ao.quantization.observer import ( +from torchao.quantization.pt2e.observer import ( HistogramObserver, PerChannelMinMaxObserver, UniformQuantizationObserverBase, ) -from torch.ao.quantization.quantizer.quantizer import ( +from torchao.quantization.pt2e.quantizer.quantizer import ( EdgeOrNode, QuantizationAnnotation, QuantizationSpec, diff --git a/backends/qualcomm/_passes/annotate_quant_attrs.py b/backends/qualcomm/_passes/annotate_quant_attrs.py index ed19a54b7e7..f5f385f6dc6 100644 --- a/backends/qualcomm/_passes/annotate_quant_attrs.py +++ b/backends/qualcomm/_passes/annotate_quant_attrs.py @@ -124,9 +124,9 @@ def _dequant_fold_params(self, n, quant_attrs, param): offsets = self._expand(quant_attrs[QCOM_ZERO_POINTS], dim, axis) param = param.sub(offsets).mul(scales).to(torch.float32).contiguous() elif quant_attrs[QCOM_ENCODING] in [ - exir_ops.edge.pt2e_quant.dequantize_affine.default + exir_ops.edge.torchao.dequantize_affine.default ]: - param = torch.ops.pt2e_quant.dequantize_affine( + param = torch.ops.torchao.dequantize_affine( param, block_size=quant_attrs[QCOM_BLOCK_SIZE], scale=quant_attrs[QCOM_SCALE], diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index c98f27db120..c4b730dc5b2 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -131,8 +131,8 @@ def get_to_edge_transform_passes( from executorch.backends.qualcomm._passes import utils from executorch.exir.dialects._ops import ops as exir_ops - utils.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default) - utils.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default) + utils.q_ops.add(exir_ops.edge.torchao.quantize_affine.default) + utils.dq_ops.add(exir_ops.edge.torchao.dequantize_affine.default) passes_job = ( passes_job if passes_job is not None else get_capture_program_passes() diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 7965a30caea..e99d6b2e620 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -242,8 +242,8 @@ def get_quant_encoding_conf( ) # TODO: refactor this when target could be correctly detected per_block_encoding = { - exir_ops.edge.pt2e_quant.quantize_affine.default, - exir_ops.edge.pt2e_quant.dequantize_affine.default, + exir_ops.edge.torchao.quantize_affine.default, + exir_ops.edge.torchao.dequantize_affine.default, } if quant_attrs[QCOM_ENCODING] in per_block_encoding: return self.make_qnn_per_block_config(node, quant_attrs) @@ -271,7 +271,7 @@ def get_quant_tensor_value( axis_order.index(x) for x in range(len(axis_order)) ) tensor = tensor.permute(origin_order) - tensor = torch.ops.pt2e_quant.quantize_affine( + tensor = torch.ops.torchao.quantize_affine( tensor, block_size=quant_attrs[QCOM_BLOCK_SIZE], scale=quant_attrs[QCOM_SCALE], diff --git a/backends/qualcomm/partition/utils.py b/backends/qualcomm/partition/utils.py index 816d1ac1d9b..05bbd1ff970 100644 --- a/backends/qualcomm/partition/utils.py +++ b/backends/qualcomm/partition/utils.py @@ -57,7 +57,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]: torch.ops.aten.upsample_bicubic2d.vec, # This request is ignored because it is in a blocklist. Refer to exir/program/_program.py torch.ops.aten.unbind.int, - torch.ops.pt2e_quant.quantize_affine.default, - torch.ops.pt2e_quant.dequantize_affine.default, + torch.ops.torchao.quantize_affine.default, + torch.ops.torchao.dequantize_affine.default, ] return do_not_decompose diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 469a801feeb..01212044d95 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -12,20 +12,20 @@ from torch._ops import OpOverload from torch._subclasses import FakeTensor -from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize +from torch.fx import Node +from torchao.quantization.pt2e.fake_quantize import FixedQParamsFakeQuantize -from torch.ao.quantization.observer import FixedQParamsObserver -from torch.ao.quantization.quantizer import ( +from torchao.quantization.pt2e.observer import FixedQParamsObserver +from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, QuantizationAnnotation, QuantizationSpec, SharedQuantizationSpec, ) -from torch.ao.quantization.quantizer.utils import ( +from torchao.quantization.pt2e.quantizer.utils import ( _annotate_input_qspec_map, _annotate_output_qspec, ) -from torch.fx import Node from .qconfig import ( get_16a16w_qnn_ptq_config, diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index bda91609f1c..2771645c3e0 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -17,13 +17,13 @@ QuantizationConfig, ) from executorch.exir.dialects._ops import ops as exir_ops -from torch.ao.quantization.observer import FixedQParamsObserver, MinMaxObserver -from torch.ao.quantization.quantizer import ( +from torch.fx import Node +from torchao.quantization.pt2e.observer import FixedQParamsObserver, MinMaxObserver +from torchao.quantization.pt2e.quantizer import ( QuantizationAnnotation, QuantizationSpec, SharedQuantizationSpec, ) -from torch.fx import Node def annotate_mimi_decoder(gm: torch.fx.GraphModule): diff --git a/backends/qualcomm/quantizer/observers/per_block_param_observer.py b/backends/qualcomm/quantizer/observers/per_block_param_observer.py index e60f15c6d9c..bf89c46a499 100644 --- a/backends/qualcomm/quantizer/observers/per_block_param_observer.py +++ b/backends/qualcomm/quantizer/observers/per_block_param_observer.py @@ -7,8 +7,8 @@ from typing import Tuple import torch -from torch.ao.quantization.observer import MappingType, PerBlock -from torch.ao.quantization.pt2e._affine_quantization import ( +from torchao.quantization.pt2e.observer import MappingType, PerBlock +from torchao.quantization.pt2e.pt2e._affine_quantization import ( _get_reduction_params, AffineQuantizedMinMaxObserver, choose_qparams_affine_with_min_max, diff --git a/backends/qualcomm/quantizer/observers/per_channel_param_observer.py b/backends/qualcomm/quantizer/observers/per_channel_param_observer.py index 0bba4d5ffeb..222acd32690 100644 --- a/backends/qualcomm/quantizer/observers/per_channel_param_observer.py +++ b/backends/qualcomm/quantizer/observers/per_channel_param_observer.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch -from torch.ao.quantization.observer import UniformQuantizationObserverBase +from torchao.quantization.pt2e.observer import UniformQuantizationObserverBase # TODO move to torch/ao/quantization/observer.py. diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index 67968363eb6..e0a838cc32e 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -7,18 +7,21 @@ PerBlockParamObserver, ) from torch import Tensor -from torch.ao.quantization.fake_quantize import ( +from torch.fx import Node +from torchao.quantization.pt2e.fake_quantize import ( FakeQuantize, FusedMovingAvgObsFakeQuantize, ) -from torch.ao.quantization.observer import ( +from torchao.quantization.pt2e.observer import ( MinMaxObserver, MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver, PerChannelMinMaxObserver, ) -from torch.ao.quantization.quantizer import DerivedQuantizationSpec, QuantizationSpec -from torch.fx import Node +from torchao.quantization.pt2e.quantizer import ( + DerivedQuantizationSpec, + QuantizationSpec, +) @dataclass(eq=True) diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 8e65607dd84..4a1bca70add 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -9,11 +9,12 @@ from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple import torch +import torchao from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager from torch._ops import OpOverload -from torch.ao.quantization.quantizer import Quantizer from torch.fx import GraphModule +from torchao.quantization.pt2e.quantizer import Quantizer from .annotators import OP_ANNOTATOR @@ -131,7 +132,7 @@ class ModuleQConfig: is_conv_per_channel: bool = False is_linear_per_channel: bool = False act_observer: Optional[ - torch.ao.quantization.observer.UniformQuantizationObserverBase + torchao.quantization.pt2e.observer.UniformQuantizationObserverBase ] = None def __post_init__(self): diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 71d3b9e7ec2..9472fece3a0 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -13,6 +13,7 @@ import numpy as np import torch +import torchao from executorch import exir from executorch.backends.qualcomm.qnn_preprocess import QnnBackend @@ -43,12 +44,12 @@ from executorch.exir.pass_base import ExportPass from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from executorch.exir.program import ExecutorchProgram, ExecutorchProgramManager -from torch.ao.quantization.quantize_pt2e import ( +from torch.fx.passes.infra.pass_base import PassResult +from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, prepare_pt2e, prepare_qat_pt2e, ) -from torch.fx.passes.infra.pass_base import PassResult def generate_context_binary( @@ -536,8 +537,8 @@ def get_qdq_module( torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_channel.default, torch.ops.quantized_decomposed.dequantize_per_channel.default, - torch.ops.pt2e_quant.quantize_affine.default, - torch.ops.pt2e_quant.dequantize_affine.default, + torch.ops.torchao.quantize_affine.default, + torch.ops.torchao.dequantize_affine.default, } if not bypass_check: self.assertTrue(nodes.intersection(q_and_dq)) @@ -568,7 +569,7 @@ def get_prepared_qat_module( quantizer.set_submodule_qconfig_list(submodule_qconfig_list) prepared = prepare_qat_pt2e(m, quantizer) - return torch.ao.quantization.move_exported_model_to_train(prepared) + return torchao.quantization.pt2e.move_exported_model_to_train(prepared) def get_converted_sgd_trained_module( self, @@ -583,7 +584,7 @@ def get_converted_sgd_trained_module( optimizer.zero_grad() loss.backward() optimizer.step() - return torch.ao.quantization.quantize_pt2e.convert_pt2e(prepared) + return torchao.quantization.pt2e.quantize_pt2e.convert_pt2e(prepared) def split_graph(self, division: int): class SplitGraph(ExportPass): diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index e0ebc5beebe..cf9a28dec63 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -601,8 +601,8 @@ def skip_annotation( from executorch.backends.qualcomm.serialization.qc_schema_serialize import ( flatbuffer_to_option, ) - from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner + from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def prepare_subgm(subgm, subgm_name): # prepare current submodule for quantization annotation diff --git a/backends/transforms/duplicate_dynamic_quant_chain.py b/backends/transforms/duplicate_dynamic_quant_chain.py index 2ca65eec45f..aac5eec7024 100644 --- a/backends/transforms/duplicate_dynamic_quant_chain.py +++ b/backends/transforms/duplicate_dynamic_quant_chain.py @@ -9,14 +9,14 @@ import torch -from torch.ao.quantization.pt2e.utils import ( +from torch.fx.node import map_arg +from torch.fx.passes.infra.pass_base import PassBase, PassResult + +from torchao.quantization.pt2e.pt2e.utils import ( _filter_sym_size_users, _is_valid_annotation, ) -from torch.fx.node import map_arg -from torch.fx.passes.infra.pass_base import PassBase, PassResult - logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) diff --git a/backends/transforms/test/test_duplicate_dynamic_quant_chain.py b/backends/transforms/test/test_duplicate_dynamic_quant_chain.py index ab965dd347d..79bc56f8780 100644 --- a/backends/transforms/test/test_duplicate_dynamic_quant_chain.py +++ b/backends/transforms/test/test_duplicate_dynamic_quant_chain.py @@ -15,7 +15,6 @@ get_symmetric_quantization_config, XNNPACKQuantizer, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e # TODO: Move away from using torch's internal testing utils from torch.testing._internal.common_quantization import ( @@ -23,6 +22,7 @@ QuantizationTestCase, TestHelperModules, ) +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e class MyTestHelperModules: diff --git a/backends/vulkan/quantizer/vulkan_quantizer.py b/backends/vulkan/quantizer/vulkan_quantizer.py index 2ea3e321dc3..2748a12fb12 100644 --- a/backends/vulkan/quantizer/vulkan_quantizer.py +++ b/backends/vulkan/quantizer/vulkan_quantizer.py @@ -18,10 +18,10 @@ propagate_annotation, QuantizationConfig, ) -from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver -from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor -from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer from torch.fx import Node +from torchao.quantization.pt2e import _ObserverOrFakeQuantizeConstructor +from torchao.quantization.pt2e.observer import MinMaxObserver, PerChannelMinMaxObserver +from torchao.quantization.pt2e.quantizer import QuantizationSpec, Quantizer __all__ = [ diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer.py b/backends/xnnpack/quantizer/xnnpack_quantizer.py index 0ddee53a41a..1a0c4fbc007 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer.py @@ -16,11 +16,11 @@ propagate_annotation, QuantizationConfig, ) -from torch.ao.quantization.fake_quantize import ( +from torchao.quantization.pt2e.fake_quantize import ( FakeQuantize, FusedMovingAvgObsFakeQuantize, ) -from torch.ao.quantization.observer import ( +from torchao.quantization.pt2e.observer import ( HistogramObserver, MinMaxObserver, MovingAverageMinMaxObserver, @@ -28,13 +28,13 @@ PerChannelMinMaxObserver, PlaceholderObserver, ) -from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer -from torch.ao.quantization.quantizer.utils import _get_module_name_filter +from torchao.quantization.pt2e.quantizer import QuantizationSpec, Quantizer +from torchao.quantization.pt2e.quantizer.utils import _get_module_name_filter if TYPE_CHECKING: - from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor from torch.fx import Node + from torchao.quantization.pt2e import _ObserverOrFakeQuantizeConstructor __all__ = [ diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py index ce459806c6e..0014c1c3536 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py @@ -8,26 +8,26 @@ import torch.nn.functional as F from torch._subclasses import FakeTensor from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix -from torch.ao.quantization.pt2e.export_utils import _WrapperModule -from torch.ao.quantization.pt2e.utils import ( +from torch.fx import Node +from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( + SubgraphMatcherWithNameNodeMap, +) +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions +from torchao.quantization.pt2e.pt2e.export_utils import _WrapperModule +from torchao.quantization.pt2e.pt2e.utils import ( _get_aten_graph_module_for_pattern, _is_conv_node, _is_conv_transpose_node, ) -from torch.ao.quantization.quantizer import ( +from torchao.quantization.pt2e.quantizer import ( QuantizationAnnotation, QuantizationSpec, SharedQuantizationSpec, ) -from torch.ao.quantization.quantizer.utils import ( +from torchao.quantization.pt2e.quantizer.utils import ( _annotate_input_qspec_map, _annotate_output_qspec, ) -from torch.fx import Node -from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( - SubgraphMatcherWithNameNodeMap, -) -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions __all__ = [ diff --git a/backends/xnnpack/test/ops/test_check_quant_params.py b/backends/xnnpack/test/ops/test_check_quant_params.py index d05b1fce540..8be59aab50e 100644 --- a/backends/xnnpack/test/ops/test_check_quant_params.py +++ b/backends/xnnpack/test/ops/test_check_quant_params.py @@ -9,8 +9,8 @@ ) from executorch.backends.xnnpack.utils.utils import get_param_tensor from executorch.exir import to_edge_transform_and_lower -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export_for_training +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e class TestCheckQuantParams(unittest.TestCase): diff --git a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py index 34b6f745044..f4bdbd7beb3 100644 --- a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py +++ b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py @@ -10,21 +10,12 @@ from typing import Dict, Tuple import torch +import torchao from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) -from torch.ao.quantization import ( - compare_results, - CUSTOM_KEY, - default_per_channel_symmetric_qnnpack_qconfig, - extract_results_from_loggers, - generate_numeric_debug_handle, - NUMERIC_DEBUG_HANDLE_KEY, - observer, - prepare_for_propagation_comparison, -) -from torch.ao.quantization.pt2e.graph_utils import bfs_trace_with_node_process +from torch.ao.quantization import default_per_channel_symmetric_qnnpack_qconfig from torch.ao.quantization.qconfig import ( float_qparams_weight_only_qconfig, per_channel_weight_observer_range_neg_127_to_127, @@ -32,14 +23,6 @@ weight_observer_range_neg_127_to_127, ) from torch.ao.quantization.qconfig_mapping import QConfigMapping -from torch.ao.quantization.quantize_pt2e import ( - convert_pt2e, - prepare_pt2e, - prepare_qat_pt2e, -) -from torch.ao.quantization.quantizer import Quantizer -from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer -from torch.ao.quantization.quantizer.embedding_quantizer import EmbeddingQuantizer from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, @@ -51,6 +34,24 @@ TemporaryFileName, TestCase, ) +from torchao.quantization.pt2e import ( + compare_results, + CUSTOM_KEY, + extract_results_from_loggers, + generate_numeric_debug_handle, + NUMERIC_DEBUG_HANDLE_KEY, + observer, + prepare_for_propagation_comparison, +) +from torchao.quantization.pt2e.pt2e.graph_utils import bfs_trace_with_node_process +from torchao.quantization.pt2e.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) +from torchao.quantization.pt2e.quantizer import Quantizer +from torchao.quantization.pt2e.quantizer.composable_quantizer import ComposableQuantizer +from torchao.quantization.pt2e.quantizer.embedding_quantizer import EmbeddingQuantizer class TestQuantizePT2E(PT2EQuantizationTestCase): @@ -396,7 +397,7 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None: m.train() # After wrapping: does not error and swaps the ops accordingly - torch.ao.quantization.allow_exported_model_train_eval(m) # pyre-ignore[6] + torchao.quantization.pt2e.allow_exported_model_train_eval(m) # pyre-ignore[6] m.eval() _assert_ops_are_correct(m, train=False) # pyre-ignore[6] m.train() @@ -411,7 +412,7 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None: m.train() # After prepare and after wrapping: does not error and swaps the ops accordingly - torch.ao.quantization.allow_exported_model_train_eval(m) + torchao.quantization.pt2e.allow_exported_model_train_eval(m) m.eval() _assert_ops_are_correct(m, train=False) m.train() @@ -425,7 +426,7 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None: m.train() # After convert and after wrapping: does not error and swaps the ops accordingly - torch.ao.quantization.allow_exported_model_train_eval(m) + torchao.quantization.pt2e.allow_exported_model_train_eval(m) m.eval() _assert_ops_are_correct(m, train=False) m.train() diff --git a/backends/xnnpack/test/quantizer/test_representation.py b/backends/xnnpack/test/quantizer/test_representation.py index e52bbbd7ae7..817f7f9e368 100644 --- a/backends/xnnpack/test/quantizer/test_representation.py +++ b/backends/xnnpack/test/quantizer/test_representation.py @@ -8,8 +8,6 @@ XNNPACKQuantizer, ) from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401 -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer import Quantizer from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, @@ -17,6 +15,8 @@ skipIfNoQNNPACK, TestHelperModules, ) +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import Quantizer @skipIfNoQNNPACK diff --git a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py index 856030755af..9a1191e51f6 100644 --- a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py +++ b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py @@ -9,13 +9,7 @@ XNNPACKQuantizer, ) from torch.ao.ns.fx.utils import compute_sqnr -from torch.ao.quantization import ( - default_dynamic_fake_quant, - default_dynamic_qconfig, - observer, - QConfig, - QConfigMapping, -) +from torch.ao.quantization import default_dynamic_qconfig, QConfig, QConfigMapping from torch.ao.quantization.backend_config import get_qnnpack_backend_config from torch.ao.quantization.qconfig import ( default_per_channel_symmetric_qnnpack_qconfig, @@ -38,6 +32,8 @@ TestHelperModules, ) from torch.testing._internal.common_quantized import override_quantized_engine +from torchao.quantization.pt2e import observer +from torchao.quantization.pt2e.fake_quantize import default_dynamic_fake_quant @skipIfNoQNNPACK diff --git a/backends/xnnpack/test/test_xnnpack_utils.py b/backends/xnnpack/test/test_xnnpack_utils.py index 3ff2f0e4c1e..e6c97545d82 100644 --- a/backends/xnnpack/test/test_xnnpack_utils.py +++ b/backends/xnnpack/test/test_xnnpack_utils.py @@ -47,7 +47,6 @@ from torch.ao.quantization import ( # @manual default_per_channel_symmetric_qnnpack_qconfig, - PlaceholderObserver, QConfig, QConfigMapping, ) @@ -55,12 +54,6 @@ from torch.ao.quantization.backend_config.executorch import ( get_executorch_backend_config, ) - -from torch.ao.quantization.observer import ( - per_channel_weight_observer_range_neg_127_to_127, - # default_weight_observer, - weight_observer_range_neg_127_to_127, -) from torch.ao.quantization.qconfig_mapping import ( _get_default_qconfig_mapping_with_default_qconfig, _get_symmetric_qnnpack_qconfig_mapping, @@ -70,11 +63,18 @@ _convert_to_reference_decomposed_fx, prepare_fx, ) - -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export_for_training from torch.testing import FileCheck +from torchao.quantization.pt2e import PlaceholderObserver + +from torchao.quantization.pt2e.observer import ( + per_channel_weight_observer_range_neg_127_to_127, + # default_weight_observer, + weight_observer_range_neg_127_to_127, +) + +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def randomize_bn(num_features: int, dimensionality: int = 2) -> torch.nn.Module: diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index cbce817cf4b..fd48837bd72 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -55,11 +55,11 @@ ) from executorch.exir.program._program import _transform from torch._export.pass_base import PassType -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer.quantizer import Quantizer from torch.export import export, ExportedProgram from torch.testing import FileCheck from torch.utils._pytree import tree_flatten +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer.quantizer import Quantizer class Stage(ABC): diff --git a/docs/source/backends-coreml.md b/docs/source/backends-coreml.md index 2d5e4256bcc..c7dbd91fb03 100644 --- a/docs/source/backends-coreml.md +++ b/docs/source/backends-coreml.md @@ -104,7 +104,7 @@ import torchvision.models as models from torchvision.models.mobilenetv2 import MobileNet_V2_Weights from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer from executorch.backends.apple.coreml.partition import CoreMLPartitioner -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from executorch.exir import to_edge_transform_and_lower from executorch.backends.apple.coreml.compiler import CoreMLBackend diff --git a/docs/source/backends-xnnpack.md b/docs/source/backends-xnnpack.md index db1c055dc9c..85919952988 100644 --- a/docs/source/backends-xnnpack.md +++ b/docs/source/backends-xnnpack.md @@ -1,6 +1,6 @@ # XNNPACK Backend -The XNNPACK delegate is the ExecuTorch solution for CPU execution on mobile CPUs. [XNNPACK](https://github.com/google/XNNPACK/tree/master) is a library that provides optimized kernels for machine learning operators on Arm and x86 CPUs. +The XNNPACK delegate is the ExecuTorch solution for CPU execution on mobile CPUs. [XNNPACK](https://github.com/google/XNNPACK/tree/master) is a library that provides optimized kernels for machine learning operators on Arm and x86 CPUs. ## Features @@ -18,7 +18,7 @@ The XNNPACK delegate is the ExecuTorch solution for CPU execution on mobile CPUs ## Development Requirements -The XNNPACK delegate does not introduce any development system requirements beyond those required by +The XNNPACK delegate does not introduce any development system requirements beyond those required by the core ExecuTorch runtime. ---- @@ -63,7 +63,7 @@ After generating the XNNPACK-delegated .pte, the model can be tested from Python ## Quantization -The XNNPACK delegate can also be used as a backend to execute symmetrically quantized models. To quantize a PyTorch model for the XNNPACK backend, use the `XNNPACKQuantizer`. `Quantizers` are backend specific, which means the `XNNPACKQuantizer` is configured to quantize models to leverage the quantized operators offered by the XNNPACK Library. +The XNNPACK delegate can also be used as a backend to execute symmetrically quantized models. To quantize a PyTorch model for the XNNPACK backend, use the `XNNPACKQuantizer`. `Quantizers` are backend specific, which means the `XNNPACKQuantizer` is configured to quantize models to leverage the quantized operators offered by the XNNPACK Library. ### Supported Quantization Schemes The XNNPACK delegate supports the following quantization schemes: @@ -94,8 +94,8 @@ from torchvision.models.mobilenetv2 import MobileNet_V2_Weights from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import XNNPACKQuantizer from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.exir import to_edge_transform_and_lower -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer.xnnpack_quantizer import get_symmetric_quantization_config +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import get_symmetric_quantization_config model = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval() sample_inputs = (torch.randn(1, 3, 224, 224), ) diff --git a/docs/source/llm/getting-started.md b/docs/source/llm/getting-started.md index ea90f346279..9198cd324dc 100644 --- a/docs/source/llm/getting-started.md +++ b/docs/source/llm/getting-started.md @@ -619,7 +619,7 @@ from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e ``` ```python diff --git a/docs/source/tutorial-xnnpack-delegate-lowering.md b/docs/source/tutorial-xnnpack-delegate-lowering.md index 7fc97476ef7..64c997d3cee 100644 --- a/docs/source/tutorial-xnnpack-delegate-lowering.md +++ b/docs/source/tutorial-xnnpack-delegate-lowering.md @@ -85,7 +85,7 @@ sample_inputs = (torch.randn(1, 3, 224, 224), ) mobilenet_v2 = export_for_training(mobilenet_v2, sample_inputs).module() # 2-stage export for quantization path -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, diff --git a/docs/source/tutorials_source/export-to-executorch-tutorial.py b/docs/source/tutorials_source/export-to-executorch-tutorial.py index de42cb51bce..2ca6a207d17 100644 --- a/docs/source/tutorials_source/export-to-executorch-tutorial.py +++ b/docs/source/tutorials_source/export-to-executorch-tutorial.py @@ -200,7 +200,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: get_symmetric_quantization_config, XNNPACKQuantizer, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer) # type: ignore[arg-type] diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 446d1a4eca4..57e74662328 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -47,10 +47,10 @@ from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.extension.export_util.utils import save_pte_program from tabulate import tabulate +from torch.utils.data import DataLoader # Quantize model if required using the standard export quantizaion flow. -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.utils.data import DataLoader +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from ..models import MODEL_NAME_TO_MODEL from ..models.model_factory import EagerModelFactory diff --git a/examples/arm/ethos_u_minimal_example.ipynb b/examples/arm/ethos_u_minimal_example.ipynb index d73695e9d48..79a3d9d4c95 100644 --- a/examples/arm/ethos_u_minimal_example.ipynb +++ b/examples/arm/ethos_u_minimal_example.ipynb @@ -84,12 +84,12 @@ " EthosUQuantizer,\n", " get_symmetric_quantization_config,\n", ")\n", - "from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e\n", + "from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n", "\n", "target = \"ethos-u55-128\"\n", "\n", "# Create a compilation spec describing the target for configuring the quantizer\n", - "# Some args are used by the Arm Vela graph compiler later in the example. Refer to Arm Vela documentation for an \n", + "# Some args are used by the Arm Vela graph compiler later in the example. Refer to Arm Vela documentation for an\n", "# explanation of its flags: https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/blob/main/OPTIONS.md\n", "spec_builder = ArmCompileSpecBuilder().ethosu_compile_spec(\n", " target,\n", @@ -100,12 +100,12 @@ "compile_spec = spec_builder.build()\n", "\n", "# Create and configure quantizer to use a symmetric quantization config globally on all nodes\n", - "quantizer = EthosUQuantizer(compile_spec) \n", + "quantizer = EthosUQuantizer(compile_spec)\n", "operator_config = get_symmetric_quantization_config(is_per_channel=False)\n", "quantizer.set_global(operator_config)\n", "\n", "# Post training quantization\n", - "quantized_graph_module = prepare_pt2e(graph_module, quantizer) \n", + "quantized_graph_module = prepare_pt2e(graph_module, quantizer)\n", "quantized_graph_module(*example_inputs) # Calibrate the graph module with the example input\n", "quantized_graph_module = convert_pt2e(quantized_graph_module)\n", "\n", @@ -128,8 +128,8 @@ "metadata": {}, "outputs": [], "source": [ - "import subprocess \n", - "import os \n", + "import subprocess\n", + "import os\n", "\n", "# Setup paths\n", "cwd_dir = os.getcwd()\n", @@ -170,9 +170,9 @@ " to_edge_transform_and_lower,\n", ")\n", "from executorch.extension.export_util.utils import save_pte_program\n", - "import platform \n", + "import platform\n", "\n", - "# Create partitioner from compile spec \n", + "# Create partitioner from compile spec\n", "partitioner = EthosUPartitioner(compile_spec)\n", "\n", "# Lower the exported program to the Ethos-U backend\n", @@ -185,8 +185,8 @@ " )\n", "\n", "# Load quantization ops library\n", - "os_aot_lib_names = {\"Darwin\" : \"libquantized_ops_aot_lib.dylib\", \n", - " \"Linux\" : \"libquantized_ops_aot_lib.so\", \n", + "os_aot_lib_names = {\"Darwin\" : \"libquantized_ops_aot_lib.dylib\",\n", + " \"Linux\" : \"libquantized_ops_aot_lib.so\",\n", " \"Windows\": \"libquantized_ops_aot_lib.dll\"}\n", "aot_lib_name = os_aot_lib_names[platform.system()]\n", "\n", @@ -226,7 +226,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Build executorch \n", + "# Build executorch\n", "subprocess.run(os.path.join(script_dir, \"build_executorch.sh\"), shell=True, cwd=et_dir)\n", "\n", "# Build portable kernels\n", diff --git a/examples/mediatek/aot_utils/oss_utils/utils.py b/examples/mediatek/aot_utils/oss_utils/utils.py index 25362788e31..60b8f2fe294 100755 --- a/examples/mediatek/aot_utils/oss_utils/utils.py +++ b/examples/mediatek/aot_utils/oss_utils/utils.py @@ -14,7 +14,7 @@ NeuropilotQuantizer, Precision, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def build_executorch_binary( diff --git a/examples/mediatek/model_export_scripts/llama.py b/examples/mediatek/model_export_scripts/llama.py index 413df21d5cc..a258569b2d7 100644 --- a/examples/mediatek/model_export_scripts/llama.py +++ b/examples/mediatek/model_export_scripts/llama.py @@ -43,7 +43,7 @@ Precision, ) from executorch.exir.backend.backend_details import CompileSpec -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from tqdm import tqdm warnings.filterwarnings("ignore") diff --git a/examples/models/moshi/mimi/test_mimi.py b/examples/models/moshi/mimi/test_mimi.py index 7e2cfb14c49..be3c075913d 100644 --- a/examples/models/moshi/mimi/test_mimi.py +++ b/examples/models/moshi/mimi/test_mimi.py @@ -19,9 +19,9 @@ from huggingface_hub import hf_hub_download from moshi.models import loaders -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export, ExportedProgram from torch.utils._pytree import tree_flatten +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e proxies = { "http": "http://fwdproxy:8080", diff --git a/examples/models/phi-3-mini/export_phi-3-mini.py b/examples/models/phi-3-mini/export_phi-3-mini.py index 11c2f3834eb..246b3ccd6c6 100644 --- a/examples/models/phi-3-mini/export_phi-3-mini.py +++ b/examples/models/phi-3-mini/export_phi-3-mini.py @@ -20,8 +20,8 @@ ) from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config from executorch.exir import to_edge -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export_for_training +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from transformers import Phi3ForCausalLM diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 375edf9fb6c..42f41672780 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -81,8 +81,8 @@ from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer -from torch.ao.quantization.observer import MinMaxObserver -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.observer import MinMaxObserver +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e sys.setrecursionlimit(4096) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" diff --git a/examples/qualcomm/oss_scripts/moshi/mimi.py b/examples/qualcomm/oss_scripts/moshi/mimi.py index 6b59a71ae64..d67d378f0ce 100644 --- a/examples/qualcomm/oss_scripts/moshi/mimi.py +++ b/examples/qualcomm/oss_scripts/moshi/mimi.py @@ -37,7 +37,7 @@ from huggingface_hub import hf_hub_download from moshi.models import loaders -from torch.ao.quantization.observer import MinMaxObserver +from torchao.quantization.pt2e.observer import MinMaxObserver def seed_all(seed): diff --git a/examples/qualcomm/scripts/export_example.py b/examples/qualcomm/scripts/export_example.py index acf8a9ab468..515fdda8b41 100644 --- a/examples/qualcomm/scripts/export_example.py +++ b/examples/qualcomm/scripts/export_example.py @@ -16,7 +16,7 @@ from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.extension.export_util.utils import save_pte_program -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def main() -> None: diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 242170712e1..ad2df4f53ba 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -19,6 +19,7 @@ import numpy as np import torch +import torchao from executorch.backends.qualcomm.quantizer.quantizer import ( ModuleQConfig, QnnQuantizer, @@ -33,8 +34,8 @@ ) from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass -from torch.ao.quantization.observer import MovingAverageMinMaxObserver -from torch.ao.quantization.quantize_pt2e import ( +from torchao.quantization.pt2e.observer import MovingAverageMinMaxObserver +from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, prepare_pt2e, prepare_qat_pt2e, @@ -230,7 +231,7 @@ def ptq_calibrate(captured_model, quantizer, dataset): def qat_train(ori_model, captured_model, quantizer, dataset): data, targets = dataset - annotated_model = torch.ao.quantization.move_exported_model_to_train( + annotated_model = torchao.quantization.pt2e.move_exported_model_to_train( prepare_qat_pt2e(captured_model, quantizer) ) optimizer = torch.optim.SGD(annotated_model.parameters(), lr=0.00001) @@ -239,7 +240,9 @@ def qat_train(ori_model, captured_model, quantizer, dataset): print(f"Epoch {i}") if i > 3: # Freeze quantizer parameters - annotated_model.apply(torch.ao.quantization.disable_observer) + annotated_model.apply( + torchao.quantization.pt2e.fake_quantize.disable_observer + ) if i > 2: # Freeze batch norm mean and variance estimates annotated_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) @@ -250,8 +253,8 @@ def qat_train(ori_model, captured_model, quantizer, dataset): loss.backward() optimizer.step() - return torch.ao.quantization.quantize_pt2e.convert_pt2e( - torch.ao.quantization.move_exported_model_to_eval(annotated_model) + return torchao.quantization.quantize_pt2e.convert_pt2e( + torchao.quantization.move_exported_model_to_eval(annotated_model) ) diff --git a/examples/xnnpack/quantization/example.py b/examples/xnnpack/quantization/example.py index 90a6b94d02b..93831ab8252 100644 --- a/examples/xnnpack/quantization/example.py +++ b/examples/xnnpack/quantization/example.py @@ -29,7 +29,7 @@ _convert_to_reference_decomposed_fx, prepare_fx, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from ...models import MODEL_NAME_TO_MODEL from ...models.model_factory import EagerModelFactory diff --git a/examples/xnnpack/quantization/utils.py b/examples/xnnpack/quantization/utils.py index 9e49f15a99d..d7648daf5da 100644 --- a/examples/xnnpack/quantization/utils.py +++ b/examples/xnnpack/quantization/utils.py @@ -11,7 +11,7 @@ XNNPACKQuantizer, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from .. import QuantType diff --git a/exir/backend/test/demos/test_xnnpack_qnnpack.py b/exir/backend/test/demos/test_xnnpack_qnnpack.py index 5cbd7f7f659..46b665c03df 100644 --- a/exir/backend/test/demos/test_xnnpack_qnnpack.py +++ b/exir/backend/test/demos/test_xnnpack_qnnpack.py @@ -31,15 +31,15 @@ from torch.ao.quantization.backend_config.executorch import ( get_executorch_backend_config, ) -from torch.ao.quantization.observer import ( - default_dynamic_quant_observer, - default_per_channel_weight_observer, -) from torch.ao.quantization.qconfig_mapping import QConfig, QConfigMapping from torch.ao.quantization.quantize_fx import ( _convert_to_reference_decomposed_fx, prepare_fx, ) +from torchao.quantization.pt2e.observer import ( + default_dynamic_quant_observer, + default_per_channel_weight_observer, +) class TestXnnQnnBackends(unittest.TestCase): diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index b87ae2dfb58..4f7bccbbf92 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -41,10 +41,6 @@ from torch.ao.quantization.backend_config.executorch import ( get_executorch_backend_config, ) -from torch.ao.quantization.observer import ( - default_dynamic_quant_observer, - default_per_channel_weight_observer, -) from torch.ao.quantization.qconfig_mapping import QConfig, QConfigMapping from torch.ao.quantization.quantize_fx import ( _convert_to_reference_decomposed_fx, @@ -55,6 +51,10 @@ from torch.export.exported_program import ExportGraphSignature from torch.fx import Graph, GraphModule, Node from torch.nn import functional as F +from torchao.quantization.pt2e.observer import ( + default_dynamic_quant_observer, + default_per_channel_weight_observer, +) torch.ops.load_library("//executorch/kernels/portable:custom_ops_generated_lib") diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 887ca39864a..caec41106e2 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -71,9 +71,6 @@ from functorch.experimental import control_flow from torch import nn - -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer import QuantizationSpec from torch.export import export from torch.export.graph_signature import InputKind, InputSpec, TensorArgument from torch.fx import GraphModule, subgraph_rewriter @@ -82,6 +79,9 @@ from torch.testing import FileCheck from torch.utils import _pytree as pytree +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import QuantizationSpec + # pyre-ignore def collect_ops(gm: torch.fx.GraphModule): @@ -1169,7 +1169,7 @@ def forward(self, query, key, value): ).module() # 8w16a quantization - from torch.ao.quantization.observer import ( + from torchao.quantization.pt2e.observer import ( MinMaxObserver, PerChannelMinMaxObserver, ) diff --git a/exir/tests/test_quantization.py b/exir/tests/test_quantization.py index 0a0a85077bb..fbd015c1e7a 100644 --- a/exir/tests/test_quantization.py +++ b/exir/tests/test_quantization.py @@ -19,18 +19,17 @@ from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.spec_prop_pass import SpecPropPass from torch.ao.ns.fx.utils import compute_sqnr -from torch.ao.quantization import QConfigMapping # @manual +from torch.ao.quantization import ( # @manual + _convert_to_reference_decomposed_fx, + QConfigMapping, +) from torch.ao.quantization.backend_config import get_executorch_backend_config from torch.ao.quantization.qconfig import default_per_channel_symmetric_qnnpack_qconfig from torch.ao.quantization.quantize_fx import prepare_fx -from torch.ao.quantization.quantize_pt2e import ( - _convert_to_reference_decomposed_fx, - convert_pt2e, - prepare_pt2e, -) from torch.export import export from torch.testing import FileCheck from torch.testing._internal.common_quantized import override_quantized_engine +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e # load executorch out variant ops torch.ops.load_library("//executorch/kernels/quantized:custom_ops_generated_lib") diff --git a/exir/tests/test_quantize_io_pass.py b/exir/tests/test_quantize_io_pass.py index ddc0294ba68..f670594616a 100644 --- a/exir/tests/test_quantize_io_pass.py +++ b/exir/tests/test_quantize_io_pass.py @@ -20,8 +20,8 @@ QuantizeOutputs, ) from executorch.exir.tensor import get_scalar_type -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.testing import FileCheck +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e op_str = { "q": "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default", diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 9be7bbeee4b..aa621a78931 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -36,11 +36,11 @@ from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes from pytorch_tokenizers import get_tokenizer -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer import Quantizer -from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer from torch.export import export_for_training, ExportedProgram from torch.nn.attention import SDPBackend +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import Quantizer +from torchao.quantization.pt2e.quantizer.composable_quantizer import ComposableQuantizer FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 24c3be2e802..b5b887324fa 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -16,8 +16,8 @@ XNNPACKQuantizer, ) -from torch.ao.quantization.quantizer import Quantizer -from torch.ao.quantization.quantizer.embedding_quantizer import EmbeddingQuantizer +from torchao.quantization.pt2e.quantizer import Quantizer +from torchao.quantization.pt2e.quantizer.embedding_quantizer import EmbeddingQuantizer FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -154,7 +154,7 @@ def get_qnn_quantizer( QnnQuantizer, QuantDtype, ) - from torch.ao.quantization.observer import MinMaxObserver + from torchao.quantization.pt2e.observer import MinMaxObserver except ImportError: raise ImportError(