Skip to content

Import backend modules in nncf/__init__.py #3451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ Here is an example of Accuracy Aware Quantization pipeline where model weights a

```python
import nncf
import nncf.torch
import torch
from torchvision import datasets, models

Expand Down Expand Up @@ -296,7 +295,7 @@ Here is an example of Accuracy Aware RB Sparsification pipeline where model weig

```python
import torch
import nncf.torch # Important - must be imported before any other external package that depends on torch
import nncf # Important - must be imported before any other external package that depends on torch

from nncf import NNCFConfig
from nncf.torch import create_compressed_model, register_default_init_args
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The framework is designed so that the modifications to your original training co

```python
import torch
import nncf.torch # Important - must be imported before any other external package that depends on torch
import nncf # Important - must be imported before any other external package that depends on torch
from nncf import NNCFConfig, create_compressed_model, load_state
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ The quantized model saving allows to load quantized modules to the target model
requires only example input for the target module, corresponding NNCF config and the quantized model state dict.

```python
import nncf.torch
import nncf

# save part
quantized_model = nncf.quantize(model, calibration_dataset)
Expand Down
1 change: 0 additions & 1 deletion examples/llm_compression/torch/qat_with_lora/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from whowhatbench import TextEvaluator

import nncf
import nncf.torch
from nncf.common.logging.track_progress import track
from nncf.data.dataset import Dataset
from nncf.parameters import CompressionFormat
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from torch._dynamo.exc import BackendCompilerFailed

import nncf
import nncf.torch
from nncf.common.utils.helpers import create_table
from nncf.common.utils.os import is_windows

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from torch.jit import TracerWarning

import nncf
import nncf.torch
from nncf.common.utils.helpers import create_table

warnings.filterwarnings("ignore", category=TracerWarning)
Expand Down
12 changes: 12 additions & 0 deletions nncf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,15 @@
"Please install one of the supported frameworks above in order to use NNCF on top of it.\n"
"See the installation guide at https://github.com/openvinotoolkit/nncf#installation-guide for help."
)

if _AVAILABLE_FRAMEWORKS["torch"]:
from nncf import torch as torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do PyTorch extensions compile during this call?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, extensions compiles on first using


if _AVAILABLE_FRAMEWORKS["tensorflow"]:
from nncf import tensorflow as tensorflow

if _AVAILABLE_FRAMEWORKS["onnx"]:
from nncf import onnx as onnx

if _AVAILABLE_FRAMEWORKS["openvino"]:
from nncf import openvino as openvino
1 change: 0 additions & 1 deletion nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from torch.quantization.fake_quantize import FakeQuantize

import nncf
import nncf.torch
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.experimental.torch.fx.constant_folding import constant_fold
Expand Down
4 changes: 4 additions & 0 deletions nncf/torch/dynamic_graph/patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ def __init__(self, name: str, namespace, op):
ORIGINAL_OPERATORS: List[OriginalOpInfo] = []
ORIGINAL_CALL = torch.nn.Module.__call__
_ORIG_JIT_SCRIPT = None
_ORIG_JIT_SCRIPT_IF_TRACE = None

_ORIG_JIT_TRACE_MAKE_MODULE = None
_ORIG_TORCH_COMPILE: Union[Callable, None] = None

Expand Down Expand Up @@ -312,6 +314,8 @@ def patch_torch_jit():

# Patch torch.jit._script_if_tracing because it references an original
# unpatched torch.jit.script and the patching above does not affect it
global _ORIG_JIT_SCRIPT_IF_TRACE
_ORIG_JIT_SCRIPT_IF_TRACE = getattr(torch.jit, "_script_if_tracing")
setattr(torch.jit, "_script_if_tracing", torch_jit_script_if_tracing)


Expand Down
8 changes: 5 additions & 3 deletions tests/torch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os

# Enable patching of torch functions
# Should be set before any import of nncf
os.environ["NNCF_TORCH_LEGACY_TRACING"] = "1"

import random
from pathlib import Path

Expand Down Expand Up @@ -122,9 +127,6 @@ def pytest_configure(config: Config):
if config.getoption(f"--regen-{regen_option}", False):
os.environ[f"NNCF_TEST_REGEN_{regen_option.upper()}"] = "1"

# Enable patching of torch functions
os.environ["NNCF_TORCH_LEGACY_TRACING"] = "1"


@pytest.fixture(scope="module")
def dataset_dir(request: FixtureRequest):
Expand Down
6 changes: 3 additions & 3 deletions tests/torch/pytorch_patch_isolated.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def clean_source_code(code_source):
@pytest.mark.skipif(ISOLATION_RUN_ENV_VAR not in os.environ, reason="Should be run via isolation proxy")
def test_jit_if_tracing_script_source_equals():
# Get original torch.jit._script_if_tracing source
torch_source = clean_source_code(inspect.getsource(torch.jit._script_if_tracing))
from nncf.torch.dynamic_graph.patch_pytorch import _ORIG_JIT_SCRIPT_IF_TRACE

import nncf.torch # noqa: F401
torch_source = clean_source_code(inspect.getsource(_ORIG_JIT_SCRIPT_IF_TRACE))

# Get torch.jit._script_if_tracing source after patching was performed
nncf_source = clean_source_code(inspect.getsource(torch.jit._script_if_tracing))
Expand Down Expand Up @@ -109,7 +109,7 @@ def forward(self, x):
def test_compile():
compile_forward = os.environ.get("COMPILE_FORWARD", None) == "1"
before_nncf = compile_and_run_test_model(compile_forward)
import nncf.torch # noqa: F401
import nncf # noqa: F401

after_nncf = compile_and_run_test_model(compile_forward)
assert torch.allclose(before_nncf, after_nncf)
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from torch import nn

import nncf
import nncf.torch
from nncf.common.quantization.structs import QuantizationScheme
from nncf.parameters import CompressWeightsMode
from nncf.parameters import StripFormat
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from torch.quantization import FakeQuantize

import nncf
import nncf.torch
from nncf.quantization.advanced_parameters import OverflowFix
from nncf.torch.function_hook.wrapper import get_hook_storage
from tests.common.quantization.data_generators import generate_lazy_sweep_data
Expand Down
1 change: 0 additions & 1 deletion tests/torch2/test_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

def test_patching():
# Check that patching torch functions is disabled
import nncf.torch # noqa: F401

with pytest.raises(AttributeError):
getattr(torch.relu, "_original_op")
Expand Down
Loading