Skip to content

[not yet for review] migrate pt2e from torch.ao to torchao #10294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/apple/coreml/test/test_coreml_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
)

from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
from torch.ao.quantization.quantize_pt2e import (
from torch.export import export_for_training
from torchao.quantization.pt2e.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
prepare_qat_pt2e,
)
from torch.export import export_for_training


class TestCoreMLQuantizer:
Expand Down
12 changes: 6 additions & 6 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,25 @@
is_ethosu,
) # usort: skip
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.ao.quantization.fake_quantize import (
from torch.fx import GraphModule, Node
from torchao.quantization.pt2e import _ObserverOrFakeQuantizeConstructor

Check failure on line 38 in backends/arm/quantizer/arm_quantizer.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY import-untyped

Skipping analyzing "torchao.quantization.pt2e": module is installed, but missing library stubs or py.typed marker To disable, use ` # type: ignore[import-untyped]`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I am wrong but torchao isn't a mandetory dep today but now it is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we define mandatory dependencies? It is installed by the install_requirements script?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we pull in source -

url = https://github.com/pytorch/ao.git

So this submodule is already updated since the tests are passing here.

check (1) if we run tests on et wheels with something quant, (2) if we do are they passing for this diff.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from torchao.quantization.pt2e.fake_quantize import (

Check failure on line 39 in backends/arm/quantizer/arm_quantizer.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY import-untyped

Skipping analyzing "torchao.quantization.pt2e.fake_quantize": module is installed, but missing library stubs or py.typed marker To disable, use ` # type: ignore[import-untyped]`
FakeQuantize,
FusedMovingAvgObsFakeQuantize,
)
from torch.ao.quantization.observer import (
from torchao.quantization.pt2e.observer import (
HistogramObserver,
MinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
from torch.ao.quantization.quantizer.utils import (
from torchao.quantization.pt2e.quantizer import QuantizationSpec, Quantizer
from torchao.quantization.pt2e.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
)
from torch.fx import GraphModule, Node

__all__ = [
"TOSAQuantizer",
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/quantizer/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

import torch
from torch._subclasses import FakeTensor

from torch.ao.quantization.quantizer import QuantizationAnnotation
from torch.fx import GraphModule, Node

from torchao.quantization.pt2e.quantizer import QuantizationAnnotation

Check failure on line 20 in backends/arm/quantizer/arm_quantizer_utils.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY import-untyped

Skipping analyzing "torchao.quantization.pt2e.quantizer": module is installed, but missing library stubs or py.typed marker To disable, use ` # type: ignore[import-untyped]`


def is_annotated(node: Node) -> bool:
"""Given a node return whether the node is annotated."""
Expand Down
9 changes: 6 additions & 3 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
from executorch.backends.arm.quantizer import arm_quantizer_utils
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from executorch.backends.arm.tosa_utils import get_node_debug_info
from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec
from torch.ao.quantization.quantizer.utils import (
from torch.fx import Node
from torchao.quantization.pt2e.quantizer import (

Check failure on line 17 in backends/arm/quantizer/quantization_annotator.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY import-untyped

Skipping analyzing "torchao.quantization.pt2e.quantizer": module is installed, but missing library stubs or py.typed marker To disable, use ` # type: ignore[import-untyped]`
QuantizationSpecBase,
SharedQuantizationSpec,
)
from torchao.quantization.pt2e.quantizer.utils import (

Check failure on line 21 in backends/arm/quantizer/quantization_annotator.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY import-untyped

Skipping analyzing "torchao.quantization.pt2e.quantizer.utils": module is installed, but missing library stubs or py.typed marker To disable, use ` # type: ignore[import-untyped]`
_annotate_input_qspec_map,
_annotate_output_qspec,
)
from torch.fx import Node

logger = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions backends/arm/quantizer/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from dataclasses import dataclass

import torch
from torch.ao.quantization import ObserverOrFakeQuantize
from torchao.quantization.pt2e import ObserverOrFakeQuantize

Check failure on line 12 in backends/arm/quantizer/quantization_config.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY import-untyped

Skipping analyzing "torchao.quantization.pt2e": module is installed, but missing library stubs or py.typed marker To disable, use ` # type: ignore[import-untyped]`

from torch.ao.quantization.quantizer import (
from torchao.quantization.pt2e.quantizer import (

Check failure on line 14 in backends/arm/quantizer/quantization_config.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY import-untyped

Skipping analyzing "torchao.quantization.pt2e.quantizer": module is installed, but missing library stubs or py.typed marker To disable, use ` # type: ignore[import-untyped]`
DerivedQuantizationSpec,
FixedQParamsQuantizationSpec,
QuantizationSpec,
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/test/ops/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
TosaPipelineMI,
)
from executorch.backends.xnnpack.test.tester import Quantize
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantizer import QuantizationSpec
from torchao.quantization.pt2e.observer import HistogramObserver
from torchao.quantization.pt2e.quantizer import QuantizationSpec


aten_op = "torch.ops.aten.add.Tensor"
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/test/ops/test_sigmoid_16bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
TosaPipelineBI,
)
from executorch.backends.xnnpack.test.tester import Quantize
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantizer import QuantizationSpec
from torchao.quantization.pt2e.observer import HistogramObserver
from torchao.quantization.pt2e.quantizer import QuantizationSpec


def _get_16_bit_quant_config():
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/test/ops/test_sigmoid_32bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
TosaPipelineBI,
)
from executorch.backends.xnnpack.test.tester import Quantize
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantizer import QuantizationSpec
from torchao.quantization.pt2e.observer import HistogramObserver
from torchao.quantization.pt2e.quantizer import QuantizationSpec


def _get_16_bit_quant_config():
Expand Down
2 changes: 1 addition & 1 deletion backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
from executorch.exir.program._program import to_edge_with_preserved_ops
from torch._inductor.decomposition import remove_decompositions
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e

from torch.export import export
from torch.export.exported_program import ExportedProgram
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


from .passes import get_cadence_passes

Expand Down
2 changes: 1 addition & 1 deletion backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from torch import fx
from torch._ops import OpOverload
from torch.ao.quantization.quantizer import (
from torchao.quantization.pt2e.quantizer import (
DerivedQuantizationSpec,
SharedQuantizationSpec,
)
Expand Down
6 changes: 3 additions & 3 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@

from torch import fx

from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
from torchao.quantization.pt2e.observer import HistogramObserver, MinMaxObserver
from torchao.quantization.pt2e.quantizer import DerivedQuantizationSpec, Quantizer
from torchao.quantization.pt2e.quantizer.composable_quantizer import ComposableQuantizer


act_qspec_asym8s = QuantizationSpec(
Expand Down
2 changes: 1 addition & 1 deletion backends/cadence/aot/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
import torch
from torch import fx
from torch._ops import OpOverload
from torch.ao.quantization import ObserverOrFakeQuantize

from torch.fx import GraphModule
from torch.fx.passes.utils.source_matcher_utils import (
check_subgraphs_connected,
SourcePartition,
)
from torchao.quantization.pt2e import ObserverOrFakeQuantize


def quantize_tensor_multiplier(
Expand Down
4 changes: 2 additions & 2 deletions backends/cadence/aot/tests/test_remove_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@
from parameterized.parameterized import parameterized
from pyre_extensions import none_throws

from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e

from torch.export import export_for_training
from torch.fx.passes.infra.pass_base import PassResult

from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e


class TestRemoveOpsPasses(unittest.TestCase):
@parameterized.expand(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dim_order_utils import get_dim_order
from executorch.exir.pass_base import ExportPass, PassResult
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torchao.quantization.pt2e.pt2e.graph_utils import find_sequential_partitions


class PermuteMemoryFormatsPass(ExportPass):
Expand Down
2 changes: 1 addition & 1 deletion backends/example/example_operators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation
from torchao.quantization.pt2e.quantizer.quantizer import QuantizationAnnotation


def _nodes_are_annotated(node_list):
Expand Down
2 changes: 1 addition & 1 deletion backends/example/example_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.graph_module import get_control_flow_submodules
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.export import ExportedProgram
from torch.fx.passes.operator_support import OperatorSupportBase
from torchao.quantization.pt2e.pt2e.graph_utils import find_sequential_partitions


@final
Expand Down
6 changes: 3 additions & 3 deletions backends/example/example_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from executorch.backends.example.example_operators.ops import module_to_annotator
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import OperatorConfig
from torch import fx
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
from torchao.quantization.pt2e.observer import HistogramObserver, MinMaxObserver
from torchao.quantization.pt2e.pt2e.graph_utils import find_sequential_partitions
from torchao.quantization.pt2e.quantizer import QuantizationSpec, Quantizer


def get_uint8_tensor_spec(observer_or_fake_quant_ctr):
Expand Down
4 changes: 2 additions & 2 deletions backends/example/test_example_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
DuplicateDequantNodePass,
)
from executorch.exir.delegate import executorch_call_delegate

from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.export import export

from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

from torchvision.models.quantization import mobilenet_v2


Expand Down
12 changes: 6 additions & 6 deletions backends/mediatek/quantizer/annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@
from torch._ops import OpOverload
from torch._subclasses import FakeTensor

from torch.ao.quantization.quantizer import QuantizationAnnotation
from torch.ao.quantization.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
)

from torch.export import export_for_training
from torch.fx import Graph, Node
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
SubgraphMatcherWithNameNodeMap,
)

from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
from torchao.quantization.pt2e.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
)

from .qconfig import QuantizationConfig


Expand Down
6 changes: 3 additions & 3 deletions backends/mediatek/quantizer/qconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

import torch

from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
from torch.ao.quantization.quantizer import QuantizationSpec
from torchao.quantization.pt2e.fake_quantize import FakeQuantize
from torchao.quantization.pt2e.observer import MinMaxObserver, PerChannelMinMaxObserver
from torchao.quantization.pt2e.quantizer import QuantizationSpec


@unique
Expand Down
2 changes: 1 addition & 1 deletion backends/mediatek/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# except in compliance with the License. See the license file in the root
# directory of this source tree for more details.

from torch.ao.quantization.quantizer import Quantizer
from torch.fx import GraphModule
from torchao.quantization.pt2e.quantizer import Quantizer

from .._passes.decompose_scaled_dot_product_attention import (
DecomposeScaledDotProductAttention,
Expand Down
6 changes: 3 additions & 3 deletions backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
QuantizationSpec,
)
from torch import fx
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
from torchao.quantization.pt2e.observer import HistogramObserver, MinMaxObserver
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from torchao.quantization.pt2e.quantizer import DerivedQuantizationSpec, Quantizer
from torchao.quantization.pt2e.quantizer.composable_quantizer import ComposableQuantizer


class NeutronAtenQuantizer(Quantizer):
Expand Down
2 changes: 1 addition & 1 deletion backends/nxp/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from executorch.backends.nxp.quantizer.utils import get_bias_qparams
from torch import fx
from torch._ops import OpOverload
from torch.ao.quantization.quantizer import (
from torchao.quantization.pt2e.quantizer import (
DerivedQuantizationSpec,
FixedQParamsQuantizationSpec,
SharedQuantizationSpec,
Expand Down
2 changes: 1 addition & 1 deletion backends/nxp/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
import torch
from torch import fx
from torch._ops import OpOverload
from torch.ao.quantization import ObserverOrFakeQuantize
from torch.fx.passes.utils.source_matcher_utils import (
check_subgraphs_connected,
SourcePartition,
)
from torchao.quantization.pt2e import ObserverOrFakeQuantize


def is_annotated(nodes: List[fx.Node]) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion backends/nxp/tests/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import executorch.backends.nxp.tests.models as models
import torch
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e


def _get_target_name(node):
Expand Down
4 changes: 2 additions & 2 deletions backends/openvino/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
import torch.fx

from nncf.common.graph.graph import NNCFGraph # type: ignore[import-untyped]
from torch.ao.quantization.observer import (
from torchao.quantization.pt2e.observer import (

Check failure on line 20 in backends/openvino/quantizer/quantizer.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY import-untyped

Skipping analyzing "torchao.quantization.pt2e.observer": module is installed, but missing library stubs or py.typed marker To disable, use ` # type: ignore[import-untyped]`
HistogramObserver,
PerChannelMinMaxObserver,
UniformQuantizationObserverBase,
)
from torch.ao.quantization.quantizer.quantizer import (
from torchao.quantization.pt2e.quantizer.quantizer import (

Check failure on line 25 in backends/openvino/quantizer/quantizer.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY import-untyped

Skipping analyzing "torchao.quantization.pt2e.quantizer.quantizer": module is installed, but missing library stubs or py.typed marker To disable, use ` # type: ignore[import-untyped]`
EdgeOrNode,
QuantizationAnnotation,
QuantizationSpec,
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/_passes/annotate_quant_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def _dequant_fold_params(self, n, quant_attrs, param):
offsets = self._expand(quant_attrs[QCOM_ZERO_POINTS], dim, axis)
param = param.sub(offsets).mul(scales).to(torch.float32).contiguous()
elif quant_attrs[QCOM_ENCODING] in [
exir_ops.edge.pt2e_quant.dequantize_affine.default
exir_ops.edge.torchao.dequantize_affine.default
]:
param = torch.ops.pt2e_quant.dequantize_affine(
param = torch.ops.torchao.dequantize_affine(
param,
block_size=quant_attrs[QCOM_BLOCK_SIZE],
scale=quant_attrs[QCOM_SCALE],
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def get_to_edge_transform_passes(
from executorch.backends.qualcomm._passes import utils
from executorch.exir.dialects._ops import ops as exir_ops

utils.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default)
utils.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default)
utils.q_ops.add(exir_ops.edge.torchao.quantize_affine.default)
utils.dq_ops.add(exir_ops.edge.torchao.dequantize_affine.default)

passes_job = (
passes_job if passes_job is not None else get_capture_program_passes()
Expand Down
6 changes: 3 additions & 3 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ def get_quant_encoding_conf(
)
# TODO: refactor this when target could be correctly detected
per_block_encoding = {
exir_ops.edge.pt2e_quant.quantize_affine.default,
exir_ops.edge.pt2e_quant.dequantize_affine.default,
exir_ops.edge.torchao.quantize_affine.default,
exir_ops.edge.torchao.dequantize_affine.default,
}
if quant_attrs[QCOM_ENCODING] in per_block_encoding:
return self.make_qnn_per_block_config(node, quant_attrs)
Expand Down Expand Up @@ -271,7 +271,7 @@ def get_quant_tensor_value(
axis_order.index(x) for x in range(len(axis_order))
)
tensor = tensor.permute(origin_order)
tensor = torch.ops.pt2e_quant.quantize_affine(
tensor = torch.ops.torchao.quantize_affine(
tensor,
block_size=quant_attrs[QCOM_BLOCK_SIZE],
scale=quant_attrs[QCOM_SCALE],
Expand Down
Loading
Loading