From f757e51c328b575962f8607693f4fef650e1ddb4 Mon Sep 17 00:00:00 2001 From: "Qiaoyu (Joey) Deng" Date: Fri, 26 Jun 2026 09:53:53 -0700 Subject: [PATCH] chore(release): updates for pre 0.2.1 release - version 0.2.0 stays unchanged. --- .pre-commit-config.yaml | 6 +- AGENTS.md | 6 + CONTRIBUTING.md | 1 + Makefile | 46 +- agents/llm-first-principles.md | 5 + ci/__init__.py | 4 + ci/nox/__init__.py | 4 + ci/nox/merge_pytest_results.py | 109 ++ ci/nox/noxfile.py | 95 ++ ci/nox/utils.py | 106 ++ docs/src/conf.py | 7 +- docs/src/index.md | 62 +- docs/src/landing_page.md | 61 + docs/src/pruning/basics.md | 76 ++ docs/src/pruning/config.md | 181 +++ docs/src/pruning/images/pruning_magnitude.svg | 134 ++ docs/src/pruning/images/pruning_schedule.svg | 62 + docs/src/pruning/images/pruning_schemes.svg | 141 ++ docs/src/pruning/index.md | 9 + docs/src/pruning/overview.md | 91 ++ docs/src/tutorials/mnist_palettization.ipynb | 16 +- ...tization_and_activation_quantization.ipynb | 16 +- docs/src/tutorials/mnist_quantization.ipynb | 18 +- docs/src/utils/activation_comparison.md | 313 +++++ docs/tests/test_tutorials.py | 65 + pyproject.toml | 34 +- scripts/make/install_pre_commit_hooks.sh | 2 +- scripts/make/log_versions.py | 35 +- scripts/make/print_api_list.py | 2 +- scripts/make/run_tests_on_latest_coreai.sh | 29 - scripts/make/setup_env.sh | 47 +- scripts/release/release_utils.py | 22 +- src/coreai_opt/_utils/config_utils.py | 47 +- src/coreai_opt/_utils/fx_utils.py | 168 +++ .../_utils/insertion/torch_function/modes.py | 185 +-- .../torch_function/module_boundary_tracker.py | 4 +- .../registered_optimizers_tracker.py | 1 + .../torch_function/state_spec_resolver.py | 212 +++ .../_utils/insertion/torch_function/utils.py | 29 +- src/coreai_opt/_utils/metadata_utils.py | 6 +- src/coreai_opt/_utils/registry_utils.py | 33 + src/coreai_opt/_utils/torch_utils.py | 50 - src/coreai_opt/config/__init__.py | 1 + .../coreai_utils/_utils/sparse_utils.py | 1 - src/coreai_opt/inspection/__init__.py | 12 +- src/coreai_opt/inspection/_common.py | 97 ++ src/coreai_opt/inspection/_eager_mode.py | 576 ++++++++ src/coreai_opt/inspection/_formatting.py | 88 +- src/coreai_opt/inspection/_graph_mode.py | 285 ++-- src/coreai_opt/inspection/model_inspector.py | 84 +- src/coreai_opt/inspection/types.py | 120 +- .../kmeans/kmeans_support_mixins.py | 11 + .../palettization/kmeans/palettizer.py | 5 + .../kmeans/supported_ops_registry.py | 16 + src/coreai_opt/palettization/spec/factory.py | 1 - .../palettization/spec/fake_palettize.py | 25 +- .../palettization/spec/granularity.py | 3 +- src/coreai_opt/pruning/spec/spec.py | 42 +- .../quantization/_eager/quantizer.py | 22 +- .../_eager/supported_ops_registry.py | 2 + .../quantization/_graph/_annotation_config.py | 29 + .../_graph/_annotation_pattern_registry.py | 21 +- .../quantization/_graph/_annotation_utils.py | 208 +-- .../quantization/_graph/quantizer.py | 139 +- src/coreai_opt/quantization/_utils.py | 5 +- src/coreai_opt/quantization/quantizer.py | 54 +- src/coreai_opt/quantization/spec/__init__.py | 6 + src/coreai_opt/quantization/spec/errors.py | 1 + src/coreai_opt/quantization/spec/factory.py | 10 + .../quantization/spec/fake_quantize.py | 26 +- .../quantization/spec/qparams_calculator.py | 162 ++- .../quantization/spec/range_calculator.py | 12 +- src/coreai_opt/quantization/spec/spec.py | 51 +- tests/conftest.py | 1163 +---------------- tests/coreai_utils/test_sparse_utils.py | 193 +++ tests/coreai_utils/test_sparsify_weights.py | 112 ++ tests/export/export_utils.py | 7 +- tests/export/test_eager_mil_export.py | 2 +- tests/export/test_eager_mlir_export.py | 34 +- tests/export/test_graph_mode_mlir_export.py | 38 +- tests/export/test_kmeans_export.py | 8 +- tests/export/test_pruning_export.py | 2 +- tests/export/test_pt2e_mil_export.py | 2 +- tests/fixtures/__init__.py | 4 + tests/fixtures/compression.py | 112 ++ tests/fixtures/fp4.py | 163 +++ tests/fixtures/fp8.py | 179 +++ tests/fixtures/palettization.py | 161 +++ tests/fixtures/pruning.py | 66 + tests/fixtures/quantization.py | 522 ++++++++ tests/models/__init__.py | 4 + .../test_kmeans_fake_palettize.py | 22 +- tests/palettization/test_kmeans_palettizer.py | 18 +- .../test_kmeans_palettizer_mnist.py | 41 +- .../test_annotation_pattern_registry.py | 46 +- tests/quantization/test_eager_quant.py | 16 +- tests/quantization/test_factory.py | 40 +- .../quantization/test_graph_mode_quantizer.py | 19 +- tests/quantization/test_is_state_node.py | 18 +- tests/quantization/test_qparams_calculator.py | 163 +++ tests/quantization/test_quantization.py | 499 ++++++- .../quantization/test_state_spec_resolver.py | 441 +++++++ tests/test_inspection.py | 995 +++++++++++++- tests/test_joint_compression.py | 14 +- tests/test_nox_utils.py | 93 ++ tests/test_utils/general.py | 29 +- tests/test_utils/test_config_utils.py | 59 + tests/test_utils/test_general.py | 25 + tests/test_utils/test_registry_utils.py | 15 + tests/test_utils/test_torch_utils.py | 2 +- tests/utils.py | 12 + 111 files changed, 7703 insertions(+), 2291 deletions(-) create mode 100644 ci/__init__.py create mode 100644 ci/nox/__init__.py create mode 100644 ci/nox/merge_pytest_results.py create mode 100644 ci/nox/noxfile.py create mode 100644 ci/nox/utils.py create mode 100644 docs/src/landing_page.md create mode 100644 docs/src/pruning/basics.md create mode 100644 docs/src/pruning/config.md create mode 100644 docs/src/pruning/images/pruning_magnitude.svg create mode 100644 docs/src/pruning/images/pruning_schedule.svg create mode 100644 docs/src/pruning/images/pruning_schemes.svg create mode 100644 docs/src/pruning/index.md create mode 100644 docs/src/pruning/overview.md create mode 100644 docs/src/utils/activation_comparison.md create mode 100644 docs/tests/test_tutorials.py delete mode 100755 scripts/make/run_tests_on_latest_coreai.sh create mode 100644 src/coreai_opt/_utils/fx_utils.py create mode 100644 src/coreai_opt/_utils/insertion/torch_function/state_spec_resolver.py create mode 100644 src/coreai_opt/inspection/_common.py create mode 100644 src/coreai_opt/inspection/_eager_mode.py create mode 100644 tests/coreai_utils/test_sparse_utils.py create mode 100644 tests/fixtures/__init__.py create mode 100644 tests/fixtures/compression.py create mode 100644 tests/fixtures/fp4.py create mode 100644 tests/fixtures/fp8.py create mode 100644 tests/fixtures/palettization.py create mode 100644 tests/fixtures/pruning.py create mode 100644 tests/fixtures/quantization.py create mode 100644 tests/models/__init__.py create mode 100644 tests/quantization/test_state_spec_resolver.py create mode 100644 tests/test_nox_utils.py create mode 100644 tests/test_utils/test_config_utils.py create mode 100644 tests/test_utils/test_general.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 304e228..d1c94c6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -62,7 +62,7 @@ repos: language: system entry: end-of-file-fixer types: [text] - exclude: ^\.codegenius/learnings\.md$ + exclude: ^\.codegenius/ - id: check-yaml name: check yaml @@ -174,7 +174,7 @@ repos: language: system entry: mdformat --number files: \.(md)$ - exclude: ^(CHANGELOG\.md|\.codegenius/learnings\.md)$ + exclude: ^(CHANGELOG\.md$|\.codegenius/) args: [ --number, # Force 1., 2., 3. instead of all 1. --wrap=keep, # Preserve existing line breaks @@ -255,7 +255,7 @@ repos: language: system entry: pymarkdown --disable-rules MD013,MD024 scan files: \.(md)$ - exclude: ^\.codegenius/learnings\.md$ + exclude: ^\.codegenius/ # ---------------------------------------------------------------------------- # 3.5 Lint Shell Script (bashate) diff --git a/AGENTS.md b/AGENTS.md index 0ffd2ac..2b0b709 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -17,6 +17,12 @@ - `pytest -n auto path/to/test.py` — run a test file - `pytest path/to/test.py::test_name` — run a single test +## uv usage + +Always pass `--no-sync` to `uv run`: `uv run --no-sync --active …`. + +`uv run` implicitly syncs the active project to its default-groups before running, which re-resolves dependencies and can clobber a venv's group-pinned packages — e.g. the torch pin in `.venv-lowest-torch`/`.venv-highest-torch` gets re-anchored back to the default torch. Our Make targets always prepare the environment first via `use_env`/`setup_env.sh`, so by the time `uv run` executes the deps are already correct. A `uv run` invocation should be a read-only run of a command in that prepared env, never a dependency mutation — `--no-sync` enforces that. + ## Editing Guidelines - Use `@path` to reference small files (loaded into every session automatically). diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 792e84d..d83cb00 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -69,6 +69,7 @@ Before pushing your changes, run these locally: - `make check` — full lint and `mypy` type-check pass - `make test` — full test suite (parallelized with `pytest-xdist`) - `make test-fast` — excludes tests marked `@pytest.mark.slow` for quicker iteration +- `make test-smoke` — builds the package, installs it into a clean environment, and verifies that imports plus basic quantization and palettization work end to end A clean `make check` and `make test` are required before a pull request will be reviewed. diff --git a/Makefile b/Makefile index 526067a..421b733 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,9 @@ # Use of this source code is governed by a BSD-3-Clause license that can # be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause -.PHONY: _maybe_patch_pyproject all api-list build check clean distclean distclean-all docs docs-clean docs-open env env-all env-docs env-highest-torch env-latest-coreai env-tutorial set-auto-venv test test-coreai-compression test-cov test-export test-fast test-highest-pytorch test-lowest-pytorch test-slow version +.PHONY: _maybe_patch_pyproject all api-list build check clean distclean distclean-all docs docs-clean docs-open env env-all env-docs env-highest-torch env-tutorial set-auto-venv test test-cov test-fast test-highest-pytorch test-lowest-pytorch test-slow test-smoke test-tutorials version + +SHELL := /bin/bash # Directory holding this Makefile, derived from its own location so the same # recipes work in both contexts: @@ -161,7 +163,7 @@ endif # ============================================================================= # Default target - run full workflow -all: clean distclean-all env-all check test-export test-lowest-pytorch test-highest-pytorch build +all: clean distclean-all env-all check test-lowest-pytorch test-highest-pytorch build # ============================================================================= # Environment Setup @@ -177,10 +179,6 @@ env-highest-torch: _maybe_patch_pyproject @$(SETUP_ENV) --venv $(VENV_HIGHEST_TORCH) --python-version $(PYTHON_VERSION) --with-highest_tested_torch @$(call write_active_venv,$(VENV_HIGHEST_TORCH)) -# Set up development environment with latest CoreAI for export testing -env-latest-coreai: _maybe_patch_pyproject - @$(SETUP_ENV) --venv .venv_latest_coreai --python-version $(PYTHON_VERSION) --with-latest-coreai --without-stable-coreai - # Set up environment for running tutorials (quantization notebook) env-tutorial: _maybe_patch_pyproject @$(SETUP_ENV) --venv $(VENV_TUTORIAL) --python-version $(PYTHON_VERSION) --with-tutorial @@ -198,7 +196,7 @@ env-all: _maybe_patch_pyproject # Build package build: - @$(call use_env,VENV) && uv run --active python $(SCRIPTS)/make/build.py + @$(call use_env,VENV) && uv run --no-sync --active python $(SCRIPTS)/make/build.py # ============================================================================= # Code Quality @@ -207,13 +205,13 @@ build: # Print public API surface (symbols declared in __all__ across all public packages). # Pass MODULE= to inspect a single module: make api-list MODULE=coreai_opt.quantization.spec.spec api-list: - @$(call use_env,VENV) && uv run --active python $(SCRIPTS)/make/print_api_list.py $(MODULE) + @$(call use_env,VENV) && uv run --no-sync --active python $(SCRIPTS)/make/print_api_list.py $(MODULE) # Run linting and type checking. check: @$(call use_env,VENV) && \ echo "Running linting and formatting checks..." && \ - uv run --active pre-commit run --all-files && \ + uv run --no-sync --active pre-commit run --all-files && \ echo "All checks passed!" # ============================================================================= @@ -236,20 +234,19 @@ test-fast: test-slow: @$(MAKE) test PYTEST_ARGS="--marker slow" -# Run export tests with latest CoreAI (pass PYTEST_ARGS for custom flags) -test-export: env-latest-coreai - @$(SCRIPTS)/make/run_tests_on_latest_coreai.sh --path tests/export/ $(PYTEST_ARGS) - -# Run coreai compression tests with latest CoreAI (pass PYTEST_ARGS for custom flags) -test-coreai-compression: env-latest-coreai - @$(SCRIPTS)/make/run_tests_on_latest_coreai.sh --path tests/coreai_utils/ $(PYTEST_ARGS) +# Run smoke tests only (pass PYTEST_ARGS for custom flags, e.g., make test-smoke PYTEST_ARGS="--junitxml=results.xml"). +test-smoke: + @$(call use_env,VENV) && \ + echo "Running smoke tests..." && \ + uv run --no-sync --active nox -f $(MAKEFILE_DIR)ci/nox/noxfile.py -s smoke_tests -- $(PYTEST_ARGS) && \ + echo "All smoke tests passed!" # Run tests on lowest supported PyTorch version (pass PYTEST_ARGS for custom flags) test-lowest-pytorch: @echo "Running tests on lowest PyTorch version supported..." - @$(call use_env,VENV_LOWEST_TORCH,--with-lowest_tested_torch --without-stable-coreai) && \ + @$(call use_env,VENV_LOWEST_TORCH,--with-lowest_tested_torch) && \ echo "Testing with lowest supported PyTorch versions" && \ - uv run --active python $(SCRIPTS)/make/log_versions.py && \ + uv run --no-sync --active python $(SCRIPTS)/make/log_versions.py && \ $(RUN_TESTS) $(PYTEST_ARGS) && \ echo "All tests passed!" @@ -258,10 +255,17 @@ test-highest-pytorch: @echo "Running tests on highest PyTorch version supported..." @$(call use_env,VENV_HIGHEST_TORCH,--with-highest_tested_torch) && \ echo "Testing with latest supported PyTorch versions" && \ - uv run --active python $(SCRIPTS)/make/log_versions.py && \ + uv run --no-sync --active python $(SCRIPTS)/make/log_versions.py && \ $(RUN_TESTS) $(PYTEST_ARGS) && \ echo "All tests passed!" +# Run tutorial notebook tests +test-tutorials: + @$(call use_env,VENV_TUTORIAL,--with-tutorial --with-test) && \ + echo "Running tutorial notebook tests..." && \ + $(RUN_TESTS) --path $(DOCS_DIR)/tests/test_tutorials.py $(PYTEST_ARGS) && \ + echo "All tutorial tests passed!" + # ============================================================================= # Maintenance # ============================================================================= @@ -288,7 +292,7 @@ set-auto-venv: # Show current version version: - @python -c "exec(open('./src/coreai_opt/_about.py').read()); print(__version__)" + @python -c "exec(open('$(MAKEFILE_DIR)src/coreai_opt/_about.py').read()); print(__version__)" # ============================================================================= # Documentation @@ -316,7 +320,7 @@ endif @echo "==> [4/5] Setting up docs environment" && \ $(call use_env,VENV_DOCS,--with-docs) && \ echo "==> [5/5] Building documentation" && \ - cd $(DOCS_DIR) && uv run --active sphinx-build -E -b html src build/html + cd $(DOCS_DIR) && uv run --no-sync --active sphinx-build -E -b html src build/html ifndef _DOCS_ALL @echo "" @echo "════════════════════════════════════════════════════════════════════" diff --git a/agents/llm-first-principles.md b/agents/llm-first-principles.md index 8160521..79995b1 100644 --- a/agents/llm-first-principles.md +++ b/agents/llm-first-principles.md @@ -9,6 +9,7 @@ Behavioral guidelines to reduce common LLM coding mistakes. Merge with project-s - [3. Surgical Changes](#3-surgical-changes) - [4. Goal-Driven Execution](#4-goal-driven-execution) - [Success Indicators](#success-indicators) +- [References](#references) ## 1. Think Before Coding @@ -78,3 +79,7 @@ These principles are working when: - Diffs contain fewer unrelated changes. - Fewer rewrites stem from overcomplication. - Clarifying questions arrive before mistakes, not after. + +## References + +The principles above are adapted from Andrej Karpathy's coding guidance, via the [andrej-karpathy-skills](https://github.com/multica-ai/andrej-karpathy-skills) project (MIT License). diff --git a/ci/__init__.py b/ci/__init__.py new file mode 100644 index 0000000..30b83ae --- /dev/null +++ b/ci/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/ci/nox/__init__.py b/ci/nox/__init__.py new file mode 100644 index 0000000..30b83ae --- /dev/null +++ b/ci/nox/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/ci/nox/merge_pytest_results.py b/ci/nox/merge_pytest_results.py new file mode 100644 index 0000000..5c00dfb --- /dev/null +++ b/ci/nox/merge_pytest_results.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 + +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +"""Merge multiple pytest JUnit XML result files into a single file. + +This script is used in CI to combine session-specific pytest results +(e.g., pytest-results-3.10.xml, pytest-results-3.11.xml) into a single +pytest-results.xml file. + +Usage: + python merge_pytest_results.py [--input-dir DIR] [--output FILE] [--pattern PATTERN] + +Examples: + python merge_pytest_results.py --input-dir results --output combined.xml +""" + +import argparse +import logging +import sys +from pathlib import Path + +from junitparser import JUnitXml + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(levelname)s: %(message)s", +) + + +def merge_junit_xml_files( + input_dir: str = "test-results", + output_file: str = "test-results/pytest-results.xml", + pattern: str = "pytest-results-*.xml", +) -> bool: + """Merge multiple JUnit XML files into a single file. + + Args: + input_dir: Directory containing the XML files to merge + output_file: Path to the output merged XML file + pattern: Glob pattern for matching input files + + Returns: + True if merge was successful, False otherwise + """ + input_path = Path(input_dir) + xml_files = list(input_path.glob(pattern)) + + if not xml_files: + logging.info(f"No files matching '{pattern}' found in {input_dir}, nothing to merge") + return True + + logging.info(f"Found {len(xml_files)} XML files to merge:") + for f in sorted(xml_files): + logging.info(f" - {f}") + + merged = JUnitXml() + for xml_file in sorted(xml_files): + try: + merged += JUnitXml.fromfile(str(xml_file)) + except Exception as e: + logging.error(f"Error parsing {xml_file}: {e}") + return False + + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + merged.write(str(output_path)) + + logging.info(f"Successfully merged into {output_file}") + return True + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Merge multiple pytest JUnit XML result files into a single file." + ) + parser.add_argument( + "--input-dir", + default="test-results", + help="Directory containing XML files to merge (default: test-results)", + ) + parser.add_argument( + "--output", + default="test-results/pytest-results.xml", + help="Output file path (default: test-results/pytest-results.xml)", + ) + parser.add_argument( + "--pattern", + default="pytest-results-*.xml", + help="Glob pattern for input files (default: pytest-results-*.xml)", + ) + + args = parser.parse_args() + + success = merge_junit_xml_files( + input_dir=args.input_dir, + output_file=args.output, + pattern=args.pattern, + ) + + return not success + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/ci/nox/noxfile.py b/ci/nox/noxfile.py new file mode 100644 index 0000000..bfa3270 --- /dev/null +++ b/ci/nox/noxfile.py @@ -0,0 +1,95 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +"""Nox sessions for testing against multiple Python versions. + +This module defines nox sessions to test the coreai-opt package against: +1. Supported Python versions (blocking for CI) +""" + +import os +import sys +from pathlib import Path + +from nox import Session, options +from nox_uv import session + +from coreai_opt._utils.repo_utils import find_repo_root + +# Find repository root (where pyproject.toml is located) +REPO_ROOT = find_repo_root(__file__) + +# Add repository root to sys.path so we can import ci package +sys.path.insert(0, str(REPO_ROOT)) +os.environ.setdefault("UV_PROJECT", str(REPO_ROOT)) + +from ci.nox.utils import ( # noqa: E402 + change_dir_to_project_root, + get_pytest_executable, + get_supported_python_versions, +) + +options.default_venv_backend = "uv" +options.error_on_missing_interpreters = True + + +@session(python=get_supported_python_versions(), uv_extras=["coreai"], uv_groups=["test"]) +def smoke_tests(session: Session) -> None: + """Smoke test the package build and coreai_opt imports and basic functionality. + + Builds the package using the nox session's Python version, installs it + in a clean environment, and runs smoke tests to verify functionality. + """ + change_dir_to_project_root(session) + session.log(f"Building package with Python {session.python}") + session.install("build") + session.run("make", "build", external=True) + session.log("Installing built package") + + # Find the built wheel + wheels = list(Path("dist").glob("*.whl")) + if not wheels: + session.error(f"Build unsuccessful for Python {session.python}") + session.error("No wheel found in dist/") + latest_wheel = max(wheels, key=lambda p: p.stat().st_mtime) + session.install(str(latest_wheel)) + session.log("Build Succeeded!") + + # setuptools is needed by torch.utils.cpp_extension (used by PT2E quantization); + # required on Python 3.12+ where distutils was removed from stdlib. + session.install("setuptools") + + session.log("Running smoke tests") + + # Use run_tests.sh to properly handle --junit and other custom flags + # The script handles --junit by converting it to --junitxml + # Pass the session's pytest executable to ensure we use the nox venv's pytest + # Process posargs to handle --junit flag for unique filenames per Python version + if session.posargs and "--junit" in session.posargs: + test_args = [arg for arg in session.posargs if arg != "--junit"] + test_args.extend( + [ + f"--junitxml=test-results/pytest-results-{session.python}.xml", + "--cov-append", + ] + ) + else: + test_args = list(session.posargs) if session.posargs else [] + session.run( + str(REPO_ROOT / "scripts" / "make" / "run_tests.sh"), + "--pytest", + get_pytest_executable(session), + "--path", + str(REPO_ROOT / "tests" / "test_smoke.py"), + # Disable pytest-xdist for smoke tests because it makes test suite much slower + # This can be overriden by user by setting workers in test_args + "--workers", + "0", + "--noconftest", + *test_args, + external=True, + ) + + session.log("Smoke test passed!") diff --git a/ci/nox/utils.py b/ci/nox/utils.py new file mode 100644 index 0000000..643dcf8 --- /dev/null +++ b/ci/nox/utils.py @@ -0,0 +1,106 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +"""Utility functions for nox sessions.""" + +import os +import tomllib +from pathlib import Path + +from nox import Session +from packaging.specifiers import SpecifierSet + +from coreai_opt._utils.repo_utils import find_repo_root + +# Find repository root (where pyproject.toml is located) +REPO_ROOT = find_repo_root(__file__) + + +def change_dir_to_project_root(session: Session) -> None: + """Change to the project root used for builds and ``uv``. + + ``UV_PROJECT``, when set, points at the directory that holds ``uv.lock`` and + the build-ready ``pyproject.toml``; otherwise the package root fills both + roles. + """ + session.chdir(os.environ.get("UV_PROJECT") or str(REPO_ROOT)) + + +def get_pytest_executable(session: Session) -> str: + """Get the command to run pytest using the nox session's Python executable. + + Args: + session: Nox session + + Returns: + Command string to run pytest via python -m pytest + """ + python_path = Path(session.bin) / "python" + return f"{python_path} -m pytest" + + +def _get_minimum_python_minor_version(specifier: SpecifierSet) -> int: + """Extract the minimum Python minor version from a specifier set. + + Args: + specifier: A SpecifierSet parsed from requires-python. + + Returns: + The minimum minor version (e.g., 10 for ">=3.10"). + + Raises: + ValueError: If no lower bound is specified in the specifier. + """ + for spec in specifier: + if spec.operator in (">=", ">"): + return int(spec.version.split(".")[1]) + raise ValueError(f"No lower bound found in specifier: {specifier}") + + +def get_supported_python_versions() -> list[str]: + """Parse requires-python from pyproject.toml and return supported versions.""" + with open(REPO_ROOT / "pyproject.toml", "rb") as f: + pyproject = tomllib.load(f) + + specifier = SpecifierSet(pyproject["project"]["requires-python"]) + min_minor = _get_minimum_python_minor_version(specifier) + + # Generate candidate versions from lower bound up to a reasonable upper limit + all_versions = [f"3.{minor}" for minor in range(min_minor, min_minor + 10)] + return [v for v in all_versions if specifier.contains(v)] + + +def build_pytest_args( + default_args: list[str], + posargs: list[str] | None = None, + python_version: str | None = None, +) -> list[str]: + """Build pytest arguments from posargs with optional session-specific handling. + + Args: + posargs: Optional pytest arguments passed from command line (session.posargs) + default_args: Default arguments to use when posargs is empty. + If None, defaults to ["-v"] + python_version: Optional Python version for session-specific junitxml filename. + When provided and --junit is in posargs, adds + --junitxml=test-results/pytest-results-{python_version}.xml + and --cov-append if --cov is also present. + + Returns: + List of pytest arguments + """ + if not posargs: + return list(default_args) + + pytest_args = list(posargs) + + # Handle session-specific junitxml for multi-python-version sessions + if python_version and "--junit" in posargs: + pytest_args.append(f"--junitxml=test-results/pytest-results-{python_version}.xml") + # Use --cov-append to accumulate coverage across sessions + if "--cov" in posargs or any(arg.startswith("--cov=") for arg in posargs): + pytest_args.append("--cov-append") + + return pytest_args diff --git a/docs/src/conf.py b/docs/src/conf.py index 23b3df7..ac707b2 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -32,10 +32,9 @@ _docs_dir = Path(__file__).parent.parent.resolve() sys.path.insert(0, str(_docs_dir)) -# Find the repo / package root by walking up until we hit ``src/coreai_opt``. -# In OSS layout this is the directory immediately above docs/; in internal -# layout it's three levels up from docs/src/conf.py (since src/ is a sibling -# of external/, not external/docs/). +# Find the package root by walking up until we hit ``src/coreai_opt``. In both +# layouts that root is the directory immediately above docs/: the OSS root after +# export, or ``external/`` in the internal repo. _search = _docs_dir while _search != _search.parent: if (_search / "src" / "coreai_opt").exists(): diff --git a/docs/src/index.md b/docs/src/index.md index 840d4e7..92b8f7c 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,67 +1,8 @@ -# Core AI Optimization Documentation - -## What is `coreai-opt`? - -`coreai-opt` is a Python library for compressing PyTorch models for deployment on Apple Silicon. It allows you to apply compression-based optimizations (such as quantization or palettization) to any PyTorch model, producing a transformed PyTorch model that can be converted to a Core AI model and run with the [Core AI](https://developer.apple.com/documentation/coreai) framework. - -Model compression can help reduce the memory footprint of your model (disk size and at runtime), reduce inference latency, reduce power consumption, or optimize them all at once. - -```{mermaid} -flowchart LR - A[PyTorch model] --> B(coreai-opt) - B --> C["Transformed
PyTorch model
(compressed)"] - C --> D("coreai-torch
(convert)") - D --> E["Core AI model
(.aimodel)"] - style A color:#999,fill:none,stroke:none - style C color:#999,fill:none,stroke:none - style E color:#999,fill:none,stroke:none - linkStyle default stroke:#999,stroke-width:1.5px +```{include} landing_page.md ``` -`coreai-opt` is built around the following ideas: - -- **PyTorch native.** All APIs operate on PyTorch models. Compression is another transformation in your PyTorch workflow. The output of every compressor is itself a PyTorch model that can be validated, fine-tuned, and exported like any other model. - -- **Integrates with existing PyTorch code.** Adding post-training compression, calibration-based, or compression-aware training to an existing PyTorch pipeline takes a few additional lines of code. All three use the same compressor object. - -- **Aligned with Apple Silicon.** Default configurations and the majority of the available optimization options align with what the [Core AI](https://developer.apple.com/documentation/coreai) runtime executes efficiently, on one or many of the Apple Silicon platforms. Compressed PyTorch models can be seamlessly converted to `.aimodel` for deployment via Core AI. - -## Types of compression - -Available APIs cover the following categories of compression: - -- **[Quantization](quantization/index.md)** approximates weights and/or activations using a quantization function. Weight precisions include INT2, INT4, INT8 and FP4, FP8; activation precisions include INT8 and FP8. -- **[Palettization](palettization/index.md)**, also known as codebook-style compression, clusters weights into a look-up table of centroids and stores indices in their place. Weights can be palettized to N ∈ {1, 2, 3, 4, 6, 8} bits. -- **Pruning** zeros out weights with the smallest magnitudes and stores the remaining weights using sparse representations. - -These techniques can also be combined and applied in a hybrid fashion — for example, applying different palettization bit widths to different weights, or combining weight palettization with activation quantization — to build customized optimization recipes. - -## Compression workflows - -The process of applying compression to a model typically involves the following stages. - -- **Data-free compression**: Weight-only compression that needs only the model — no calibration or training data. (Test data and an evaluation metric are still used to validate the result.) The fastest workflow — typically seconds to minutes even for large models. Often works well for reducing the model down to 8 bits, or even 6 or 4 bits, with only a slight decrease in accuracy. Typical approaches used for getting more aggressive compression, effective bits-per-weight (bpw) < 5 bits, involve using more granular compression (e.g. per-block quantization, per-grouped-channel palettization) and/or mixed-bit compression (assigning different bits to different weights, based on their effect on accuracy). - -- **Calibration-based compression**: Post-training compression with calibration data. Often used when quantizing activations. A small amount of representative data (e.g. ~128 samples) lets compressors observe activation ranges and weight sensitivities. - -- **Fine-tuning-based compression**: Compression-aware fine-tuning (e.g. quantization-aware training) with full training data. The compressor is integrated into your training loop so the model adapts to compression error as it trains. The most time-intensive workflow, but typically the only way to recover accuracy at the most aggressive compression ratios for weights (4 bits and below), and/or for models that are sensitive to activation quantization. - -`coreai-opt`'s APIs allow you to easily move from one stage to the next while evaluating accuracy after each stage and escalating to a more expensive workflow only when needed. - -## Getting started - -For an overview of the generic structure of `coreai-opt` APIs, see [How to use coreai-opt](introduction/how_to_use_coreaiopt.md). - -For end-to-end examples on API usage and common workflows, see [MNIST examples](examples/toy_models.md) and [model examples](examples/model_examples.md). - -## Links to related Core AI components - -- **[coreai-torch](https://github.com/apple/coreai-torch)** — Python library for converting PyTorch models to the Core AI (`.aimodel`) format. -- **[coreai-models](https://github.com/apple/coreai-models)** — GitHub repository with example models demonstrating how to convert, optimize, and re-author models for deployment with Core AI. Several of the LLMs in there are compressed to ~4–5 bits using `coreai-opt`. The repo also contains a number of AI skills, including some that wrap `coreai-opt` workflows. -- **[Core AI framework](https://developer.apple.com/documentation/coreai)** — Apple's on-device AI framework that runs `.aimodel` models. - ```{toctree} :maxdepth: 1 :caption: Introduction @@ -98,6 +39,7 @@ palettization/index utils/joint_compression utils/mixed_precision utils/model_inspection +utils/activation_comparison utils/casting utils/coreai_compression ``` diff --git a/docs/src/landing_page.md b/docs/src/landing_page.md new file mode 100644 index 0000000..9416f2a --- /dev/null +++ b/docs/src/landing_page.md @@ -0,0 +1,61 @@ +# Core AI Optimization Documentation + +## What is `coreai-opt`? + +`coreai-opt` is a Python library for compressing PyTorch models for deployment on Apple Silicon. It allows you to apply compression-based optimizations (such as quantization or palettization) to any PyTorch model, producing a transformed PyTorch model that can be converted to a Core AI model and run with the [Core AI](https://developer.apple.com/documentation/coreai) framework. + +Model compression can help reduce the memory footprint of your model (disk size and at runtime), reduce inference latency, reduce power consumption, or optimize them all at once. + +```{mermaid} +flowchart LR + A[PyTorch model] --> B(coreai-opt) + B --> C["Transformed
PyTorch model
(compressed)"] + C --> D("coreai-torch
(convert)") + D --> E["Core AI model
(.aimodel)"] + style A color:#999,fill:none,stroke:none + style C color:#999,fill:none,stroke:none + style E color:#999,fill:none,stroke:none + linkStyle default stroke:#999,stroke-width:1.5px +``` + +`coreai-opt` is built around the following ideas: + +- **PyTorch native.** All APIs operate on PyTorch models. Compression is another transformation in your PyTorch workflow. The output of every compressor is itself a PyTorch model that can be validated, fine-tuned, and exported like any other model. + +- **Integrates with existing PyTorch code.** Adding post-training compression, calibration-based, or compression-aware training to an existing PyTorch pipeline takes a few additional lines of code. All three use the same compressor object. + +- **Aligned with Apple Silicon.** Default configurations and the majority of the available optimization options align with what the [Core AI](https://developer.apple.com/documentation/coreai) runtime executes efficiently, on one or many of the Apple Silicon platforms. Compressed PyTorch models can be seamlessly converted to `.aimodel` for deployment via Core AI. + +## Types of compression + +Available APIs cover the following categories of compression: + +- **[Quantization](quantization/index.md)** approximates weights and/or activations using a quantization function. Weight precisions include INT2, INT4, INT8 and FP4, FP8; activation precisions include INT8 and FP8. +- **[Palettization](palettization/index.md)**, also known as codebook-style compression, clusters weights into a look-up table of centroids and stores indices in their place. Weights can be palettized to N ∈ {1, 2, 3, 4, 6, 8} bits. +- **[Pruning](pruning/index.md)** zeros out weights with the smallest magnitudes and stores the remaining weights using sparse representations. + +These techniques can also be combined and applied in a hybrid fashion — for example, applying different palettization bit widths to different weights, or combining weight palettization with activation quantization — to build customized optimization recipes. + +## Compression workflows + +The process of applying compression to a model typically involves the following stages. + +- **Data-free compression**: Weight-only compression that needs only the model — no calibration or training data. (Test data and an evaluation metric are still used to validate the result.) The fastest workflow — typically seconds to minutes even for large models. Often works well for reducing the model down to 8 bits, or even 6 or 4 bits, with only a slight decrease in accuracy. Typical approaches used for getting more aggressive compression, effective bits-per-weight (bpw) < 5 bits, involve using more granular compression (e.g. per-block quantization, per-grouped-channel palettization) and/or mixed-bit compression (assigning different bits to different weights, based on their effect on accuracy). + +- **Calibration-based compression**: Post-training compression with calibration data. Often used when quantizing activations. A small amount of representative data (e.g. ~128 samples) lets compressors observe activation ranges and weight sensitivities. + +- **Fine-tuning-based compression**: Compression-aware fine-tuning (e.g. quantization-aware training) with full training data. The compressor is integrated into your training loop so the model adapts to compression error as it trains. The most time-intensive workflow, but typically the only way to recover accuracy at the most aggressive compression ratios for weights (4 bits and below), and/or for models that are sensitive to activation quantization. + +`coreai-opt`'s APIs allow you to easily move from one stage to the next while evaluating accuracy after each stage and escalating to a more expensive workflow only when needed. + +## Getting started + +For an overview of the generic structure of `coreai-opt` APIs, see [How to use coreai-opt](introduction/how_to_use_coreaiopt.md). + +For end-to-end examples on API usage and common workflows, see [MNIST examples](examples/toy_models.md) and [model examples](examples/model_examples.md). + +## Links to related Core AI components + +- **[coreai-torch](https://github.com/apple/coreai-torch)** — Python library for converting PyTorch models to the Core AI (`.aimodel`) format. +- **[coreai-models](https://github.com/apple/coreai-models)** — GitHub repository with example models demonstrating how to convert, optimize, and re-author models for deployment with Core AI. Several of the LLMs in there are compressed to ~4–5 bits using `coreai-opt`. The repo also contains a number of AI skills, including some that wrap `coreai-opt` workflows. +- **[Core AI framework](https://developer.apple.com/documentation/coreai)** — Apple's on-device AI framework that runs `.aimodel` models. diff --git a/docs/src/pruning/basics.md b/docs/src/pruning/basics.md new file mode 100644 index 0000000..231f65c --- /dev/null +++ b/docs/src/pruning/basics.md @@ -0,0 +1,76 @@ +# Basics + +Pruning a model is the process of sparsifying the weight matrices within a model, thereby reducing its storage size by packing weights more efficiently. This can be done by setting a fraction of the values in the model’s weight matrices to zero. + +Pruned weights can be represented more efficiently using a sparse representation rather than the typical dense representation. In a sparse representation, only non-zero values are stored, along with a bit mask, that takes the value 1 at the indices of the non-zero values. For example, if the weight values are + +```text +[0, 7, 0, 0, -3.2, 0, 0, 56.3] +``` + +the sparse representation contains a bit mask with 1s in the locations where the value is non-zero: + +```text +[0, 1, 0, 0, 1, 0, 0, 1] +``` + +This is accompanied by the non-zero data, which in the following example will look like: + +```text +[7, -3.2, 56.3] +``` + +```{figure} images/pruning_magnitude.svg +:alt: Magnitude pruning of a weight tensor and its sparse representation +:align: center +:width: 100% +:class: imgnoborder + +Magnitude pruning zeros out the smallest-magnitude weights and stores the result as a bit mask plus a packed list of non-zero values. +``` + +## Pruning Algorithm + +The pruning algorithm selects which elements to zero out. + +- `MagnitudePruning`: A simple way to pick the elements to zero out is by the magnitude of the element. The magnitude pruner sorts the elements based on the value and zeros out the smallest set of elements up to the `target_sparsity`. + +## Pruning Schemes + +Sparsity can be introduced either in an unstructured way (the 0s introduced follow no pattern) or in a structured way (0s will be grouped together based on a pattern). This can be configured using the `PruningScheme`. + +- `Unstructured`: in this pruning scheme, there is no constraint for the 0s introduced into the tensor. For example, in the case of magnitude pruning with 50% sparsity, the pruner finds the smallest values in the tensor and zeros out half of them, wherever they may be located across the tensor. +- `ChannelStructured`: this pruning scheme constrains the 0s to entire channels (slices along a chosen `axis` of the tensor) — every element within a pruned channel is zeroed together. For example, in the case of magnitude pruning with 50% sparsity, channels are ranked by their L1 norm and the half with the smallest norms are zeroed out across all of their elements, while the other half are kept intact. The realized sparsity is rounded down to the nearest multiple of `1/num_channels`. + +```{figure} images/pruning_schemes.svg +:alt: Comparison of unstructured and channel-structured pruning at 20% sparsity +:align: center +:width: 100% +:class: imgnoborder + +Unstructured pruning zeros individual cells anywhere in the tensor, while channel-structured pruning zeros entire channels together. Both reach the same overall 20% sparsity. +``` + +## Pruning Schedule + +When the sparsity is applied to the module, it introduces error into the module as a portion of the weight values are no longer contributing to the model's output. In models that are sensitive to sparsity, it might help to apply the sparsity in an incremental manner while fine-tuning the model to adapt to the sparsification. The Pruning Schedule allows applying sparsity based on a certain schedule. + +- `ConstantSparsitySchedule`: This is a simple schedule which mimics a step function. Up to `begin_step` the sparsity is 0%. Starting from `begin_step`, the schedule applies the entire `target_sparsity` to the model. This is a good first step to check how the model behaves with sparsity. For robust models and smaller amounts of sparsity, this works well and is recommended. +- `PolynomialDecaySchedule`: This schedule applies the sparsity based on a polynomial function which can be configured. The sparsity at step `s` within the schedule window is + +```text +sparsity(s) = s_target + (s_initial − s_target) · (1 − t)^power + +where t = (s − begin_step) / (total_iters − 1) +``` + +Starting from `begin_step`, the schedule incrementally applies the sparsity in increments of `update_frequency` up till the `target_sparsity`, following the polynomial described by the polynomial exponent `power` until `total_iters` is reached. Beyond `total_iters`, it will maintain the `target_sparsity`. + +```{figure} images/pruning_schedule.svg +:alt: Comparison of ConstantSparsitySchedule and PolynomialDecaySchedule over training steps +:align: center +:width: 100% +:class: imgnoborder + +Sparsity over training steps under each schedule, both targeting 70% sparsity. The constant schedule jumps from 0% to the target at `begin_step`; the polynomial schedule ramps up smoothly with a slow start (power = 3) before plateauing at the target. +``` diff --git a/docs/src/pruning/config.md b/docs/src/pruning/config.md new file mode 100644 index 0000000..bea0835 --- /dev/null +++ b/docs/src/pruning/config.md @@ -0,0 +1,181 @@ +# Config API + +Pruning Configs follow the same philosophy as the [Palettization Config](../palettization/config.md). +They are simpler as pruning applies only to the weights in the model. +(Hence there are no `op_input_spec` and `op_output_spec` fields in the {class}`~coreai_opt.pruning.config.ModuleMagnitudePrunerConfig` and {class}`~coreai_opt.pruning.config.OpMagnitudePrunerConfig`.) + +## PruningSpec + +{class}`~coreai_opt.pruning.spec.PruningSpec` defines the following key properties (for full list see API reference): + +- `target_sparsity`: Fraction of elements to zero, in `[0, 1]`. Default: 0.5. +- `pruning_scheme`: Structural pattern of sparsity. Allowed: {class}`~coreai_opt.pruning.spec.Unstructured`() or {class}`~coreai_opt.pruning.spec.ChannelStructured`(axis=...), defaults to the former. + +```python +from coreai_opt.pruning import PruningSpec +from coreai_opt.pruning.spec import ( + ChannelStructured, + default_weight_pruning_spec, +) + +# 50% unstructured magnitude pruning (default) +spec = default_weight_pruning_spec() + +# 75% unstructured +spec = PruningSpec(target_sparsity=0.75) + +# 50% channel-structured along axis 0 — entire channels are pruned together +spec = PruningSpec( + target_sparsity=0.5, + pruning_scheme=ChannelStructured(axis=0), +) +``` + +:::{note} +**Realized sparsity for `ChannelStructured`**: + +Channel-structured pruning prunes whole channels along `axis`, so the realized sparsity is rounded down to the nearest multiple of `1/num_channels`. For `num_channels=10` and `target_sparsity=0.5`, exactly 5 channels are pruned and the realized sparsity matches the target. For `num_channels=7` and the same target sparsity, only 3 channels are pruned, giving 3/7 ≈ 43% realized sparsity. `Unstructured` rounds at the element level, so this is only a concern for `ChannelStructured`. +::: + +## Config classes and their defaults + +The pruning config system mirrors palettization's three-class hierarchy: + +- {class}`~coreai_opt.pruning.MagnitudePrunerConfig` — the top-level config for the entire model. It holds a `global_config`, plus optional `module_type_configs` and `module_name_configs` for overrides. Same precedence as palettization: `module_name_configs` > `module_type_configs` > `global_config`. + +- {class}`~coreai_opt.pruning.ModuleMagnitudePrunerConfig` — controls pruning for all ops within a module's scope (or all modules if used as a `global_config`). Like {class}`~coreai_opt.palettization.config.ModuleKMeansPalettizerConfig`, it specifies a default `op_state_spec` for ops in the module and allows overrides via `op_type_config`, `op_name_config`, and `module_state_spec`. For a given op's weight, the spec is resolved in this priority order (highest first): `module_state_spec`, the matching entry in `op_name_config`, the matching entry in `op_type_config`, then the module's `op_state_spec`. It also exposes a `sparsity_schedule` field — when set, `pruner.step()` ramps sparsity over training (see [Pruning with Fine-Tuning](overview.md#pruning-with-fine-tuning)); when unset, the spec's `target_sparsity` is applied statically. + +- {class}`~coreai_opt.pruning.config.OpMagnitudePrunerConfig` — controls pruning for a specific op type or op name. Only `op_state_spec` is used. + +### Default behavior when no arguments are provided + +Creating any of these config classes with no arguments gives you a ready-to-use **50% unstructured magnitude pruning** configuration: + +```python +# All three of these produce equivalent default pruning settings: +config = MagnitudePrunerConfig() +# is equivalent to: +config = MagnitudePrunerConfig(global_config=ModuleMagnitudePrunerConfig()) +# which is equivalent to: +config = MagnitudePrunerConfig( + global_config=ModuleMagnitudePrunerConfig( + op_state_spec={ + "weight": default_weight_pruning_spec(), + "in_proj_weight": default_weight_pruning_spec(), + }, + ) +) + +op_config = OpMagnitudePrunerConfig() +# is equivalent to: +op_config = OpMagnitudePrunerConfig( + op_state_spec={ + "weight": default_weight_pruning_spec(), + "in_proj_weight": default_weight_pruning_spec(), + }, +) +``` + +- The default applies `default_weight_pruning_spec()` — 50% target sparsity, unstructured, magnitude-based — to parameters named `"weight"` and `"in_proj_weight"`. Other state tensors (e.g., `"bias"`) are left uncompressed. + +- If you need different behavior — such as pruning custom parameter names, excluding certain modules, or applying different sparsity targets to different layers — see the [Examples](#examples) section below. + +## Examples + +Several examples below configure specific module types or module names. To determine these for your model, see [How to get names + types](../quantization/config.md#how-to-get-names--types-for-modules-and-ops). Since pruning only supports eager execution mode, only the eager mode guidance in that section is relevant. + +### Apply 50% pruning globally, 75% to linear layers + +Apply 50% magnitude pruning to all supported layers, and override `linear` layers to 75%. + +```python +# programmatic +import torch.nn as nn +from coreai_opt.pruning import ( + MagnitudePrunerConfig, + ModuleMagnitudePrunerConfig, + PruningSpec, +) + +# 50% on all supported layers globally (the default) +config = MagnitudePrunerConfig() + +# override Linear layers to 75% +config.set_module_type( + nn.Linear, + ModuleMagnitudePrunerConfig( + op_state_spec={"weight": PruningSpec(target_sparsity=0.75)}, + ), +) +``` + +The snippet above applies 50% pruning globally (covering Conv2d and all other supported modules), then overrides Linear layers to 75%. + +#### Config chaining + +The setters also return the config itself, so multiple modifications can be chained into a single expression. The snippet above is equivalent to: + +```python +config = MagnitudePrunerConfig().set_module_type( + nn.Linear, + ModuleMagnitudePrunerConfig( + op_state_spec={"weight": PruningSpec(target_sparsity=0.75)}, + ), +) +``` + +```yaml +# yaml +magnitude_pruning_config: + global_config: + op_state_spec: + weight: + target_sparsity: 0.5 + pruning_scheme: { type: unstructured } + module_type_configs: + torch.nn.modules.linear.Linear: + op_state_spec: + weight: + target_sparsity: 0.75 + pruning_scheme: { type: unstructured } +``` + +### Apply pruning to specific module types only + +When you want to prune only specific module types and leave everything else uncompressed, construct the config explicitly without a `global_config`. Each module type gets its own `ModuleMagnitudePrunerConfig`, and modules not listed in `module_type_configs` are skipped. + +```python +# programmatic — explicit (scoped to specific module types) +from coreai_opt.pruning import ( + MagnitudePrunerConfig, + ModuleMagnitudePrunerConfig, + PruningSpec, +) + +config = MagnitudePrunerConfig( + module_type_configs={ + "torch.nn.modules.linear.Linear": ModuleMagnitudePrunerConfig( + op_state_spec={"weight": PruningSpec(target_sparsity=0.75)}, + ), + "torch.nn.modules.conv.Conv2d": ModuleMagnitudePrunerConfig( + op_state_spec={"weight": PruningSpec(target_sparsity=0.5)}, + ), + }, +) +``` + +```yaml +# yaml +magnitude_pruning_config: + module_type_configs: + torch.nn.modules.linear.Linear: + op_state_spec: + weight: + target_sparsity: 0.75 + pruning_scheme: { type: unstructured } + torch.nn.modules.conv.Conv2d: + op_state_spec: + weight: + target_sparsity: 0.5 + pruning_scheme: { type: unstructured } +``` diff --git a/docs/src/pruning/images/pruning_magnitude.svg b/docs/src/pruning/images/pruning_magnitude.svg new file mode 100644 index 0000000..7ec0199 --- /dev/null +++ b/docs/src/pruning/images/pruning_magnitude.svg @@ -0,0 +1,134 @@ + + + + + + + + + + + + +Float weights + + + +2.1 + +0.1 + +2.5 + +3.6 + + +0.6 + +3.2 + +0.1 + +0.3 + + +0.4 + +0.2 + +3.5 + +1.2 + + +Prune + + + + +Clamped weights + + + +2.1 + +0.0 + +2.5 + +3.6 + + +0.0 + +3.2 + +0.0 + +0.0 + + +0.0 + +0.0 + +3.5 + +1.2 + + +Store + + + + +Bit mask & non-zero floats + + + + +1 + +0 + +1 + +1 + + + +0 + +1 + +0 + +0 + + + +0 + +0 + +1 + +1 + + + +2.1 + +2.5 + +3.6 + +3.2 + + + +3.5 + +1.2 + + diff --git a/docs/src/pruning/images/pruning_schedule.svg b/docs/src/pruning/images/pruning_schedule.svg new file mode 100644 index 0000000..80a8edb --- /dev/null +++ b/docs/src/pruning/images/pruning_schedule.svg @@ -0,0 +1,62 @@ + + + + + + + + + + + + + + 0.0 + + 0.2 + + 0.4 + + 0.6 + + 0.8 + + 1.0 + + + + + + 0 + + 20 + + 40 + + 60 + + 80 + + 100 + + + +Training step +Sparsity + + + + + + + + + +Schedules (target_sparsity = 0.7) + + +ConstantSparsitySchedule(begin_step=30) + + +PolynomialDecaySchedule(begin_step=10, total_iters=80, power=3.0) + diff --git a/docs/src/pruning/images/pruning_schemes.svg b/docs/src/pruning/images/pruning_schemes.svg new file mode 100644 index 0000000..672428b --- /dev/null +++ b/docs/src/pruning/images/pruning_schemes.svg @@ -0,0 +1,141 @@ + + + + + + + + + +Unstructured + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +Individual cells zeroed (target_sparsity = 0.2) + + +Channel-structured + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +Whole channels zeroed along axis=1 (target_sparsity = 0.2) + + diff --git a/docs/src/pruning/index.md b/docs/src/pruning/index.md new file mode 100644 index 0000000..c192b91 --- /dev/null +++ b/docs/src/pruning/index.md @@ -0,0 +1,9 @@ +# Pruning + +```{toctree} +:maxdepth: 2 + +basics +overview +config +``` diff --git a/docs/src/pruning/overview.md b/docs/src/pruning/overview.md new file mode 100644 index 0000000..fc51464 --- /dev/null +++ b/docs/src/pruning/overview.md @@ -0,0 +1,91 @@ +# API Overview + +## Post-training Pruning + +Post-training pruning sparsifies the model in a single shot. The mask is computed during the `prepare()` call and the prepared model immediately reflects the effects of sparsity. + +Unless the original PyTorch model already has a large fraction of weights close to zero across all of its weight parameters, post-training pruning will almost always degrade accuracy. It is most useful as a quick way to evaluate the impact of sparsity on model size and inference latency before committing to a fine-tuning workflow. + +```python +import coreai_opt as opt +from coreai_opt.pruning import MagnitudePruner, MagnitudePrunerConfig +import torch + +model = MyModel().eval() +example_inputs = (torch.randn(1, 3, 224, 224),) + +# Default config: 50% unstructured magnitude pruning on every supported weight. +config = MagnitudePrunerConfig() + +# Apply sparsity to the model. After the prepare API is called, +# the model will have 50% sparsity on every supported weight parameter +pruner = MagnitudePruner(model, config) +prepared_model = pruner.prepare(example_inputs) + +# Validate the model with the effects of sparsity +val_metric = validate(prepared_model, val_dataset) + +# Deployment is similar to the Quantizer. +# We invoke the 'finalize' API to update the PyTorch model and make it compatible +# for conversion with either CoreAI or CoreML + +finalized_model_for_coreai = pruner.finalize(backend=opt.ExportBackend.CoreAI) +# OR +finalized_model_for_coreml = pruner.finalize(backend=opt.ExportBackend.CoreML) +``` + +## Pruning with Fine-Tuning + +In most cases, fine-tuning is required to recover good accuracy after pruning. We can use a sparsity schedule on the module config and call `pruner.step()` to gradually ramp up sparsity over training. + +```python +from coreai_opt.pruning import ( + MagnitudePruner, + MagnitudePrunerConfig, + ModuleMagnitudePrunerConfig, + PruningSpec, +) +from coreai_opt.pruning.config import PolynomialDecaySchedule +import torch + +model = MyModel() +example_inputs = (...,) +num_epochs = 5 + +# 70% target sparsity ramped in via a polynomial schedule over num_epochs. +# Schedule starts at step 0 (sparsity=0) and advances on every pruner.step() +# call along a cubic curve until reaching target_sparsity at step total_iters. +config = MagnitudePrunerConfig( + global_config=ModuleMagnitudePrunerConfig( + op_state_spec={"weight": PruningSpec(target_sparsity=0.7)}, + sparsity_schedule=PolynomialDecaySchedule( + begin_step=0, total_iters=num_epochs, power=3.0 + ), + ), +) +pruner = MagnitudePruner(model, config) +prepared_model = pruner.prepare(example_inputs) + +# ---------- training loop -------------------- +# We fine-tune the model while incrementing the sparsity schedule. +# The pruner.step() API advances the sparsity schedule and recomputes +# the masks against the current weight magnitudes for the next sparsity level. +# The step() API can be called at the epoch frequency or at the batch step frequency +# based on the configuration of the schedule +optimizer = torch.optim.SGD(prepared_model.parameters(), lr=1e-3) +for epoch in range(num_epochs): + prepared_model.train() + for batch, target in train_dataloader: + optimizer.zero_grad() + loss = criterion(prepared_model(batch), target) + loss.backward() + optimizer.step() + pruner.step() + + val_metric = validate(prepared_model, val_dataloader) + +# ----------- deployment ------------------ +# same as before +``` + +For more details on how to use {class}`~coreai_opt.pruning.MagnitudePrunerConfig`, {class}`~coreai_opt.pruning.ModuleMagnitudePrunerConfig` to apply different settings to different weights in the model, see [Pruning Config](config.md). diff --git a/docs/src/tutorials/mnist_palettization.ipynb b/docs/src/tutorials/mnist_palettization.ipynb index 1104070..41a1945 100644 --- a/docs/src/tutorials/mnist_palettization.ipynb +++ b/docs/src/tutorials/mnist_palettization.ipynb @@ -75,6 +75,20 @@ "from torchvision import datasets, transforms" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "select-device", + "metadata": {}, + "outputs": [], + "source": [ + "# Use the MPS (Apple Silicon GPU) backend when available; otherwise fall back to CPU.\n", + "if torch.backends.mps.is_available():\n", + " DEVICE = torch.device(\"mps\")\n", + "else:\n", + " DEVICE = torch.device(\"cpu\")" + ] + }, { "cell_type": "code", "execution_count": 22, @@ -426,7 +440,7 @@ "loss_fn = torch.nn.CrossEntropyLoss()\n", "optimizer = create_adam_optimizer(basic_cnn_model)\n", "\n", - "basic_cnn_model = basic_cnn_model.to(\"mps\")\n", + "basic_cnn_model = basic_cnn_model.to(DEVICE)\n", "\n", "epoch_results = []\n", "for epoch in range(EPOCHS):\n", diff --git a/docs/src/tutorials/mnist_palettization_and_activation_quantization.ipynb b/docs/src/tutorials/mnist_palettization_and_activation_quantization.ipynb index 52eb387..9dee87c 100644 --- a/docs/src/tutorials/mnist_palettization_and_activation_quantization.ipynb +++ b/docs/src/tutorials/mnist_palettization_and_activation_quantization.ipynb @@ -73,6 +73,20 @@ "from torchvision import datasets, transforms" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "select-device", + "metadata": {}, + "outputs": [], + "source": [ + "# Use the MPS (Apple Silicon GPU) backend when available; otherwise fall back to CPU.\n", + "if torch.backends.mps.is_available():\n", + " DEVICE = torch.device(\"mps\")\n", + "else:\n", + " DEVICE = torch.device(\"cpu\")" + ] + }, { "cell_type": "code", "execution_count": 21, @@ -435,7 +449,7 @@ "loss_fn = torch.nn.CrossEntropyLoss()\n", "optimizer = create_adam_optimizer(basic_cnn_model)\n", "\n", - "basic_cnn_model = basic_cnn_model.to(\"mps\")\n", + "basic_cnn_model = basic_cnn_model.to(DEVICE)\n", "\n", "epoch_results = []\n", "for epoch in range(EPOCHS):\n", diff --git a/docs/src/tutorials/mnist_quantization.ipynb b/docs/src/tutorials/mnist_quantization.ipynb index 957cc63..b792ce3 100644 --- a/docs/src/tutorials/mnist_quantization.ipynb +++ b/docs/src/tutorials/mnist_quantization.ipynb @@ -77,6 +77,20 @@ "from torchvision import datasets, transforms" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "select-device", + "metadata": {}, + "outputs": [], + "source": [ + "# Use the MPS (Apple Silicon GPU) backend when available; otherwise fall back to CPU.\n", + "if torch.backends.mps.is_available():\n", + " DEVICE = torch.device(\"mps\")\n", + "else:\n", + " DEVICE = torch.device(\"cpu\")" + ] + }, { "cell_type": "code", "execution_count": 24, @@ -428,7 +442,7 @@ "loss_fn = torch.nn.CrossEntropyLoss()\n", "optimizer = create_adam_optimizer(basic_cnn_model)\n", "\n", - "basic_cnn_model = basic_cnn_model.to(\"mps\")\n", + "basic_cnn_model = basic_cnn_model.to(DEVICE)\n", "\n", "epoch_results = []\n", "for epoch in range(EPOCHS):\n", @@ -856,7 +870,7 @@ "QAT_EPOCHS = 5\n", "qat_optimizer = create_adam_optimizer(wa_prepared, lr=1e-4)\n", "\n", - "wa_prepared.to(\"mps\")\n", + "wa_prepared.to(DEVICE)\n", "\n", "for epoch in range(QAT_EPOCHS):\n", " with wa_quantizer.training_mode():\n", diff --git a/docs/src/utils/activation_comparison.md b/docs/src/utils/activation_comparison.md new file mode 100644 index 0000000..0a6ae57 --- /dev/null +++ b/docs/src/utils/activation_comparison.md @@ -0,0 +1,313 @@ +# Comparing Intermediate Activations + +A common debugging question after preparing a compressed model is: *which layers are most affected by compression?* Comparing intermediate activation tensors of the compressed model against the uncompressed model, on the same inputs, can be a simple yet effective technique to identify the most impacted layers and tune the config accordingly (e.g., raise bit-width or skip quantization). + +The recipe is: + +1. Pick a small batch of representative inputs. +2. Run the uncompressed model and capture intermediate activations. +3. Run the prepared (compressed) model on the *same* inputs and capture activations at corresponding points. +4. For each pair of tensors, compute a similarity metric — signal-to-noise ratio (SNR) is a common choice. +5. Sort layers by SNR; the lowest values point to the layers most affected by compression. + +The examples below use quantization (`Quantizer`), but the same recipe applies to any other compressor (e.g., `KMeansPalettizer`). + +The wiring of step 3 differs between [eager mode and graph mode](../quantization/overview.md#two-execution-modes-graph-and-eager) because the prepared model has a different structure in each. The rest of this page walks through both. + +```{note} +The code snippets below use `torch.randn(...)` for brevity, but random inputs drive the model through activation regions it rarely sees in practice, so the resulting SNR can rank layers differently from what real data would produce. Use representative data when applying this technique in practice. +``` + +## SNR helper + +```python +import torch + + +def snr_db(reference: torch.Tensor, noisy: torch.Tensor) -> float: + """SNR in dB between a reference tensor and a noisy approximation of it.""" + signal_power = reference.float().pow(2).mean() + noise_power = (reference.float() - noisy.float()).pow(2).mean() + if noise_power == 0: + return float("inf") + return 10.0 * torch.log10(signal_power / noise_power).item() +``` + +Higher SNR ⇒ smaller compression-induced error. A drop of tens of dB at one layer relative to its neighbors is a strong hint that the layer is sensitive to the configured spec. + +## Example model + +The code on this page uses the following toy `Conv2d → ReLU → Linear` model on MNIST-sized inputs (1×28×28): + +```python +import torch.nn as nn + + +class ToyModel(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(1, 8, 3, padding=1) + self.relu = nn.ReLU() + self.linear = nn.Linear(8 * 28 * 28, 10) + + def forward(self, x): + x = self.relu(self.conv(x)) + return self.linear(x.flatten(1)) +``` + +## Eager mode + +In eager mode, `quantizer.prepare()` returns an `nn.Module` with the same submodule structure as the original — `prepared_model.layer1.conv1` lives at the same path as `model.layer1.conv1`. A forward hook on that path captures the *quantized* output on the prepared model and the *float* output on the original. + +For weights, `prepare()` registers PyTorch parametrizations in-place, so accessing `.weight` on either model after `prepare()` returns the fake-quantized value — not the original float. To retain the float reference, save copies of the weights before calling `prepare()`. + +### Activation collector + +The following helpers capture each named submodule's output tensor using forward hooks: + +```python +from collections import OrderedDict +import torch + + +def make_capture(store: dict[str, torch.Tensor], name: str): + def hook(_module, _inputs, output): + store[name] = output.detach().clone() + + return hook + + +def collect_outputs( + target: torch.nn.Module, names: list[str], inputs +) -> dict[str, torch.Tensor]: + store: dict[str, torch.Tensor] = OrderedDict() + handles = [ + target.get_submodule(n).register_forward_hook(make_capture(store, n)) + for n in names + ] + try: + with torch.no_grad(): + target(*inputs) + finally: + for h in handles: + h.remove() + return store +``` + +### Putting it together + +```python +import torch +from coreai_opt.quantization import Quantizer, QuantizerConfig +from coreai_opt.quantization.config import ExecutionMode + +model = ToyModel().eval() +example_inputs = (torch.randn(1, 1, 28, 28),) # Use representative inputs in practice + +# Collect all named submodules for activation capture. +target_names = [name for name, _ in model.named_modules() if name] + +# Save weight copies before prepare — prepare() registers parametrizations in-place, +# making .weight return the fake-quantized value on both models after this point. +original_weights = { + name: p.detach().clone() for name, p in model.named_parameters() if "weight" in name +} + +# Capture float activations before preparing. +fp_acts = collect_outputs(model, target_names, example_inputs) + +# Prepare and run to capture quantized activations. +config = QuantizerConfig(execution_mode=ExecutionMode.EAGER) +quantizer = Quantizer(model, config) +prepared_model = quantizer.prepare(example_inputs) +prepared_model.eval() +q_acts = collect_outputs(prepared_model, target_names, example_inputs) + +# Print weight SNR. +for key, orig_w in original_weights.items(): + module_name = key.rsplit(".weight", 1)[0] + quant_w = prepared_model.get_submodule(module_name).weight + print(f"{key}: SNR = {snr_db(orig_w, quant_w):.2f} dB") + +# Print activation SNR. +for name in target_names: + print(f"{name}: SNR = {snr_db(fp_acts[name], q_acts[name]):.2f} dB") +``` + +Running `ToyModel` with the default INT8 config produces: + +```text +conv.weight: SNR = 47.17 dB +linear.weight: SNR = 48.13 dB +conv: SNR = 38.87 dB +relu: SNR = 38.72 dB +linear: SNR = 37.46 dB +``` + +Biases are not quantized and are omitted (SNR = ∞). Activation SNRs are lower than weight SNRs and reflect accumulated error: each activation output is affected by the quantization of the weights feeding into that op, plus the quantization of its own output. + +## Graph mode + +In graph mode (the default), `quantizer.prepare()` runs `torch.export` and returns an `fx.GraphModule`. The original `nn.Module` hierarchy is flattened into a flat graph of named nodes (`conv2d`, `relu`, `linear`, `linear_weight`, …), and fake-quantize submodules (`activation_post_process_0`, `activation_post_process_1`, …) are inserted around the ops being quantized. + +To compare cleanly, also export the *original* model with `torch.export`. Both graphs then share identical op-node names, so an op's float activation and its quantized counterpart can be looked up by the same key — no `nn.Module` ↔ graph-node mapping needed. However, on the prepared side, an op's *quantized* output sits on the fake-quantize node downstream of the op, not on the op itself. Build the op-to-FQ lookup by walking the prepared graph once. + +```mermaid +%%{init: {"flowchart": {"rankSpacing": 25, "nodeSpacing": 15}}}%% +graph LR + subgraph orig["Original graph (exported)"] + direction TB + ow([conv_weight]) --> oc[conv2d] + ox([x]) --> oc + oc --> or_[relu] + or_ --> of[flatten] + olw([linear_weight]) --> ol[linear] + of --> ol + ol --> oout([output]) + end + + subgraph prep["Prepared graph (W8A8)"] + direction TB + pw([conv_weight]) --> pfq1(FQ_1) --> pc[conv2d] + px([x]) --> pfq0(FQ_0) --> pc + pc --> pr[relu] + pr --> pfq2(FQ_2) --> pf[flatten] + pf --> pfq3(FQ_3) --> pl[linear] + plw([linear_weight]) --> pfq4(FQ_4) --> pl + pl --> pfq5(FQ_5) --> pout([output]) + end + + orig ~~~ prep +``` + +*Left: original exported graph. Right: after W8A8 quantization. Rounded rectangles are FQ nodes (activation_post_process_N), present only in the prepared graph; op node names are identical in both graphs.* + +### Activation collector + +`register_forward_hook` only fires on `call_module` nodes — most ops in `fx.GraphModule` are `call_function` and have no module to attach to. To capture every tensor output by node name, drive the model with [`torch.fx.Interpreter`](https://docs.pytorch.org/docs/stable/fx.html#torch.fx.Interpreter): + +```python +from collections import OrderedDict +import torch + + +class ActivationCollector(torch.fx.Interpreter): + """Run an fx.GraphModule and capture every tensor output, keyed by node name.""" + + def __init__(self, graph_module: torch.fx.GraphModule): + super().__init__(graph_module) + self.captured: dict[str, torch.Tensor] = OrderedDict() + + def run_node(self, n: torch.fx.Node): + out = super().run_node(n) + if isinstance(out, torch.Tensor): + self.captured[n.name] = out.detach().clone() + return out + + +def collect_node_outputs(gm, inputs) -> dict[str, torch.Tensor]: + collector = ActivationCollector(gm) + with torch.no_grad(): + collector.run(*inputs) + return collector.captured +``` + +### Mapping ops to their fake-quantize nodes + +When the prepared graph adds an `activation_post_process` node consuming an op's output, the *quantized* counterpart of that op output lives on the FQ, not on the op itself. Build the lookup with a single walk over the prepared graph: + +```python +def op_to_fq(prepared: torch.fx.GraphModule) -> dict[str, str]: + """Map each op (or weight/input) to the activation_post_process node consuming it.""" + mapping: dict[str, str] = {} + for node in prepared.graph.nodes: + if node.op != "call_module" or "activation_post_process" not in str( + node.target + ): + continue + if node.args: + mapping[node.args[0].name] = node.name + return mapping +``` + +This map covers *everything* the prepared graph fake-quantizes — weights, inputs, and op outputs alike. For a `Conv2d → ReLU → Linear` model with default quantization, the map looks like: + +```python +{ + "x": "activation_post_process_0", # input FQ + "conv_weight": "activation_post_process_1", # weight FQ + "relu": "activation_post_process_2", # output FQ for the conv→relu fused chain + "flatten": "activation_post_process_3", # shared quantizer through flatten + "linear_weight": "activation_post_process_4", # weight FQ + "linear": "activation_post_process_5", # output FQ +} +``` + +Two patterns to notice in this map: + +- **Weight-only configs.** For weight-only configs, the map only has `*_weight` entries — weights are the only things being fake-quantized. Op nodes like `linear` and `conv2d` don't appear, but on the prepared side those ops already consume the fake-quantized weight, so their outputs already reflect the quantization. Just compare them by the same name on both sides. +- **Pattern fusion.** `conv2d` has no entry because graph mode fuses `conv → relu` and places the single output FQ after `relu`, not after `conv2d`. Comparing `conv2d` by the same name on both sides still works — the prepared graph's `conv2d` is computed from fake-quantized inputs and weights, so the SNR there reflects input + weight quantization error. The output-activation quantization error for the fused block shows up one row down, at `relu`. + +### Putting it together + +```python +import torch +from coreai_opt.quantization import Quantizer, QuantizerConfig + +model = ToyModel().eval() +example_inputs = (torch.randn(1, 1, 28, 28),) # Use representative inputs in practice + +# 1. Export the original model so we can interpret it node-by-node, by name. +exported_original = torch.export.export(model, example_inputs).module() + +# 2. Prepare the (graph-mode) compressed model. Same input names; adds FQs. +prepared = Quantizer(model, QuantizerConfig()).prepare(example_inputs) +prepared.eval() + +# 3. Capture every tensor activation in both graphs. +float_acts = collect_node_outputs(exported_original, example_inputs) +quant_acts = collect_node_outputs(prepared, example_inputs) + +# 4. For each op in the float graph, compare against its FQ node on the prepared +# side if one exists, else against the same-named node. +fq_map = op_to_fq(prepared) +for name, fp in float_acts.items(): + target = fq_map.get(name, name) + if target not in quant_acts or fp.shape != quant_acts[target].shape: + continue + print(f"{name:20s} -> {target:30s} SNR = {snr_db(fp, quant_acts[target]):.2f} dB") +``` + +You get one SNR row per *node* — weights, inputs, and op outputs alike. + +Running `ToyModel` with the default INT8 config produces: + +```text +conv_weight -> activation_post_process_1 SNR = 47.17 dB +conv_bias -> conv_bias SNR = inf dB +linear_weight -> activation_post_process_4 SNR = 48.13 dB +linear_bias -> linear_bias SNR = inf dB +x -> activation_post_process_0 SNR = 43.20 dB +conv2d -> conv2d SNR = 42.40 dB +relu -> activation_post_process_2 SNR = 38.94 dB +flatten -> activation_post_process_3 SNR = 38.94 dB +linear -> activation_post_process_5 SNR = 35.74 dB +``` + +Biases report SNR = ∞ — they are not quantized and compare identically on both sides. `conv2d` (42.40 dB) sits higher than `relu` (38.94 dB) because it has no downstream FQ; its error comes only from the fake-quantized input and weights, while `relu` additionally incurs output-activation quantization from FQ_2. `relu` and `flatten` share the same SNR because `flatten` is a reshape over identical values backed by a shared quantizer. `linear` has the lowest SNR, accumulating error from both its weight and the quantized activation arriving from `flatten`. + +## Acting on the results + +Two patterns are common in the SNR table: + +**Sudden drop at a single layer.** A sharp fall of tens of dB at one op — often visible on a weight or output FQ node — indicates that layer is particularly sensitive to the configured spec. Target it directly: + +- **Skip the layer.** Set its entry to `None` to leave it at full precision — `module_name_configs={"...": None}`, `module_type_configs={...: None}`, `op_name_config={"...": None}`, or `op_type_config={"...": None}`. `QuantizerConfig.without(...)` is the convenience shortcut for the module-level case. See [Skip quantization for a specific layer type](../quantization/config.md#example-skip-quantization-for-a-specific-layer-type). +- **Raise its precision.** Replace the spec at that scope with a higher-bit dtype or a finer granularity (e.g., `PerBlockGranularity` instead of `PerChannelGranularity`). [Apply different configs to different module types](../quantization/config.md#example-apply-different-configs-to-different-module-types) shows the chaining pattern. + +**Gradual drift across a sequence of layers.** When SNR declines steadily over several consecutive layers with no obvious single culprit, quantization error accumulates as it propagates through the network. Two approaches can help: + +- **Skip the first few layers in the declining run.** Leaving the earliest affected layers at full precision prevents the initial error from accumulating downstream. +- **Raise precision across the region.** Apply a higher-bit or finer-granularity spec to all layers in the sequence. This limits how much error each layer contributes. + +Re-run the same comparison after each config change to confirm the SNR at the targeted layers improved without regressing elsewhere. diff --git a/docs/tests/test_tutorials.py b/docs/tests/test_tutorials.py new file mode 100644 index 0000000..32314a1 --- /dev/null +++ b/docs/tests/test_tutorials.py @@ -0,0 +1,65 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + + +"""Test that tutorial notebooks execute without errors.""" + +from __future__ import annotations + +from pathlib import Path + +import papermill as pm +import pytest + +from coreai_opt._utils.repo_utils import find_repo_root + +NOTEBOOK_CELL_TIMEOUT_SECONDS = 300 + +# Notebooks whose filename contains this token must export deployment models. +MNIST_TOKEN = "mnist" +EXPECTED_MNIST_EXPORTS = [ + "exported_model.aimodel", +] + +_repo_root = find_repo_root(__file__) +_tutorials_dir = _repo_root / "docs" / "src" / "tutorials" +_notebooks = sorted(_tutorials_dir.glob("*.ipynb")) + + +def _notebook_id(path: Path) -> str: + return path.stem + + +def test_tutorials_dir_is_non_empty() -> None: + """Guard against an empty parametrize set silently producing zero tests.""" + assert _notebooks, f"No tutorial notebooks found under {_tutorials_dir}" + + +@pytest.mark.parametrize("notebook", _notebooks, ids=_notebook_id) +def test_tutorial_notebook_executes(notebook: Path, tmp_path: Path) -> None: + """Execute a tutorial notebook end-to-end with papermill and verify outputs. + + ``SAVE_DIRECTORY`` is injected as a pytest ``tmp_path`` so the notebook + writes its dataset and exported models into a temporary directory rather + than the source tree. Any notebook whose filename contains "mnist" must + export both ``exported_model.aimodel`` and ``exported_model.mlpackage``. + """ + pm.execute_notebook( + str(notebook), + str(tmp_path / notebook.name), + parameters={"SAVE_DIRECTORY": str(tmp_path)}, + kernel_name="python3", + execution_timeout=NOTEBOOK_CELL_TIMEOUT_SECONDS, + ) + + if MNIST_TOKEN not in notebook.stem: + return + + for name in EXPECTED_MNIST_EXPORTS: + export_path = tmp_path / name + assert export_path.exists(), ( + f"MNIST notebook {notebook.name} did not produce expected export: " + f"{name} (looked in {tmp_path})" + ) diff --git a/pyproject.toml b/pyproject.toml index 33537b1..cd46b21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ dev = [ ] test = [ "junitparser>=4.0.2", - "pytest>=7.0.0", + "pytest>=9.0.0", "pytest-cov>=4.0.0", "pytest-xdist>=3.0.0", ] @@ -89,8 +89,13 @@ docs = [ "sphinx-copybutton==0.5.2", "sphinx-llm==0.3.0", "sphinxcontrib-mermaid==1.0.0", - { include-group = "stable-coreai" }, + { include-group = "coreai" }, ] +# `coreai` group mirrors the `coreai` optional-dependency (extra) so `uv sync` +# installs CoreAI by default (see tool.uv.default-groups). The extra is the +# single source of truth; this group self-references it to avoid duplicating the +# package list. +coreai = [ "coreai-opt[coreai]" ] # Used in CI to force latest mimimum supported torch version # These torch versions must be in bounds of torch versions listed in project dependencies highest_tested_torch = [ @@ -98,12 +103,6 @@ highest_tested_torch = [ "torchao==0.17.0", "torchvision==0.26.0", ] -# Latest CoreAI nightly -latest-coreai = [ - "coreai-core", - "coreai-torch", - "scikit-learn>=1.7.2", -] # Used in CI to force lowest mimimum supported torch version # These torch versions must be in bounds of torch versions listed in project dependencies lowest_tested_torch = [ @@ -133,12 +132,6 @@ pre-commit = [ "toml-sort>=0.24.3", "tomli-w>=1.0.0", ] -# TODO: deduplicate coreai dependencies across groups -stable-coreai = [ - "coreai-core==1.0.0b1", - "coreai-torch==0.4.0", - "scikit-learn>=1.7.2", -] tamm-export = [] # Since torch and torchao are project dependencies, we need to include torchvision in dev # This allows standard `make test` to find torchvision @@ -151,7 +144,7 @@ tutorial = [ "nbconvert>=7.17.1", "papermill>=2.7.0", "torchinfo>=1.8.0", - { include-group = "stable-coreai" }, + { include-group = "coreai" }, { include-group = "torchvision" }, ] @@ -167,7 +160,7 @@ find.where = [ "src" ] # make env group installed by default with uv sync [tool.uv] -default-groups = [ "dev", "stable-coreai" ] +default-groups = [ "dev", "coreai" ] index = [ { explicit = true, name = "pytorch-cpu", url = "https://download.pytorch.org/whl/cpu" }, { explicit = true, name = "pytorch-cu128", url = "https://download.pytorch.org/whl/cu128" }, @@ -188,14 +181,6 @@ conflicts = [ { group = "highest_tested_torch" }, { group = "lowest_tested_torch" }, ], - [ - { group = "stable-coreai" }, - { group = "latest-coreai" }, - ], - [ - { extra = "coreai" }, - { group = "latest-coreai" }, - ], ] [tool.uv.sources] torch = [ @@ -281,6 +266,7 @@ markers = [ "slow: marks tests as slow (deselect with '-m \"not slow\"')", ] norecursedirs = [ ".oss-export" ] +strict = true testpaths = [ "tests" ] [tool.towncrier] diff --git a/scripts/make/install_pre_commit_hooks.sh b/scripts/make/install_pre_commit_hooks.sh index f883f1c..36c080b 100755 --- a/scripts/make/install_pre_commit_hooks.sh +++ b/scripts/make/install_pre_commit_hooks.sh @@ -60,5 +60,5 @@ if ! command -v lychee &>/dev/null; then fi # Configure git hooks. Force-install so the shebang points to the current venv. -uv run --active pre-commit install -f +uv run --no-sync --active pre-commit install -f echo "✓ Pre-commit hooks configured to use $VIRTUAL_ENV" diff --git a/scripts/make/log_versions.py b/scripts/make/log_versions.py index 2f52da6..5d5b6cc 100755 --- a/scripts/make/log_versions.py +++ b/scripts/make/log_versions.py @@ -8,12 +8,11 @@ """Log package versions and Python executable information.""" import sys -from importlib.metadata import PackageNotFoundError, version +from importlib.metadata import PackageNotFoundError, distributions, version -PACKAGES: dict[str, list[str]] = { - "Torch": ["torch", "torchvision", "torchao"], - "CoreAI": ["coreai-core", "coreai-torch"], -} +TORCH_PACKAGES: list[str] = ["torch", "torchvision", "torchao"] + +COREAI_NAME_SUBSTRING = "coreai" def _get_version(pkg: str) -> str: @@ -23,14 +22,32 @@ def _get_version(pkg: str) -> str: return "not installed" +def _find_coreai_versions() -> dict[str, str]: + """Return {name: version} for every installed distribution whose name contains 'coreai'.""" + found: dict[str, str] = {} + for dist in distributions(): + name = dist.name + if name and COREAI_NAME_SUBSTRING in name.lower(): + found[name] = dist.version + return dict(sorted(found.items())) + + def main() -> None: print("=== Python ===") print(f"Python version: {sys.version}") print(f"Python executable: {sys.executable}") - for section, packages in PACKAGES.items(): - print(f"=== {section} ===") - for pkg in packages: - print(f"{pkg}: {_get_version(pkg)}") + + print("=== Torch ===") + for pkg in TORCH_PACKAGES: + print(f"{pkg}: {_get_version(pkg)}") + + print("=== CoreAI ===") + coreai_versions = _find_coreai_versions() + if coreai_versions: + for name, pkg_version in coreai_versions.items(): + print(f"{name}: {pkg_version}") + else: + print("no coreai packages installed") if __name__ == "__main__": diff --git a/scripts/make/print_api_list.py b/scripts/make/print_api_list.py index cc5b697..344c327 100755 --- a/scripts/make/print_api_list.py +++ b/scripts/make/print_api_list.py @@ -20,7 +20,7 @@ # Inspect a single module (dotted name or file path): python scripts/make/print_api_list.py coreai_opt.quantization.spec.spec - python scripts/make/print_api_list.py src/coreai_opt/quantization/spec/spec.py + python scripts/make/print_api_list.py path/to/module.py make api-list MODULE=coreai_opt.quantization.spec.spec """ diff --git a/scripts/make/run_tests_on_latest_coreai.sh b/scripts/make/run_tests_on_latest_coreai.sh deleted file mode 100755 index ba43202..0000000 --- a/scripts/make/run_tests_on_latest_coreai.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/usr/bin/env bash - -# Copyright 2026 Apple Inc. -# -# Use of this source code is governed by a BSD-3-Clause license that can -# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause - -# Run tests inside the latest-CoreAI virtual environment. -# -# Usage: -# ./run_tests_on_latest_coreai.sh --path tests/export/ -# ./run_tests_on_latest_coreai.sh --path tests/coreai_utils/ --marker "not slow" -# -# All arguments are forwarded to run_tests.sh (--path is required). - -set -euo pipefail - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" -VENV_PATH="$REPO_ROOT/.venv_latest_coreai" - -echo "Running tests with latest CoreAI..." - -# shellcheck source=/dev/null -source "$VENV_PATH/bin/activate" -(cd "$REPO_ROOT" && uv run --active python scripts/make/log_versions.py) -"$SCRIPT_DIR/run_tests.sh" "$@" - -echo "All tests passed!" diff --git a/scripts/make/setup_env.sh b/scripts/make/setup_env.sh index 8007922..e4af6e4 100755 --- a/scripts/make/setup_env.sh +++ b/scripts/make/setup_env.sh @@ -85,6 +85,20 @@ if [[ -z "$AVAILABLE_GROUPS" ]]; then exit 1 fi +group_torch_pin() { + local group="$1" + awk -v group="$group" ' + $0 ~ "^" group "[[:space:]]*=[[:space:]]*\\[" { in_group = 1; next } + in_group && /^\]/ { exit } + in_group && match($0, /"torch[[:space:]]*==[[:space:]]*[0-9][0-9.]*/) { + version = substr($0, RSTART, RLENGTH) + sub(/"torch[[:space:]]*==[[:space:]]*/, "", version) + print version + exit + } + ' "$PYPROJECT_TOML" +} + # Parse command line arguments VENV=".venv" PYTHON_VERSION="" @@ -94,10 +108,8 @@ ALL_GROUPS=false ENSURE_MODE=false # Groups excluded from --all-groups due to mutual conflicts in pyproject.toml. -# stable-coreai is omitted because it's in default-groups; use --without-stable-coreai -# explicitly when a conflicting group (e.g. lowest_tested_torch) is requested. # tamm-export is omitted because it's opt-in only (never in default-groups or --all-groups). -CONFLICTING_GROUPS=("highest_tested_torch" "latest-coreai" "lowest_tested_torch") +CONFLICTING_GROUPS=("highest_tested_torch" "lowest_tested_torch") show_help() { echo "Usage: $0 [OPTIONS]" @@ -108,7 +120,7 @@ show_help() { echo " --venv Virtual environment name (default: .venv)" echo " --python-version Python version (required)" echo " --with- Install additional dependency group (e.g., --with-docs, --with-turi)" - echo " --without- Exclude a default dependency group (e.g., --without-stable-coreai)" + echo " --without- Exclude a default dependency group (e.g., --without-coreai)" echo " --all-groups Install all non-conflicting dependency groups" echo " --ensure Quick check mode: skip setup if venv exists and deps are present" echo " --help Show this help message" @@ -244,13 +256,23 @@ fi # --ensure mode: skip setup if venv exists and required deps are already installed. # This is the fast path called by Make targets to avoid re-running full setup. if [[ "$ENSURE_MODE" == "true" ]] && [ -f "$VENV/bin/python" ]; then - # Build a single Python command that checks all sentinel imports at once IMPORT_STMTS="import pytest" if [[ ${#EXTRA_GROUPS[@]} -gt 0 ]]; then for GROUP in "${EXTRA_GROUPS[@]}"; do case "$GROUP" in docs) IMPORT_STMTS+="; import sphinx" ;; - highest_tested_torch | lowest_tested_torch) IMPORT_STMTS+="; import torchao" ;; + highest_tested_torch | lowest_tested_torch) + IMPORT_STMTS+="; import torchao" + EXPECTED_TORCH="$(group_torch_pin "$GROUP")" + # These groups always pin torch, so an empty result means the + # pyproject parse regressed — fail loudly instead of silently + # skipping the version check (which would reintroduce the bug). + if [[ -z "$EXPECTED_TORCH" ]]; then + echo "Error: could not parse a torch pin for group '$GROUP' in $PYPROJECT_TOML" >&2 + exit 1 + fi + IMPORT_STMTS+="; import torch; assert torch.__version__.split('+')[0] == '$EXPECTED_TORCH'" + ;; rio) IMPORT_STMTS+="; import turi_lightning" ;; tamm-export) IMPORT_STMTS+="; import tamm_export" ;; esac @@ -261,7 +283,16 @@ if [[ "$ENSURE_MODE" == "true" ]] && [ -f "$VENV/bin/python" ]; then exit 0 fi - # Deps missing — fall through to full setup + # Deps missing or pinned torch mismatch — fall through to full setup. + # If a pinned-torch group expected a specific version, surface the mismatch + # so the rebuild isn't silent. + if [[ -n "${EXPECTED_TORCH:-}" ]]; then + ACTUAL_TORCH="$("$VENV/bin/python" -c \ + "import torch; print(torch.__version__.split('+')[0])" 2>/dev/null || true)" + if [[ -n "$ACTUAL_TORCH" && "$ACTUAL_TORCH" != "$EXPECTED_TORCH" ]]; then + echo "Note: $VENV has torch $ACTUAL_TORCH, expected $EXPECTED_TORCH; rebuilding." >&2 + fi + fi fi echo "==========================================" @@ -298,7 +329,7 @@ elif [[ ${#EXTRA_GROUPS[@]} -gt 0 ]]; then SYNC_CMD+=(--group "$GROUP") done fi -# Apply explicit group exclusions (e.g., --without-stable-coreai) +# Apply explicit group exclusions (e.g., --without-coreai) if [[ ${#EXCLUDE_GROUPS[@]} -gt 0 ]]; then for GROUP in "${EXCLUDE_GROUPS[@]}"; do SYNC_CMD+=(--no-group "$GROUP") diff --git a/scripts/release/release_utils.py b/scripts/release/release_utils.py index 7c03d4f..b274635 100644 --- a/scripts/release/release_utils.py +++ b/scripts/release/release_utils.py @@ -18,7 +18,23 @@ from datetime import UTC, datetime from pathlib import Path -VERSION_FILE_PATH = Path("src") / "coreai_opt" / "_about.py" +# coreai_opt's _about.py lives under `external/src` in the internal repo but at +# the top-level `src` in the exported OSS tree, so resolve against both layouts. +_VERSION_FILE_CANDIDATES = ( + Path("src") / "coreai_opt" / "_about.py", + Path("external") / "src" / "coreai_opt" / "_about.py", +) + + +def _resolve_version_file(repo_root: Path) -> Path: + """Return the path to coreai_opt's ``_about.py`` for either repo layout.""" + for rel in _VERSION_FILE_CANDIDATES: + candidate = repo_root / rel + if candidate.is_file(): + return candidate + checked = ", ".join(str(c) for c in _VERSION_FILE_CANDIDATES) + msg = f"Could not locate coreai_opt/_about.py under {repo_root} (checked {checked})" + raise FileNotFoundError(msg) def get_short_sha() -> str: @@ -55,7 +71,7 @@ def get_dev_release_version(base_version: str) -> str: def get_package_version(repo_root: Path) -> str: """Read package version from _about.py.""" - about_file = repo_root / VERSION_FILE_PATH + about_file = _resolve_version_file(repo_root) spec = importlib.util.spec_from_file_location("_about", about_file) if spec is None or spec.loader is None: msg = f"Could not load module spec from {about_file}" @@ -75,7 +91,7 @@ def write_version(repo_root: Path, version: str) -> None: Raises: RuntimeError: If ``_about.py`` does not contain a ``__version__`` assignment. """ - about = repo_root / VERSION_FILE_PATH + about = _resolve_version_file(repo_root) content = about.read_text(encoding="utf-8") updated = re.sub( r'(__version__\s*=\s*)["\'].*?["\']', diff --git a/src/coreai_opt/_utils/config_utils.py b/src/coreai_opt/_utils/config_utils.py index 2ebe485..1014bfd 100644 --- a/src/coreai_opt/_utils/config_utils.py +++ b/src/coreai_opt/_utils/config_utils.py @@ -4,21 +4,65 @@ # be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause """Utilities for config related items""" + from __future__ import annotations +import logging +from collections.abc import Iterable, Mapping from enum import Enum, auto from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) # Constant representing all tensors for an input/output/state spec ALL_TENSORS = "*" +def get_last_matching_spec( + identifiers: Iterable[int | str], + spec_dict: Mapping[int | str, Any], +) -> tuple[Any, bool]: + """Return the last value in ``spec_dict`` whose key matches any identifier. + + Iterates ``spec_dict`` keys in declaration order, tracking the last match. + Later-declared keys take precedence over earlier ones. Falls back to the + ``"*"`` wildcard key if no specific match is found. Warns when multiple + keys match. Returns ``(value, True)`` on match — value may be ``None`` if + the entry is explicit-None to disable. Returns ``(None, False)`` if no key + matched at all. + """ + identifiers_set = set(identifiers) + matching_keys: list[int | str] = [] + last_value = None + found = False + for key, value in spec_dict.items(): + if key in identifiers_set: + matching_keys.append(key) + last_value = value + found = True + if found: + if len(matching_keys) > 1: + logger.warning( + "Multiple spec keys matched for identifiers %s against spec keys %s: " + "%s. Using the last matching key '%s'.", + list(identifiers_set), + list(spec_dict.keys()), + matching_keys, + matching_keys[-1], + ) + return last_value, True + if ALL_TENSORS in spec_dict: + return spec_dict[ALL_TENSORS], True + return None, False + + def is_yaml_file(file_path: Path) -> bool: """ Returns True if file_path points to a file ending in .yaml or .yml suffix, False otherwise. """ - return file_path.is_file() and file_path.suffix.lower() in ['.yaml', '.yml'] + return file_path.is_file() and file_path.suffix.lower() in [".yaml", ".yml"] class ConfigLevel(Enum): @@ -31,6 +75,7 @@ class ConfigLevel(Enum): - MODULE_TYPE: Applied to specific module types (e.g., all Conv2d) - GLOBAL: Applied to all modules """ + MODULE_NAME = auto() MODULE_TYPE = auto() GLOBAL = auto() diff --git a/src/coreai_opt/_utils/fx_utils.py b/src/coreai_opt/_utils/fx_utils.py new file mode 100644 index 0000000..8f16831 --- /dev/null +++ b/src/coreai_opt/_utils/fx_utils.py @@ -0,0 +1,168 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +"""Utilities for working with torch.fx graphs and nodes.""" + +import logging +import re + +import torch +from torch.fx import Node + +logger = logging.getLogger(__name__) + + +def get_node_type(node: Node, warn_on_failure: bool = True) -> str | None: + """Extract the op type string from an FX node's ``torch_fn`` metadata. + + The ``torch_fn`` metadata entry is a two-element tuple where the second + element encodes the ATen operator in *namespace.op_name* form. This + function returns the *op_name* part (after the dot). + + Args: + node (Node): An FX graph node. + warn_on_failure (bool): If True, log a warning if node type could not be found. + + Returns: + str | None: The op type string, or ``None`` if unavailable. + """ + try: + _, torch_fn = node.meta.get("torch_fn") + return torch_fn.split(".")[1] + except (AttributeError, IndexError, TypeError, ValueError): + if warn_on_failure: + warning_msg = f"Unable to determine node type for node {node.name}. Skipping the node." + logger.warning(warning_msg) + return None + + +def normalize_module_fqn(path: str) -> str: + """Normalize module path from nn_module_stack to match named_modules format. + + Handles various torch.export contexts including decorators (@torch.no_grad, + @wraps), array indexing, and nested _modules['X'] patterns. + + Examples: + "model.layers.0.norm" -> "model.layers.0.norm" + "L['self'].model" -> "model" + "L['fn'].model" -> "model" + "L['args'][0].model.layers[0]" -> "model.layers.0" + "_modules['model']._modules['layers']._modules['0']" -> "model.layers.0" + """ + # Remove torch.export prefixes (self, fn, args[N]) + path = re.sub(r"^(?:L\['(?:self|fn)'\]\.|L\['args'\]\[\d+\]\.)", "", path) + + # Convert _modules['X'] and array indexing [N] to dot notation in one pass + path = re.sub( + r"_modules\['([^']+)'\]|\[(\d+)\]", lambda m: "." + (m.group(1) or m.group(2)), path + ) + + # Collapse multiple dots and strip leading/trailing dots + return re.sub(r"\.+", ".", path).strip(".") + + +def get_local_state_name(state: torch.fx.Node | str) -> str | None: + """Return the local state name (the last dotted component of the state identifier). + + For ``get_attr`` nodes the identifier is ``node.target``; for string inputs + (e.g., an ``OpInfo.op_name``) the string itself is used directly. + Returns ``None`` for ``call_function`` nodes (e.g., ``lut_to_dense``), + which represent already-compressed state and have no traditional parameter name. + + Args: + state (torch.fx.Node | str): An FX ``get_attr`` node, a ``call_function`` + state node, or a state name string. + + Returns: + str | None: The last dotted component, or ``None`` for call_function nodes. + + Example: + >>> get_local_state_name("model.mod1.mod2.weight") + 'weight' + >>> get_local_state_name("model_weight") + 'model_weight' + """ + if isinstance(state, str): + return state.rsplit(".", 1)[-1] + if state.op != "get_attr": + # call_function nodes identified as state (e.g., lut_to_dense from palettization) + # don't have a traditional state name - they are already compressed + return None + return state.target.rsplit(".", 1)[-1] + + +def is_coreai_compressed_state_node(node: Node) -> bool: + """Return True if node represents model state (not a computation). + + A node is state if it is: + + 1. A ``get_attr`` node (model parameter or buffer access). + 2. A ``call_function`` node targeting a coreai state-producing op: + + - ``coreai.lut_to_dense``: palettized weight decompression. + - ``coreai.constexpr_blockwise_shift_scale``: block shift/scale on weights. + - ``coreai.sparse_to_dense``: sparse weight decompression (pruning). + - ``coreai.sparse_with_bitmask_to_dense``: bitmask-based sparse decompression (pruning). + + Note: + Update this function if new coreai ops are introduced that produce state tensors + from compressed representations, or if existing op names change. + + Args: + node (Node): An FX graph node. + + Returns: + bool: True if the node is a state node, False otherwise. + """ + if node.op == "get_attr": + return True + if node.op != "call_function": + return False + target = node.target + if not isinstance(target, torch._ops.OpOverload) or target.namespace != "coreai": + return False + return target._opname in ( + "lut_to_dense", + "constexpr_blockwise_shift_scale", + "sparse_to_dense", + "sparse_with_bitmask_to_dense", + ) + + +def get_module_boundary_nodes( + nodes_in_module: list[Node], +) -> tuple[list[tuple[Node, Node]], list[Node]]: + """Return the input and output boundary nodes for a set of nodes in a module. + + Args: + nodes_in_module (list[Node]): All FX nodes belonging to the module's subtree, + in topological order. + + Returns: + tuple: A pair ``(input_consumer_tuples, output_nodes)`` where: + + - ``input_consumer_tuples``: ``(external_node, consumer_node)`` pairs in which + ``external_node`` is outside the module and ``consumer_node`` (inside the module) + consumes it. State nodes are excluded from ``external_node``. + - ``output_nodes``: Nodes inside the module that have at least one user outside it, + in topological order. + """ + input_consumer_tuples: list[tuple[Node, Node]] = [] + output_nodes: list[Node] = [] + nodes_in_module_set = set(nodes_in_module) + + for node in nodes_in_module: + for input_node in node.all_input_nodes: + if ( + not is_coreai_compressed_state_node(input_node) + and input_node not in nodes_in_module_set + ): + input_consumer_tuples.append((input_node, node)) + for user in node.users: + if user not in nodes_in_module_set: + output_nodes.append(node) + break + + return input_consumer_tuples, output_nodes diff --git a/src/coreai_opt/_utils/insertion/torch_function/modes.py b/src/coreai_opt/_utils/insertion/torch_function/modes.py index 034e5bb..a9ec7a5 100644 --- a/src/coreai_opt/_utils/insertion/torch_function/modes.py +++ b/src/coreai_opt/_utils/insertion/torch_function/modes.py @@ -11,7 +11,6 @@ import logging import re import types -from collections import defaultdict from collections.abc import Callable, Generator, Mapping from typing import Any, Literal, NamedTuple @@ -30,6 +29,7 @@ from .module_boundary_tracker import ModuleBoundaryInfo, ModuleBoundaryTracker from .preregistration_tracker import PreregistrationTracker from .registered_optimizers_tracker import RegisteredOptimizersTracker +from .state_spec_resolver import StateSpecResolver from .types import ( ModuleCompressionComponents, OpCompressionComponents, @@ -39,6 +39,7 @@ any_tensor_optimizable, get_func_base_name, get_func_name, + get_optimizer_from_components_dict, is_optimizable_tensor, normalize_args_kwargs, ) @@ -185,12 +186,11 @@ def __init__( optimization_type_name=optimization_type_name, ) self.traversed_modules: set[nn.Module] = set() - self.states_to_register: dict[ - torch.nn.Parameter | torch.Tensor, CompressionSimulatorBase | None - ] = {} - self.states_to_names = self._get_states_to_names() - self.states_to_modules = self._get_states_to_modules() - self.module_priority_dict = module_priority_dict + self.state_resolver = StateSpecResolver( + model=model, + module_components_dict=module_components_dict, + module_priority_dict=module_priority_dict, + ) self.module_boundary_tracker = ModuleBoundaryTracker() self.preregistration_tracker = PreregistrationTracker() # Set default value to True for @@ -246,52 +246,6 @@ def _fill_module_has_module_spec_dict( module_has_module_spec_dict, ) - def _get_states_to_names(self) -> Mapping[torch.nn.Parameter | torch.Tensor, list[str]]: - """ - Get a dictionary mapping states to all aliases of the state. A state can have - multiple aliases if it is shared by multiple modules in the model. - """ - states_to_names: defaultdict[ - torch.nn.Parameter | torch.Tensor, list[str] - ] = defaultdict(list) - for n, s in itertools.chain( - self.model.named_parameters(remove_duplicate=False), - self.model.named_buffers(remove_duplicate=False) - ): - states_to_names[s].append(n) - return states_to_names - - def _in_states_to_names(self, tensor_or_value: Any) -> bool: - """ - Helper function to check whether tensor_or_value is in states_to_names by first checking - whether tensor_or_value is a tensor. - """ - if not isinstance(tensor_or_value, torch.Tensor): - return False - return tensor_or_value in self.states_to_names - - def _get_states_to_modules( - self, - ) -> Mapping[torch.nn.Parameter | torch.Tensor, list[NamedModule]]: - """ - Get a dictionary mapping state tensors to the modules they are defined in. - - Returns a Mapping where each state tensor maps to a list of (module_name, module) - tuples. A state can belong to multiple modules if it is shared. - """ - states_to_modules: dict[torch.nn.Parameter | torch.Tensor, list[NamedModule]] = {} - - for module_name, module in self.named_modules.items(): - # Check parameters and buffers defined directly in this module (recurse=False) - for state in itertools.chain( - module.parameters(recurse=False), module.buffers(recurse=False) - ): - if state not in states_to_modules: - states_to_modules[state] = [] - states_to_modules[state].append(NamedModule(module_name, module)) - - return states_to_modules - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any: self.remove_hooks() return super().__exit__(exc_type, exc_val, exc_tb) @@ -513,7 +467,7 @@ def _get_module_component_override_if_applicable( # Check for module-level spec override for the corresponding input/output # components dict - override_optimizer, found = self._get_optimizer_from_components_dict( + override_optimizer, found = get_optimizer_from_components_dict( func, [module_boundary.index for module_boundary in module_boundaries], components_dict @@ -575,7 +529,7 @@ def _split_module_boundary_info_list_by_module( def register_all_states(self) -> None: """ - Parametrize all states which appear in self.state_to_register. + Parametrize every state for which self.state_resolver has a cached optimizer. """ states_to_parametrize: list[StateParametrizationInfo] = [] # Go through all modules to identify modules to parametrize. We cannot @@ -586,7 +540,8 @@ def register_all_states(self) -> None: module.named_parameters(recurse=False, remove_duplicate=False), module.named_buffers(recurse=False, remove_duplicate=False) ): - if optimizer := self.states_to_register.get(state): + optimizer = self.state_resolver.get_optimizer(state) + if optimizer: states_to_parametrize.append(StateParametrizationInfo( module, state_name, @@ -623,7 +578,7 @@ def _preregister_input_and_state_optimization( ): components_dict = ( op_compression_components.op_state_components - if self._in_states_to_names(tensor) + if self.state_resolver.is_state_tensor(tensor) else op_compression_components.op_input_components ) if not is_optimizable_tensor(tensor): @@ -632,8 +587,13 @@ def _preregister_input_and_state_optimization( ) continue - if self._in_states_to_names(tensor): - self._create_state_optimizer(func, tensor, components_dict) + if self.state_resolver.is_state_tensor(tensor): + self.state_resolver.resolve( + func, + tensor, + NamedModule(self.current_module_name, self.current_module), + components_dict, + ) else: pending_registration = self._create_pending_activation_registration( func=func, @@ -663,9 +623,9 @@ def _warn_non_quantizable_tensor_setting( Only specific matches will raise a warning. If the user has set "*" to quantize all inputs/outputs/states, no warning will be raised. """ - if self._in_states_to_names(tensor): + if self.state_resolver.is_state_tensor(tensor): setting_type = "state" - tensor_identifiers = self._get_local_state_tensor_names(tensor) + tensor_identifiers = self.state_resolver.get_all_local_names(tensor) elif is_output: setting_type = "output" tensor_identifiers = [idx] @@ -692,105 +652,6 @@ def _warn_non_quantizable_tensor_setting( # looking for other identifiers for the same tensor. return - def _get_local_state_tensor_names( - self, - state_tensor: torch.nn.Parameter | torch.Tensor - ) -> list[str]: - """ - Get the local names associated with a state tensor. - """ - fqns = self.states_to_names[state_tensor] - return [fqn.rsplit(".")[-1] for fqn in fqns] - - def _create_state_optimizer( - self, - func: Callable, - state_tensor: torch.Tensor, - components_dict: Mapping[int | str, _PartialConstructor | None], - ) -> None: - """ - Create optimizer for a state of a function and store it in - self.states_to_register for future registration. - - Check module_state_spec before op_state_spec by looking up which module - defines this state tensor and seeing if it has module_state_components configured. - """ - state_names = self._get_local_state_tensor_names(state_tensor) - - # If state_tensor is already processed (can occur if state_tensor is shared), skip further - # processing. This guarantees the same state tensor to be quantized consistently across the - # entire model. - if state_tensor in self.states_to_register: - return - - # Check for module-level state spec override - modules_for_state = self.states_to_modules.get(state_tensor, []) - # Sort the list of modules to check by priority based on how the original config was - # defined. - sorted_modules_for_state = sorted( - modules_for_state, - key=lambda module_name_and_module: self.module_priority_dict.get( - module_name_and_module.name, float("inf") - ), - ) - - final_optimizer = None - found_component = False - - # Go through each of the sorted modules and check if there is an applicable - # module_state_component matching the current state tensor. Use the first match as the final - # optimizer and stop further processing (the optimizer may be None if the component is None) - for module_name_and_module in sorted_modules_for_state: - module_components = self.module_components_dict.get(module_name_and_module) - - if module_components and module_components.module_state_components: - override_optimizer, found_component = self._get_optimizer_from_components_dict( - func, state_names, module_components.module_state_components - ) - - if found_component: - final_optimizer = override_optimizer - break - - # If no module_state_component was found, fall back to op_state_component - if not found_component: - final_optimizer, _ = self._get_optimizer_from_components_dict( - func, state_names, components_dict - ) - - self.states_to_register[state_tensor] = final_optimizer - - @staticmethod - def _get_optimizer_from_components_dict( - func: Callable, - tensor_identifiers: int | str | list[int | str], - components_dict: Mapping[int | str, _PartialConstructor | None], - ) -> tuple[CompressionSimulatorBase | None, bool]: - """ - Return the appropriate optimizer from components dict. - - The tensor identifier(s) will either be an integer index or a string kwarg - obtained from parsing the args and kwargs of the function. We attempt to match - the identifier with a config entry in components_dict. A match could - either be a direct match with the index or string name, or if the config - was configured with "*" to match all tensors. - - If no direct match or "*" was found, this means the tensor is not meant to be - optimized in which case None is returned. - - Return None if no optimizer was found, or if the components entry is None - to begin with. - """ - if not isinstance(tensor_identifiers, list): - tensor_identifiers = [tensor_identifiers] - - # Try direct matches first, then fall back to wildcard - for identifier in (*tensor_identifiers, "*"): - if identifier in components_dict: - constructor = components_dict[identifier] - return (constructor(op_to_optimize=func), True) if constructor else (None, True) - return (None, False) - def _get_op_compression_components( self, func_name: str, func_type: str, module_component: ModuleCompressionComponents ) -> OpCompressionComponents: @@ -864,9 +725,7 @@ def _create_pending_activation_registration( PendingOptimizerRegistration """ # Get optimizer from components - act_optimizer, _ = self._get_optimizer_from_components_dict( - func, tensor_idx, components_dict - ) + act_optimizer, _ = get_optimizer_from_components_dict(func, tensor_idx, components_dict) # Build optimizer name based on activation type func_name = get_func_name(func, func_counter) diff --git a/src/coreai_opt/_utils/insertion/torch_function/module_boundary_tracker.py b/src/coreai_opt/_utils/insertion/torch_function/module_boundary_tracker.py index 73a5555..222aee3 100644 --- a/src/coreai_opt/_utils/insertion/torch_function/module_boundary_tracker.py +++ b/src/coreai_opt/_utils/insertion/torch_function/module_boundary_tracker.py @@ -154,9 +154,7 @@ def record_module_boundary_tensors( ) def get_module_boundaries_for_tensor( - self, - tensor_counter: int, - boundary_type: Literal["input", "output"] + self, tensor_counter: int, boundary_type: Literal["input", "output"] ) -> list[ModuleBoundaryInfo]: """ For a given tensor identified by its counter, return a list of diff --git a/src/coreai_opt/_utils/insertion/torch_function/registered_optimizers_tracker.py b/src/coreai_opt/_utils/insertion/torch_function/registered_optimizers_tracker.py index effdd3a..f9a3b58 100644 --- a/src/coreai_opt/_utils/insertion/torch_function/registered_optimizers_tracker.py +++ b/src/coreai_opt/_utils/insertion/torch_function/registered_optimizers_tracker.py @@ -24,6 +24,7 @@ class RegisteredOptimizersTracker: This is used during the optimization phase to validate that the same optimizers are applied in the same order as during registration. """ + def __init__(self) -> None: """Initialize an empty registered optimizers tracker.""" self._registered: RegisteredOptimizersDict = {} diff --git a/src/coreai_opt/_utils/insertion/torch_function/state_spec_resolver.py b/src/coreai_opt/_utils/insertion/torch_function/state_spec_resolver.py new file mode 100644 index 0000000..c145642 --- /dev/null +++ b/src/coreai_opt/_utils/insertion/torch_function/state_spec_resolver.py @@ -0,0 +1,212 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + + +"""Priority-aware resolver for state-tensor compression optimizers.""" + +import itertools +from collections.abc import Callable, Mapping +from typing import Any, ClassVar, NamedTuple + +import torch +import torch.nn as nn + +from coreai_opt._utils.spec_utils import PartialConstructor as _PartialConstructor +from coreai_opt._utils.torch_utils import NamedModule +from coreai_opt.config.spec import CompressionSimulatorBase + +from .types import ModuleCompressionComponents +from .utils import get_optimizer_from_components_dict + + +class _StateInventoryEntry(NamedTuple): + """Inventory entry for one state tensor. + + Attributes: + owners (list[NamedModule]): Modules that own this state tensor (one + entry per owning module, even if the module aliases the tensor + under multiple attribute names). + local_names (list[str]): All local attribute names any owning module + uses for this tensor. Aliases under different attributes on the + same module each contribute an entry. + """ + + owners: list[NamedModule] + local_names: list[str] + + +class StateSpecResolver: + """Resolve and cache compression-simulator optimizers for state tensors. + + Owns the model's state inventory and the priority-aware optimizer cache. + + Responsibilities: + + - Identifying which tensors are model states and what local name(s) any + owning module uses for them. + - Resolving the optimizer for a given (state, current call site) pair, + honoring ``module_state_spec`` > ``op_state_spec`` precedence and the + module priority assigned by the quantizer. + - Caching the resolved optimizer with a priority annotation so subsequent + visits from lower-priority modules cannot overwrite a higher-priority + decision. + + The cache invariant: for each state tensor, ``_optimizer_cache`` holds the + optimizer chosen by the highest-priority module that has visited the + tensor so far. ``module_state_spec`` matches are cached at + ``_MODULE_STATE_PRIORITY`` (a sentinel below all ``op_state_spec`` + priorities) so they cannot be overridden by ``op_state_spec`` from any + module. + + Lookups for both ``op_state_spec`` and ``module_state_spec`` consider all + local names the state has across all owning modules. This matches the + pre-refactor behavior of ``_create_state_optimizer``. + """ + + _MODULE_STATE_PRIORITY: ClassVar[int] = -1 + + def __init__( + self, + model: nn.Module, + module_components_dict: Mapping[NamedModule, ModuleCompressionComponents], + module_priority_dict: Mapping[str, int], + ) -> None: + self._module_components_dict = module_components_dict + self._module_priority_dict = module_priority_dict + self._state_inventory = self._build_state_inventory(model) + self._optimizer_cache: dict[torch.Tensor, tuple[CompressionSimulatorBase | None, int]] = {} + + @staticmethod + def _build_state_inventory( + model: nn.Module, + ) -> dict[torch.Tensor, _StateInventoryEntry]: + """Build a map from state tensor to its inventory entry. + + For each tensor reachable as a parameter or buffer of any module, + record the owning modules and every local attribute name they use + for it. + """ + inventory: dict[torch.Tensor, _StateInventoryEntry] = {} + for module_name, module in model.named_modules(): + named_module = NamedModule(module_name, module) + for state_name, state in itertools.chain( + module.named_parameters(recurse=False, remove_duplicate=False), + module.named_buffers(recurse=False, remove_duplicate=False), + ): + if state not in inventory: + inventory[state] = _StateInventoryEntry(owners=[], local_names=[]) + entry = inventory[state] + # Dedupe on insertion to maintain a stable traversal order for + # get_all_local_names. Shared state lists are typically short so + # the O(n) dedup cost is acceptable. + if named_module not in entry.owners: + entry.owners.append(named_module) + if state_name not in entry.local_names: + entry.local_names.append(state_name) + return inventory + + def is_state_tensor(self, value: Any) -> bool: + """Return True iff ``value`` is a Tensor reachable as a parameter or buffer of the model.""" + if not isinstance(value, torch.Tensor): + return False + return value in self._state_inventory + + def get_all_local_names(self, state: torch.Tensor) -> list[str]: + """Return every local attribute name any module uses for ``state``. + + Used both for surfacing warnings when configuration mentions an + apparent state name that the resolver cannot match, and as the + identifier list for ``op_state_spec`` and ``module_state_spec`` + lookups. If the state is not in the inventory, returns an empty list. + """ + entry = self._state_inventory.get(state) + return list(entry.local_names) if entry else [] + + def resolve( + self, + func: Callable, + state_tensor: torch.Tensor, + current_module: NamedModule, + components_dict: Mapping[int | str, _PartialConstructor | None], + ) -> None: + """Resolve and cache the optimizer for ``state_tensor`` at the current call site. + + Algorithm: + + 1. Get the current module's priority. + 2. Skip if a strictly higher-priority spec is already cached for this + tensor (strict ``>`` so equal priorities allow last-writer-wins, + matching the pre-refactor unsorted loop behavior). + 3. On the first visit to ``state_tensor``, walk all owners in priority + order looking for a ``module_state_spec`` match. First match wins + and is cached at ``_MODULE_STATE_PRIORITY`` so it cannot be + overridden by any subsequent ``op_state_spec`` resolution. + We check for whether ``state_tensor`` is present in + ``self._optimizer_cache`` to see if it is the first visit or not. + 4. Otherwise, look up the ``op_state_spec`` using **all** local names + the state has across all owning modules, then cache the result at + the current module's priority. + """ + current_priority = self._module_priority_dict.get(current_module.name, float("inf")) + if self._should_skip(state_tensor, current_priority): + return + + # This block will only be run once the first time the state is encountered. + if state_tensor not in self._optimizer_cache: + optimizer, found = self._resolve_module_state(func, state_tensor) + if found: + self._optimizer_cache[state_tensor] = (optimizer, self._MODULE_STATE_PRIORITY) + return + + optimizer = self._resolve_op_state(func, state_tensor, components_dict) + self._optimizer_cache[state_tensor] = (optimizer, current_priority) + + def get_optimizer(self, state: torch.Tensor) -> CompressionSimulatorBase | None: + """Return the cached optimizer for ``state``, or None if absent or resolved to None.""" + optimizer, _ = self._optimizer_cache.get(state, (None, None)) + return optimizer + + def _should_skip(self, state_tensor: torch.Tensor, current_priority: int) -> bool: + """Return True if a strictly higher-priority result is already cached for + ``state_tensor``. + """ + _, cached_priority = self._optimizer_cache.get(state_tensor, (None, None)) + return cached_priority is not None and current_priority > cached_priority + + def _resolve_module_state( + self, func: Callable, state_tensor: torch.Tensor + ) -> tuple[CompressionSimulatorBase | None, bool]: + """Walk owner modules in priority order, returning the first ``module_state_spec`` match + and a found flag. + """ + local_state_names = self.get_all_local_names(state_tensor) + for named_module in self._priority_sorted_owner_modules(state_tensor): + module_components = self._module_components_dict.get(named_module) + if not module_components or not module_components.module_state_components: + continue + optimizer, found = get_optimizer_from_components_dict( + func, local_state_names, module_components.module_state_components + ) + if found: + return optimizer, True + return None, False + + def _resolve_op_state( + self, + func: Callable, + state_tensor: torch.Tensor, + components_dict: Mapping[int | str, _PartialConstructor | None], + ) -> CompressionSimulatorBase | None: + """Look up the ``op_state_spec`` optimizer using all local names for ``state_tensor``.""" + local_state_names = self.get_all_local_names(state_tensor) + optimizer, _ = get_optimizer_from_components_dict(func, local_state_names, components_dict) + return optimizer + + def _priority_sorted_owner_modules(self, state_tensor: torch.Tensor) -> list[NamedModule]: + """Return the owners of ``state_tensor`` sorted by ascending priority""" + return sorted( + self._state_inventory[state_tensor].owners, + key=lambda nm: self._module_priority_dict.get(nm.name, float("inf")), + ) diff --git a/src/coreai_opt/_utils/insertion/torch_function/utils.py b/src/coreai_opt/_utils/insertion/torch_function/utils.py index af7435d..3111822 100644 --- a/src/coreai_opt/_utils/insertion/torch_function/utils.py +++ b/src/coreai_opt/_utils/insertion/torch_function/utils.py @@ -7,7 +7,7 @@ import itertools import logging -from collections.abc import Callable +from collections.abc import Callable, Mapping from typing import Any import torch @@ -15,6 +15,10 @@ from torch.fx.node import map_aggregate from torch.fx.operator_schemas import create_type_hint, normalize_function +from coreai_opt._utils.config_utils import get_last_matching_spec +from coreai_opt._utils.spec_utils import PartialConstructor as _PartialConstructor +from coreai_opt.config.spec import CompressionSimulatorBase + logger = logging.getLogger(__name__) _OPTIMIZABLE_DTYPES = {torch.float64, torch.float32, torch.float16, torch.bfloat16} @@ -73,6 +77,7 @@ def normalize_args_kwargs( kwargs["self"] = kwargs.pop("input") return args, kwargs + def get_func_base_name(func: Callable) -> str: """ Return the function base name @@ -85,6 +90,7 @@ def get_func_base_name(func: Callable) -> str: return func.__name__ return func_name.rsplit(".", maxsplit=1)[-1] + def get_func_name(func: Callable, func_count: int) -> str: """Return the function name using the base name and func_count. @@ -96,6 +102,7 @@ def get_func_name(func: Callable, func_count: int) -> str: return get_func_base_name(func) return f"{get_func_base_name(func)}_{func_count}" + def is_optimizable_tensor(tensor: Any) -> bool: """ Return True if the tensor is optimizable, False otherwise. @@ -118,3 +125,23 @@ def any_tensor_optimizable(args: list[Any], kwargs: dict[str, Any]) -> bool: Return True if any of the tensors in args and kwargs are optimizable, False otherwise. """ return any(is_optimizable_tensor(tensor) for tensor in itertools.chain(args, kwargs.values())) + + +def get_optimizer_from_components_dict( + func: Callable, + tensor_identifiers: int | str | list[int | str], + components_dict: Mapping[int | str, _PartialConstructor | None], +) -> tuple[CompressionSimulatorBase | None, bool]: + """Return the appropriate optimizer from ``components_dict``. + + Delegates matching to :func:`~coreai_opt._utils.config_utils.get_last_matching_spec`. + Returns ``(optimizer, True)`` on match (optimizer may be ``None`` if the + components-dict entry was explicit-None to disable). Returns + ``(None, False)`` if no identifier matched. + """ + if not isinstance(tensor_identifiers, list): + tensor_identifiers = [tensor_identifiers] + constructor, found = get_last_matching_spec(tensor_identifiers, components_dict) + if found: + return (constructor(op_to_optimize=func), True) if constructor else (None, True) + return None, False diff --git a/src/coreai_opt/_utils/metadata_utils.py b/src/coreai_opt/_utils/metadata_utils.py index b8e3d11..2563920 100644 --- a/src/coreai_opt/_utils/metadata_utils.py +++ b/src/coreai_opt/_utils/metadata_utils.py @@ -95,7 +95,8 @@ def register( # Get dict representation, excluding None values metadata_dict: dict[str, Any] = self.model_dump( - exclude_none=True, exclude={"param_name"}, + exclude_none=True, + exclude={"param_name"}, ) for metadata_key, value in metadata_dict.items(): @@ -129,5 +130,6 @@ def register_version(cls, model: torch.nn.Module) -> None: """ model.register_buffer( - METADATA_VERSION_BUFFER, torch.tensor(METADATA_VERSION_VALUE), + METADATA_VERSION_BUFFER, + torch.tensor(METADATA_VERSION_VALUE), ) diff --git a/src/coreai_opt/_utils/registry_utils.py b/src/coreai_opt/_utils/registry_utils.py index 6121daf..675bac6 100644 --- a/src/coreai_opt/_utils/registry_utils.py +++ b/src/coreai_opt/_utils/registry_utils.py @@ -71,6 +71,39 @@ class ClassRegistryMixin(RegistryMixin): def get_class(cls, key: Any) -> Any: return cls._get_object(key) + @classmethod + def resolve(cls, data: str | type) -> type: + """Resolve a string key or class type against this registry. + + Args: + data (str | type): Either a string key registered in the registry, or a + class type. If a class is given, it is returned unchanged provided + it is one of the registered values. + + Returns: + type: The registered class corresponding to ``data``. + + Raises: + ValueError: If ``data`` is a string not registered as a key, or a class + that is not registered as a value. + """ + if isinstance(data, str): + if data in cls.REGISTRY: + return cls.get_class(data) # type: ignore[no-any-return] + raise ValueError( + f"No class is registered with key: '{data}' " + f"in registry {cls.__name__}. " + f"Available keys: {sorted(cls.list_registry_keys())}" + ) + if data in cls.list_registry_values(): + return data + name = getattr(data, "__name__", data) + available_classes = sorted(c.__name__ for c in cls.list_registry_values()) + raise ValueError( + f"{name} is not a registered class in {cls.__name__}. " + f"Available classes: {available_classes}" + ) + class FunctionRegistryMixin(RegistryMixin): @classmethod diff --git a/src/coreai_opt/_utils/torch_utils.py b/src/coreai_opt/_utils/torch_utils.py index b632817..915891c 100644 --- a/src/coreai_opt/_utils/torch_utils.py +++ b/src/coreai_opt/_utils/torch_utils.py @@ -16,7 +16,6 @@ import torch import torch.nn.utils.parametrize as P -from torch.fx import Node from coreai_opt._utils.version_utils import version_ge as _version_ge @@ -422,55 +421,6 @@ def export_model( ) from e -def get_node_type(node: Node, warn_on_failure: bool = True) -> str | None: - """Extract the op type string from an FX node's ``torch_fn`` metadata. - - The ``torch_fn`` metadata entry is a two-element tuple where the second - element encodes the ATen operator in *namespace.op_name* form. This - function returns the *op_name* part (after the dot). - - Args: - node (Node): An FX graph node. - warn_on_failure (bool): If True, log a warning if node type could not be found. - - Returns: - str | None: The op type string, or ``None`` if unavailable. - """ - try: - _, torch_fn = node.meta.get("torch_fn") - return torch_fn.split(".")[1] - except (AttributeError, IndexError, TypeError, ValueError): - if warn_on_failure: - warning_msg = f"Unable to determine node type for node {node.name}. Skipping the node." - logger.warning(warning_msg) - return None - - -def normalize_module_fqn(path: str) -> str: - """Normalize module path from nn_module_stack to match named_modules format. - - Handles various torch.export contexts including decorators (@torch.no_grad, - @wraps), array indexing, and nested _modules['X'] patterns. - - Examples: - "model.layers.0.norm" -> "model.layers.0.norm" - "L['self'].model" -> "model" - "L['fn'].model" -> "model" - "L['args'][0].model.layers[0]" -> "model.layers.0" - "_modules['model']._modules['layers']._modules['0']" -> "model.layers.0" - """ - # Remove torch.export prefixes (self, fn, args[N]) - path = re.sub(r"^(?:L\['(?:self|fn)'\]\.|L\['args'\]\[\d+\]\.)", "", path) - - # Convert _modules['X'] and array indexing [N] to dot notation in one pass - path = re.sub( - r"_modules\['([^']+)'\]|\[(\d+)\]", lambda m: "." + (m.group(1) or m.group(2)), path - ) - - # Collapse multiple dots and strip leading/trailing dots - return re.sub(r"\.+", ".", path).strip(".") - - def mmap_module_state_dict(module: torch.nn.Module, path: str | PathLike[str]) -> None: """Serialize ``module.state_dict()`` to a safetensors file at ``path`` and reload it via mmap, replacing the module's parameters/buffers with mmap diff --git a/src/coreai_opt/config/__init__.py b/src/coreai_opt/config/__init__.py index af84fc7..92c88c4 100644 --- a/src/coreai_opt/config/__init__.py +++ b/src/coreai_opt/config/__init__.py @@ -4,6 +4,7 @@ # be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause """Configuration and specification modules for coreai_opt.""" + from .compression_config import ( CompressionConfig, ModuleCompressionConfig, diff --git a/src/coreai_opt/coreai_utils/_utils/sparse_utils.py b/src/coreai_opt/coreai_utils/_utils/sparse_utils.py index 8bb3702..d47278b 100644 --- a/src/coreai_opt/coreai_utils/_utils/sparse_utils.py +++ b/src/coreai_opt/coreai_utils/_utils/sparse_utils.py @@ -5,7 +5,6 @@ """Sparsification utilities for Core AI Optimization passes.""" -# TODO: add test enhancements for sparse utils. from __future__ import annotations import logging diff --git a/src/coreai_opt/inspection/__init__.py b/src/coreai_opt/inspection/__init__.py index 90d1364..e345d64 100644 --- a/src/coreai_opt/inspection/__init__.py +++ b/src/coreai_opt/inspection/__init__.py @@ -21,9 +21,19 @@ """ from .model_inspector import ModelInspector -from .types import ModelSummary, ModuleContext, ModuleInfo, OpInfo, SourceFrame +from .types import ( + BoundaryEdge, + InputEdge, + ModelSummary, + ModuleContext, + ModuleInfo, + OpInfo, + SourceFrame, +) __all__ = [ + "BoundaryEdge", + "InputEdge", "ModelInspector", "ModelSummary", "ModuleContext", diff --git a/src/coreai_opt/inspection/_common.py b/src/coreai_opt/inspection/_common.py new file mode 100644 index 0000000..ec4dfba --- /dev/null +++ b/src/coreai_opt/inspection/_common.py @@ -0,0 +1,97 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + + +"""Shared utilities for building module trees from discovered ops.""" + +from .types import ( + ModuleInfo, + OpInfo, +) + +FORWARD_FUNCTION_NAME = "forward" + + +def _get_or_create_child(parent: ModuleInfo, module_name: str, module_type: str) -> ModuleInfo: + """Get an existing child module or create a new one.""" + if module_name not in parent.child_modules: + parent.child_modules[module_name] = ModuleInfo( + module_name=module_name, + module_type=module_type, + child_modules={}, + ops=[], + input_ops={}, + output_ops={}, + ) + return parent.child_modules[module_name] + + +def build_module_tree( + root_module_type: str, + all_ops: list[OpInfo], +) -> ModuleInfo: + """Build a ModuleInfo tree from ops with populated module stacks. + + Each op is attached to the deepest module named in its ``module_stack``. + Ops whose ``module_stack`` consists entirely of root entries (or whose + stack is empty after skipping root entries) are attached to the root. + + Args: + root_module_type (str): Fully-qualified type name of the root module. + all_ops (list[_OpInfo]): All discovered ops. Ops with an empty + ``module_stack`` are excluded from the tree. + + Returns: + _ModuleInfo: The root of the module tree. + """ + root = ModuleInfo( + module_name="", + module_type=root_module_type, + child_modules={}, + ops=[], + input_ops={}, + output_ops={}, + ) + + for op in all_ops: + if op.is_state: + continue + if not op.module_stack: + continue + current = root + for ctx in op.module_stack: + if ctx.module_name == "": + continue + current = _get_or_create_child(current, ctx.module_name, ctx.module_type) + current.ops.append(op) + + return root + + +def filter_module_tree(module: ModuleInfo, keep_op_names: set[str]) -> ModuleInfo: + """Recursively filter a ModuleInfo tree, keeping only matching ops. + + Boundary info (input_ops/output_ops) is preserved from the unfiltered tree. + + Args: + module (_ModuleInfo): The module tree to filter. + keep_op_names (set[str]): Op names to retain in the tree. + + Returns: + _ModuleInfo: A filtered copy with only matching ops. + """ + filtered_children = { + fqn: filter_module_tree(child, keep_op_names) for fqn, child in module.child_modules.items() + } + filtered_ops = [op for op in module.ops if op.op_name in keep_op_names] + + return ModuleInfo( + module_name=module.module_name, + module_type=module.module_type, + child_modules=filtered_children, + ops=filtered_ops, + input_ops=module.input_ops, + output_ops=module.output_ops, + ) diff --git a/src/coreai_opt/inspection/_eager_mode.py b/src/coreai_opt/inspection/_eager_mode.py new file mode 100644 index 0000000..2e1c30d --- /dev/null +++ b/src/coreai_opt/inspection/_eager_mode.py @@ -0,0 +1,576 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + + +"""Eager mode op discovery implementation. + +Runs a forward pass with ``TorchFunctionMode`` interception to discover +operations in an ``nn.Module`` without exporting to a graph. +""" + +import itertools +import linecache +import sys +import weakref +from collections import defaultdict +from collections.abc import Callable +from typing import Any + +import torch +import torch.nn as nn +from torch.overrides import TorchFunctionMode + +from coreai_opt._utils.insertion.torch_function.modes import _is_interceptable_func +from coreai_opt._utils.insertion.torch_function.module_boundary_tracker import ( + TensorIdVersion, +) +from coreai_opt._utils.insertion.torch_function.utils import ( + get_func_base_name, + get_func_name, +) +from coreai_opt._utils.python_utils import fqn as _fqn +from coreai_opt._utils.torch_utils import NamedModule, flatten_tensors_to_list +from coreai_opt.base_model_compressor import _BaseModelCompressor +from coreai_opt.palettization import KMeansPalettizer +from coreai_opt.quantization import Quantizer +from coreai_opt.quantization._eager import EagerQuantizer +from coreai_opt.quantization.config.quantization_config import ExecutionMode + +from ._common import ( + FORWARD_FUNCTION_NAME, + build_module_tree, + filter_module_tree, +) +from .types import ( + BoundaryEdge, + InputEdge, + ModelSummary, + ModuleContext, + ModuleInfo, + OpInfo, + SourceFrame, +) + + +class _EagerOpDiscoveryMode(TorchFunctionMode): + """TorchFunctionMode that discovers ops during a forward pass. + + Args: + model (nn.Module): The model to inspect. + op_type_resolver (Callable[[Callable], str | None] | None): + When provided, maps a torch function to its compressor-defined + op type. A non-None return marks the op as compressor-supported + (for post-hoc filtering) and uses the returned string as op_type. + When None, all ops use base_name as op_type and no filtering + metadata is collected. + """ + + def __init__( + self, + model: nn.Module, + op_type_resolver: Callable[[Callable], str | None] | None = None, + ) -> None: + super().__init__() + self.model = model + self._op_type_resolver = op_type_resolver + self.parents: list[NamedModule] = [] + self.traversed_modules: set[nn.Module] = set() + self.hooks: list[torch.utils.hooks.RemovableHook] = [] + + # Per-module function call counts: module_name → base_name → count + self.func_counts: defaultdict[str, defaultdict[str, int]] = defaultdict( + lambda: defaultdict(int) + ) + + # Discovered ops + self.all_ops: list[OpInfo] = [] + self.supported_op_names: set[str] = set() + self._seen_op_names: set[str] = set() + + # Tensor connectivity: (id, version) → InputEdge (producing op + output slot). + # Tracks activation tensors only; state tensors are tracked separately below. + self._tensor_producers: dict[TensorIdVersion, InputEdge] = {} + + self._states_to_names = { + id(state): name + for name, state in itertools.chain( + self.model.named_parameters(), + self.model.named_buffers(), + ) + } + # State tensors are looked up purely by object id — no version tracking. Version + # is irrelevant for states: they are not produced by ops and their identity is + # stable for the model's lifetime regardless of in-place mutations. + self._state_op_infos: dict[int, OpInfo] = {} + + # Ephemeral OpInfos for tensors that are neither registered states nor produced + # by any intercepted op (e.g. raw tensor attributes, global tensors). These are + # created on demand in _resolve_inputs and keyed by id(tensor). They are NOT added + # to all_ops — they exist only as input references so they appear in the consuming + # op's inputs tuple and therefore in the formatted op inputs display. + self._ephemeral_op_infos: dict[int, OpInfo] = {} + self._ephemeral_counter: int = 0 + + self._module_input_producers: dict[str, list[InputEdge | None]] = {} + self._module_output_producers: dict[str, list[InputEdge | None]] = {} + + # Capture model-level input tensors BEFORE the module-loop hooks so that + # ``_capture_input_tensors`` fires before ``_enter_module("")``, ensuring + # root input tensors are registered in ``_tensor_producers`` by the time + # ``_enter_module("")`` resolves their producers. + self.hooks.append(model.register_forward_pre_hook(self._capture_input_tensors)) + + for name, module in model.named_modules(remove_duplicate=True): + pre_hook = module.register_forward_pre_hook(self._enter_module(name)) + post_hook = module.register_forward_hook(self._exit_module(name), always_call=True) + self.hooks.append(pre_hook) + self.hooks.append(post_hook) + + # Registered after the module-loop hooks so that ``_exit_module("")`` + # fires before output capture. + self.hooks.append( + model.register_forward_hook(self._capture_output_tensors, always_call=True) + ) + + @property + def current_module_name(self) -> str: + return self.parents[-1].name if self.parents else "" + + @property + def current_module(self) -> nn.Module: + return self.parents[-1].module if self.parents else self.model + + def _add_op(self, op_info: OpInfo) -> None: + assert op_info.op_name not in self._seen_op_names, f"duplicate op_name {op_info.op_name}" + self._seen_op_names.add(op_info.op_name) + self.all_ops.append(op_info) + + def _get_or_create_ephemeral(self, tensor_id: int) -> InputEdge: + """Return the ephemeral InputEdge for this tensor id, creating one if needed.""" + op_info = self._ephemeral_op_infos.get(tensor_id) + if op_info is None: + op_info = OpInfo( + op_name=f"untracked_{self._ephemeral_counter}", + op_type=None, + module_stack=(), + source_frames=(), + inputs=(), + outputs={}, + is_state=False, + ) + self._ephemeral_op_infos[tensor_id] = op_info + self._ephemeral_counter += 1 + return InputEdge(op=op_info, output_idx=None) + + def _resolve_boundary_tensor(self, t: torch.Tensor) -> InputEdge | None: + """Resolve a module-boundary tensor to its producer entry. + + Mirrors the three-way lookup in ``_resolve_inputs`` so module boundary + detection is consistent with op-level input detection: + + - Known activation: returns the existing ``InputEdge``. + - Registered state: returns ``None`` (states are filtered in + ``_populate_boundary_ops_eager`` and handled via ``module_state_spec``). + - Untracked tensor (raw attribute, global): creates or reuses an ephemeral + ``OpInfo`` so the tensor appears in ``module inputs`` for the consuming module. + """ + key = TensorIdVersion(id(t), t._version) + entry = self._tensor_producers.get(key) + if entry is not None: + return entry + if id(t) in self._states_to_names: + return None + return self._get_or_create_ephemeral(id(t)) + + def _enter_module(self, name: str) -> Callable: + def hook(module: nn.Module, inputs: Any) -> None: + self.parents.append(NamedModule(name, module)) + if module not in self.traversed_modules: + # First visit only — mirrors the quantizer's enter_module guard. + # Resolve each input tensor now while it is fresh, using the same + # three-way lookup as _resolve_inputs so that untracked tensors + # (e.g. raw attributes passed as forward arguments) appear in + # module inputs rather than being silently dropped. + self._module_input_producers[name] = [ + self._resolve_boundary_tensor(t) for t in flatten_tensors_to_list(inputs) + ] + + return hook + + def _exit_module(self, name: str) -> Callable: + def hook(module: nn.Module, inputs: Any, outputs: Any) -> None: + assert self.parents[-1].name == name + if module not in self.traversed_modules: + # First visit: store the full _TensorProducerEntry (op_info + output_idx) + # so that _populate_boundary_ops_eager can identify which output slot of + # each producing op corresponds to each module output tensor. + self._module_output_producers[name] = [ + self._tensor_producers.get(TensorIdVersion(id(t), t._version)) + for t in flatten_tensors_to_list(outputs) + ] + else: + # Re-traversal: __torch_function__ was skipped so the output tensors were + # never registered in _tensor_producers. Register them now, pointing back + # to the same entries from the first traversal, so that downstream modules + # (e.g. a module consuming this re-traversal's output) can resolve them. + # strict=False: a shared module should always return the same tensor + # count across traversals (same forward signature), but if it doesn't + # we truncate silently rather than crash — the downstream module will + # fall through to creating ephemeral ops, which is recoverable. + for entry, tensor in zip( + self._module_output_producers.get(name, []), + flatten_tensors_to_list(outputs), + strict=False, + ): + if entry is not None: + key = TensorIdVersion(id(tensor), tensor._version) + self._tensor_producers[key] = entry + weakref.finalize(tensor, self._tensor_producers.pop, key, None) + self.parents.pop() + self.traversed_modules.add(module) + + return hook + + def _capture_input_tensors(self, module: nn.Module, inputs: Any) -> None: + """Create placeholder-like ops for each module-level input tensor.""" + for i, tensor in enumerate(flatten_tensors_to_list(inputs)): + op_info = OpInfo( + op_name=f"input_{i}", + op_type=None, + module_stack=(), + source_frames=(), + inputs=(), + outputs={}, + is_state=False, + ) + self._add_op(op_info) + key = TensorIdVersion(id(tensor), tensor._version) + self._tensor_producers[key] = InputEdge(op=op_info, output_idx=None) + weakref.finalize(tensor, self._tensor_producers.pop, key, None) + + def _capture_output_tensors(self, module: nn.Module, inputs: Any, outputs: Any) -> None: + """Create output-like ops for each module-level output tensor.""" + for i, tensor in enumerate(flatten_tensors_to_list(outputs)): + key = TensorIdVersion(id(tensor), tensor._version) + entry = self._tensor_producers.get(key) + op_info = OpInfo( + op_name=f"output_{i}", + op_type=None, + module_stack=(), + source_frames=(), + inputs=(entry,) if entry is not None else (), + outputs={}, + is_state=False, + ) + self._add_op(op_info) + if entry is not None and entry.output_idx is not None: + existing = entry.outputs.get(entry.output_idx, ()) + entry.op.outputs[entry.output_idx] = existing + (op_info,) + + def _get_module_stack(self) -> tuple[ModuleContext, ...]: + return tuple( + ModuleContext(module_name=named_mod.name, module_type=_fqn(type(named_mod.module))) + for named_mod in self.parents + ) + + def _extract_source_frames(self) -> tuple[SourceFrame, ...]: + # inspect.stack() would work here but reads source context for every + # frame on the stack; walking frames manually and calling linecache only + # for the forward() frames we keep is significantly faster on large models. + frames: list[SourceFrame] = [] + frame = sys._getframe() + while frame is not None: + if frame.f_code.co_name == FORWARD_FUNCTION_NAME: + code_context = linecache.getline(frame.f_code.co_filename, frame.f_lineno).strip() + frames.append( + SourceFrame( + filename=frame.f_code.co_filename, + lineno=frame.f_lineno, + function_name=frame.f_code.co_name, + code_context=code_context, + ) + ) + frame = frame.f_back + # Reverse: outermost forward first (matching graph mode order) + return tuple(reversed(frames)) + + def _resolve_inputs( + self, input_tensor_keys: tuple[TensorIdVersion, ...] + ) -> tuple[InputEdge, ...]: + """Look up which previously-recorded ops produced the input tensors. + + Returns an ordered tuple of :class:`InputEdge` objects (duplicates + preserved), each carrying the producing op and its output slot. + """ + input_edges: list[InputEdge] = [] + for key in input_tensor_keys: + if key.id in self._states_to_names: + # State tensor: look up or create by id only — version is irrelevant. + op_info = self._state_op_infos.get(key.id) + if op_info is None: + op_info = OpInfo( + op_name=self._states_to_names[key.id], + op_type=None, + module_stack=(), + source_frames=(), + inputs=(), + outputs={}, + is_state=True, + ) + self._add_op(op_info) + self._state_op_infos[key.id] = op_info + input_edges.append(InputEdge(op=op_info, output_idx=None)) + else: + entry = self._tensor_producers.get(key) + if entry is not None: + input_edges.append(entry) + else: + # Unknown tensor: not a registered state and not produced by any + # intercepted op (e.g. a raw tensor attribute or a global tensor). + # Create an ephemeral OpInfo so the consuming op's inputs tuple is + # complete and the correct arg index appears in the formatted output. + # Ephemeral ops are NOT added to all_ops — they never appear as their + # own nodes in the summary tree. + input_edges.append(self._get_or_create_ephemeral(key.id)) + + return tuple(input_edges) + + def _record_outputs(self, out: Any, op_info: OpInfo) -> None: + """Record that op_info produced these output tensors.""" + for idx, tensor in enumerate(flatten_tensors_to_list(out)): + key = TensorIdVersion(id(tensor), tensor._version) + self._tensor_producers[key] = InputEdge(op=op_info, output_idx=idx) + weakref.finalize(tensor, self._tensor_producers.pop, key, None) + + def _register_as_consumer(self, inputs: tuple[InputEdge, ...], consumer: OpInfo) -> None: + """Append consumer to each producer's outputs dict at the given slot. + + Deduplicates per (producer, output_idx) pair so that passing the same + tensor twice as arguments to one op only registers the consumer once per + slot. A consumer that uses both output[0] and output[1] of the same + producer is still registered in both slots. + """ + seen: set[tuple[str, int | None]] = set() + for slot in inputs: + pair = (slot.op_name, slot.output_idx) + if pair in seen: + continue + seen.add(pair) + if slot.output_idx is not None: + existing = slot.outputs.get(slot.output_idx, ()) + slot.op.outputs[slot.output_idx] = existing + (consumer,) + + def __torch_function__( + self, + func: Callable, + types: list, + args: tuple = (), + kwargs: dict | None = None, + ) -> Any: + if kwargs is None: + kwargs = {} + + # Snapshot input tensor keys BEFORE func() executes. + # Critical for in-place ops: func() mutates _version, but + # the producer was recorded at the pre-mutation version. + input_tensor_keys = tuple( + TensorIdVersion(id(t), t._version) + for t in flatten_tensors_to_list((*args, *kwargs.values())) + ) + + out = func(*args, **kwargs) + + if self.current_module in self.traversed_modules: + return out + + if not _is_interceptable_func(func): + return out + + module_name = self.current_module_name + + # Compute op name with per-module counter + base_name = get_func_base_name(func) + count = self.func_counts[module_name][base_name] + local_name = get_func_name(func, count) + self.func_counts[module_name][base_name] = count + 1 + + op_name = f"{module_name}.{local_name}" if module_name else local_name + + # Determine op_type and track supported ops. + # In the case of compressor being None, simply use base_name as the op_type. + # If the compressor is not None and a registry is in play, the op_type for the same op may + # differ based on how it the function is associated in the registry. + # Ops not registered will later be filtered out and will not appear as part of the module + # tree, but can still appear as op inputs or outputs for other ops and as module + # input and output ops. + op_type = base_name + if self._op_type_resolver is not None: + resolved_type = self._op_type_resolver(func) + if resolved_type is not None: + op_type = resolved_type + self.supported_op_names.add(op_name) + + module_stack = self._get_module_stack() + source_frames = self._extract_source_frames() + input_edges = self._resolve_inputs(input_tensor_keys) + + op_info = OpInfo( + op_name=op_name, + op_type=op_type, + module_stack=module_stack, + source_frames=source_frames, + inputs=input_edges, + outputs={}, + is_state=False, + ) + + self._add_op(op_info) + + self._register_as_consumer(input_edges, op_info) + self._record_outputs(out, op_info) + + return out + + def remove_hooks(self) -> None: + for hook in self.hooks: + hook.remove() + self.hooks.clear() + + +def _get_op_type_resolver( + compressor: type[_BaseModelCompressor] | None, +) -> Callable[[Callable], str | None] | None: + """Get the op type resolver for a compressor in eager mode. + + The resolver maps a torch function to its compressor-defined op type. + Returns None if no compressor is specified. + """ + if compressor is None: + return None + if issubclass(compressor, Quantizer): + return EagerQuantizer.get_op_type_resolver() + if issubclass(compressor, KMeansPalettizer): + return KMeansPalettizer.get_op_type_resolver() + msg = f"No eager mode op type resolver for compressor {compressor.__name__}." + raise ValueError(msg) + + +def _populate_boundary_ops_eager( + module: ModuleInfo, + mode: _EagerOpDiscoveryMode, + subtree_ops_by_module: dict[str, list[OpInfo]], +) -> None: + """Populate input_ops and output_ops for all modules using hook-captured tensor boundaries. + + Recurses depth-first, then processes each module. Both lists are ordered by + module spec index (position in the flattened module input/output tensor list), + matching the index semantics of ``module_input_spec`` / ``module_output_spec`` + in the eager quantizer. + + Args: + module (ModuleInfo): The module to populate. + mode (_EagerOpDiscoveryMode): The discovery mode instance, used to access + ``_module_input_producers`` and ``_module_output_producers``. + subtree_ops_by_module (dict[str, list[OpInfo]]): Pre-partitioned map from + module name to all ops in that module's subtree, in execution order. + Built once in :func:`parse_ops_for_eager` from ``mode.all_ops`` using + each op's ``module_stack`` to avoid an O(num_modules × num_ops) scan. + + Note: + Compression's ``RegisterEagerOptimizationMode`` uses ``ModuleBoundaryTracker`` for + boundary detection. The inspector uses inline resolution in ``_enter_module`` / + ``_exit_module`` hooks instead. Both share the same conceptual methodology (capture + tensors at module entry/exit), but the inspector resolves ``TensorIdVersion`` keys to + ``OpInfo`` objects immediately while they are live, avoiding the need for a stable counter + and a deferred lookup layer. Adopting ``ModuleBoundaryTracker`` here would align the hook + call sites but add indirection (counter → ``OpInfo``) without removing the + inspector-specific translation layer (state filtering, module→producers reconstruction). + """ + for child in module.child_modules.values(): + _populate_boundary_ops_eager(child, mode, subtree_ops_by_module) + + subtree_ops_list = subtree_ops_by_module.get(module.module_name, []) + # OpInfo defines __eq__ and __hash__ on op_name, so set membership is op_name-based. + subtree_ops = set(subtree_ops_list) + + # Build a map: (external producer op_name, output_idx) → [(consuming_op, input_slot), ...] + # covering all ops inside this subtree that have at least one external non-state input. + # Keying by (op_name, output_idx) distinguishes e.g. the two outputs of a chunk/split op + # when both flow into this module as separate forward args. + external_to_consumers: dict[tuple[str, int | None], list[tuple[OpInfo, int]]] = {} + for op in subtree_ops_list: + for i, inp in enumerate(op.inputs): + if not inp.is_state and inp.op not in subtree_ops: + external_to_consumers.setdefault((inp.op_name, inp.output_idx), []).append((op, i)) + + # input_ops: keyed by module input spec index (position in _module_input_producers). + # Each key maps to all (op, input_slot) pairs that the tensor at that position feeds + # into — a module input tensor can fan out to multiple ops inside the module. + # Positions occupied by states or tensors with no consuming op in the subtree are absent. + module.input_ops = {} + for spec_idx, entry in enumerate(mode._module_input_producers.get(module.module_name, [])): + if entry is None or entry.is_state: + continue + consumers = external_to_consumers.get((entry.op_name, entry.output_idx), []) + if not consumers: + continue + module.input_ops[spec_idx] = [ + BoundaryEdge(op=op, index=input_idx) for op, input_idx in consumers + ] + + # output_ops: keyed by module output spec index (position in _module_output_producers). + # Each key maps to the single (op, output_slot) pair that produces the tensor at that + # position. Positions occupied by states or untracked tensors are absent. + module.output_ops = {} + for spec_idx, entry in enumerate(mode._module_output_producers.get(module.module_name, [])): + if entry is None: + continue + op_info, output_idx = entry.op, entry.output_idx + if op_info not in subtree_ops or op_info.is_state or output_idx is None: + continue + module.output_ops[spec_idx] = BoundaryEdge(op=op_info, index=output_idx) + + +def parse_ops_for_eager( + model: nn.Module, + example_inputs: tuple[Any, ...], + compressor: type[_BaseModelCompressor] | None = None, +) -> ModelSummary: + """Discover ops by running a forward pass with torch function interception. + + Args: + model (nn.Module): The model to inspect. + example_inputs (tuple[Any, ...]): Example inputs for the forward pass. + compressor (type[_BaseModelCompressor] | None): A compressor class to + filter ops to only those supported by that compression algorithm. + When None, all interceptable ops are included. + + Returns: + ModelSummary: The discovered operations nested in a module hierarchy. + """ + op_type_resolver = _get_op_type_resolver(compressor) + mode = _EagerOpDiscoveryMode(model, op_type_resolver) + try: + with torch.no_grad(), mode: + model(*example_inputs) + finally: + mode.remove_hooks() + + root = build_module_tree(_fqn(type(model)), mode.all_ops) + + # Pre-partition mode.all_ops by module subtree in a single O(N × D) pass, where N is + # the total op count and D is the average module stack depth. Each op is appended to + # every ancestor module's list, so each per-module list is already in execution order. + subtree_ops_by_module: dict[str, list[OpInfo]] = defaultdict(list) + for op in mode.all_ops: + for ctx in op.module_stack: + subtree_ops_by_module[ctx.module_name].append(op) + + _populate_boundary_ops_eager(root, mode, subtree_ops_by_module) + + if op_type_resolver is not None: + root = filter_module_tree(root, mode.supported_op_names) + + return ModelSummary(model=root, mode=ExecutionMode.EAGER) diff --git a/src/coreai_opt/inspection/_formatting.py b/src/coreai_opt/inspection/_formatting.py index fc71401..7c0d3a0 100644 --- a/src/coreai_opt/inspection/_formatting.py +++ b/src/coreai_opt/inspection/_formatting.py @@ -13,11 +13,26 @@ from rich.text import Text from rich.tree import Tree -from .types import ModelSummary, ModuleInfo, OpInfo +from .types import InputEdge, ModelSummary, ModuleInfo, OpInfo _FRAMEWORK_PATH_MARKERS = ("torch/nn/modules/", "torch/nn/functional", "torch/_") -_LEGEND = "Legend: ■ module_name (module_type) ◆ op_name [op_type]" +_LEGEND = ( + "Legend:\n" + " ■ module_name (module_type) ◆ op_name [op_type]\n" + "\n" + " op inputs: {I: producer[N]} — I = op_input_spec index;" + " N = output slot of the producing op\n" + " op states: param_name — model parameter or buffer\n" + " op outputs: {N: [consumers]} — N = output slot index;" + " consumers = ops receiving that output\n" + " untracked_N — input tensor whose producer was not intercepted" + " (e.g. raw attribute or global tensor); still quantizable via op_input_spec\n" + " module inputs: {I: [op[N], ...]} — I = module_input_spec index;" + " op[N] = op and its input slot receiving data from outside; absent keys = non-quantizable\n" + " module outputs: {I: op[N]} — I = module_output_spec index;" + " op[N] = op and its output slot leaving the module; absent keys = non-quantizable" +) def _source_for_op(op: OpInfo) -> tuple[str, str]: @@ -37,6 +52,17 @@ def _source_for_op(op: OpInfo) -> tuple[str, str]: return f"{rel_path}:{frame.lineno}", frame.code_context +def _producer_output_label(inp: InputEdge) -> str: + """Return the display label for one input edge. + + When ``output_idx`` is ``None`` (registered states, ephemeral/untracked tensors) + only the name is shown. Otherwise the output slot index is appended: ``name[N]``. + """ + if inp.output_idx is None: + return inp.op_name + return f"{inp.op_name}[{inp.output_idx}]" + + def _styled_op_label(op: OpInfo) -> Text: """Build the styled multi-line label for an op leaf node.""" label = Text() @@ -49,17 +75,28 @@ def _styled_op_label(op: OpInfo) -> Text: label.append(op_type_str, style="yellow") label.append("]") - # Line 2: op inputs - if op.inputs: - input_names = ", ".join(inp.op_name for inp in op.inputs) - label.append(f"\n op inputs: {input_names}") + # Line 2: op inputs as {arg_idx: producer_label}, excluding states. + # arg_idx is the full positional index (matching op_input_spec), not the filtered position. + non_state_input_items = [(i, inp) for i, inp in enumerate(op.inputs) if not inp.is_state] + if non_state_input_items: + parts = ", ".join(f"{i}: {_producer_output_label(inp)}" for i, inp in non_state_input_items) + label.append(f"\n op inputs: {{{parts}}}") + + # Line 3: op states + op_state_names = [inp._display_name for inp in op.inputs if inp.is_state] + if op_state_names: + state_names = ", ".join(name for name in op_state_names) + label.append(f"\n op states: {state_names}") - # Line 3: op outputs + # Line 4: op outputs if op.outputs: - output_names = ", ".join(out.op_name for out in op.outputs) - label.append(f"\n op outputs: {output_names}") + parts = ", ".join( + f"{idx}: [{', '.join(out.op_name for out in consumers)}]" + for idx, consumers in sorted(op.outputs.items()) + ) + label.append(f"\n op outputs: {{{parts}}}") - # Lines 4-5: source + # Lines 5-6: source source_path, source_code = _source_for_op(op) if source_path: label.append(f"\n filepath: {source_path}", style="dim") @@ -69,6 +106,21 @@ def _styled_op_label(op: OpInfo) -> Text: return label +def _format_input_ops(input_ops: dict) -> str: + """Format module input_ops dict as '{I: [op[N], ...], ...}'.""" + parts = [] + for k, edges in sorted(input_ops.items()): + edge_strs = [f"{e.op.op_name}[{e.index}]" for e in edges] + parts.append(f"{k}: [{', '.join(edge_strs)}]") + return "{" + ", ".join(parts) + "}" + + +def _format_output_ops(output_ops: dict) -> str: + """Format module output_ops dict as '{I: op[N], ...}'.""" + parts = [f"{k}: {e.op.op_name}[{e.index}]" for k, e in sorted(output_ops.items())] + return "{" + ", ".join(parts) + "}" + + def _styled_module_label(module: ModuleInfo) -> Text: """Build the styled label for a module node.""" label = Text() @@ -78,11 +130,9 @@ def _styled_module_label(module: ModuleInfo) -> Text: label.append(module.module_type, style="magenta") label.append(")") if module.input_ops: - input_names = ", ".join(op.op_name for op in module.input_ops) - label.append(f"\n module inputs: {input_names}", style="dim") + label.append(f"\n module inputs: {_format_input_ops(module.input_ops)}", style="dim") if module.output_ops: - output_names = ", ".join(op.op_name for op in module.output_ops) - label.append(f"\n module outputs: {output_names}", style="dim") + label.append(f"\n module outputs: {_format_output_ops(module.output_ops)}", style="dim") return label @@ -120,11 +170,13 @@ def format_model_summary(summary: ModelSummary, colorize: bool | None = None) -> root_label.append(summary.model.module_type, style="magenta") root_label.append(")") if summary.model.input_ops: - input_names = ", ".join(op.op_name for op in summary.model.input_ops) - root_label.append(f"\n module inputs: {input_names}", style="dim") + root_label.append( + f"\n module inputs: {_format_input_ops(summary.model.input_ops)}", style="dim" + ) if summary.model.output_ops: - output_names = ", ".join(op.op_name for op in summary.model.output_ops) - root_label.append(f"\n module outputs: {output_names}", style="dim") + root_label.append( + f"\n module outputs: {_format_output_ops(summary.model.output_ops)}", style="dim" + ) tree = Tree(root_label) _render_tree(summary.model, tree) diff --git a/src/coreai_opt/inspection/_graph_mode.py b/src/coreai_opt/inspection/_graph_mode.py index e3b6b48..13f0019 100644 --- a/src/coreai_opt/inspection/_graph_mode.py +++ b/src/coreai_opt/inspection/_graph_mode.py @@ -12,41 +12,47 @@ from __future__ import annotations import re +from collections import defaultdict import torch from torch.fx import Node -from coreai_opt._utils.torch_utils import ( - get_node_type as _get_node_type, - normalize_module_fqn as _normalize_module_fqn, +from coreai_opt._utils.fx_utils import ( + get_module_boundary_nodes, + get_node_type, + normalize_module_fqn, ) from coreai_opt.base_model_compressor import _BaseModelCompressor +from coreai_opt.quantization import Quantizer +from coreai_opt.quantization._graph.quantizer import GraphQuantizer from coreai_opt.quantization.config.quantization_config import ExecutionMode +from ._common import ( + FORWARD_FUNCTION_NAME, + build_module_tree, + filter_module_tree, +) from .types import ( - ModelSummary as _ModelOpSummary, - ModuleContext as _ModuleContext, - ModuleInfo as _ModuleSummary, - OpInfo as _OpInfo, - SourceFrame as _SourceFrame, + BoundaryEdge, + InputEdge, + ModelSummary, + ModuleContext, + ModuleInfo, + OpInfo, + SourceFrame, ) -# The function name used to identify relevant source frames. -# Only frames from ``forward`` methods are kept; all other frames -# (framework dispatch, C++ internals, etc.) are discarded. -_FORWARD_FUNCTION_NAME = "forward" - -def _extract_module_stack(node: Node) -> tuple[_ModuleContext, ...]: +def _extract_module_stack(node: Node) -> tuple[ModuleContext, ...]: """Build the module nesting hierarchy from ``nn_module_stack`` metadata.""" stack = node.meta.get("nn_module_stack", {}) return tuple( - _ModuleContext(module_name=_normalize_module_fqn(module_fqn), module_type=module_type) + ModuleContext(module_name=normalize_module_fqn(module_fqn), module_type=module_type) for module_fqn, module_type in stack.values() ) -def _parse_stack_trace(stack_trace: str | None) -> tuple[_SourceFrame, ...]: +def _parse_stack_trace(stack_trace: str | None) -> tuple[SourceFrame, ...]: """Parse the ``stack_trace`` metadata string into filtered source frames. The ``stack_trace`` stored in ``node.meta["stack_trace"]`` is a multi-line @@ -61,7 +67,7 @@ def _parse_stack_trace(stack_trace: str | None) -> tuple[_SourceFrame, ...]: if not stack_trace: return () - frames: list[_SourceFrame] = [] + frames: list[SourceFrame] = [] lines = stack_trace.strip().splitlines() # Lines come in pairs: the first is a location header of the form # File "path/to/file.py", line 42, in forward @@ -83,9 +89,9 @@ def _parse_stack_trace(stack_trace: str | None) -> tuple[_SourceFrame, ...]: if i + 1 < len(lines) and not lines[i + 1].strip().startswith("File "): code_context = lines[i + 1].strip() i += 1 - if function_name == _FORWARD_FUNCTION_NAME: + if function_name == FORWARD_FUNCTION_NAME: frames.append( - _SourceFrame( + SourceFrame( filename=filename, lineno=lineno, function_name=function_name, @@ -96,149 +102,140 @@ def _parse_stack_trace(stack_trace: str | None) -> tuple[_SourceFrame, ...]: return tuple(frames) -def _get_or_create_child( - parent: _ModuleSummary, module_name: str, module_type: str -) -> _ModuleSummary: - """Get an existing child module or create a new one.""" - if module_name not in parent.child_modules: - parent.child_modules[module_name] = _ModuleSummary( - module_name=module_name, - module_type=module_type, - child_modules={}, - ops=[], - input_ops=[], - output_ops=[], - ) - return parent.child_modules[module_name] - - -def _populate_boundary_ops(module: _ModuleSummary, get_attr_names: set[str]) -> None: - """Recursively populate ``input_ops`` and ``output_ops`` for a module tree.""" - for child in module.child_modules.values(): - _populate_boundary_ops(child, get_attr_names) - - all_subtree_ops = module.all_ops() - subtree_op_names = {op.op_name for op in all_subtree_ops} - ignore_names = subtree_op_names | get_attr_names - - module.input_ops = [ - op for op in all_subtree_ops if any(inp.op_name not in ignore_names for inp in op.inputs) - ] - module.output_ops = [ - op - for op in all_subtree_ops - if not op.outputs or any(out.op_name not in subtree_op_names for out in op.outputs) - ] - +def _populate_boundary_ops_graph( + root: ModuleInfo, + model: torch.fx.GraphModule, + node_name_to_op_info: dict[str, OpInfo], +) -> None: + """Populate input_ops and output_ops for all modules in topological order. -def parse_ops_in_graph(model: torch.fx.GraphModule) -> _ModelOpSummary: + Reuses the method for which graph mode Quantizer uses to determine module + boundary inputs and outputs. + A single pass over model.graph.nodes buckets each node under every module + in its nn_module_stack. Per module, a node is an input_op if any of its + non-get_attr inputs falls outside the module's subtree, and an output_op if + any of its users falls outside the subtree. + """ + module_to_nodes: defaultdict[str, list[Node]] = defaultdict(list) + for node in model.graph.nodes: + if node.op == "get_attr": + continue + for ctx in _extract_module_stack(node): + module_to_nodes[ctx.module_name].append(node) + + def _recurse(module: ModuleInfo) -> None: + for child in module.child_modules.values(): + _recurse(child) + + subtree_nodes = module_to_nodes.get(module.module_name, []) + input_consumer_tuples, output_nodes = get_module_boundary_nodes(subtree_nodes) + # input_ops: keyed by spec index (enumerate position in input_consumer_tuples, + # which already excludes state nodes via is_coreai_compressed_state_node). + module.input_ops = { + idx: [ + BoundaryEdge( + op=node_name_to_op_info[consumer.name], + index=consumer.all_input_nodes.index(external), + ) + ] + for idx, (external, consumer) in enumerate(input_consumer_tuples) + } + # output_ops: keyed by spec index (enumerate position in output_nodes). + module.output_ops = { + idx: BoundaryEdge( + op=node_name_to_op_info[node.name], + # outputs is always {0: consumers} in graph mode (see phase 2), so this yields 0. + index=next(iter(node_name_to_op_info[node.name].outputs)), + ) + for idx, node in enumerate(output_nodes) + } + + _recurse(root) + + +def parse_ops_for_graph( + model: torch.fx.GraphModule, + compressor: type[_BaseModelCompressor] | None = None, +) -> ModelSummary: """Discover all operations in a graph exported model. Args: - model: An exported ``torch.fx.GraphModule`` (from ``torch.export``). + model (torch.fx.GraphModule): An exported ``torch.fx.GraphModule`` + (from ``torch.export``). + compressor (type[_BaseModelCompressor] | None): A compressor class to + filter ops to only those supported by that compression algorithm. + When ``None``, all ops are included. Returns: - A :class:`ModelSummary` with operations nested in a - :class:`ModuleInfo` tree mirroring the ``nn.Module`` hierarchy. + ModelSummary: Operations nested in a :class:`ModuleInfo` tree + mirroring the ``nn.Module`` hierarchy. + + Raises: + ValueError: If *compressor* is not supported in graph mode. """ - # Phase 1: Walk graph and create stub OpInfo (empty inputs/outputs) for every op node. - ops_by_name: dict[str, _OpInfo] = {} - get_attr_names: set[str] = set() - node_op_list: list[tuple[torch.fx.Node, _OpInfo]] = [] + # Phase 1: Build OpInfo stubs (empty inputs/outputs) for every node. + node_name_to_op_info_dict: dict[str, OpInfo] = {} + node_op_list: list[tuple[torch.fx.Node, OpInfo]] = [] + all_ops: list[OpInfo] = [] + seen_op_names: set[str] = set() + root_module_type = "" for node in model.graph.nodes: - if node.op == "get_attr": - get_attr_names.add(node.name) - - op_type = _get_node_type(node, warn_on_failure=False) + op_type = get_node_type(node, warn_on_failure=False) module_stack = _extract_module_stack(node) source_frames = _parse_stack_trace(node.meta.get("stack_trace")) - op_info = _OpInfo( - op_name=node.name, + # One time processing to fill in root_module_type + if not root_module_type: + for ctx in module_stack: + if ctx.module_name == "": + root_module_type = ctx.module_type + break + + is_state = node.op == "get_attr" + op_info = OpInfo( + op_name=(node.target if is_state else node.name), op_type=op_type, module_stack=module_stack, source_frames=source_frames, inputs=(), - outputs=(), + outputs={}, + is_state=is_state, ) - ops_by_name[node.name] = op_info + node_name_to_op_info_dict[node.name] = op_info node_op_list.append((node, op_info)) + assert op_info.op_name not in seen_op_names, f"duplicate op_name {op_info.op_name}" + seen_op_names.add(op_info.op_name) + all_ops.append(op_info) - # Phase 2: Fill in inputs/outputs and build the module tree. - root = _ModuleSummary( - module_name="", - module_type="", - child_modules={}, - ops=[], - input_ops=[], - output_ops=[], - ) - + # Phase 2: Fill in op inputs/outputs. for node, op_info in node_op_list: - inputs = tuple( - ops_by_name[inp.name] for inp in node.all_input_nodes if inp.name in ops_by_name + # Graph mode: all outputs are at slot 0, so output_idx is always 0. + op_info.inputs = tuple( + InputEdge(op=node_name_to_op_info_dict[inp.name], output_idx=0) + for inp in node.all_input_nodes + if inp.name in node_name_to_op_info_dict ) - outputs = tuple(ops_by_name[user.name] for user in node.users if user.name in ops_by_name) - op_info.inputs = inputs - op_info.outputs = outputs - - # Walk down the module stack, placing the op in the deepest module. - if op_info.module_stack: - current = root - for ctx in op_info.module_stack: - if ctx.module_name == "": - # Root module — update type info - if not root.module_type: - root.module_type = ctx.module_type - continue - current = _get_or_create_child(current, ctx.module_name, ctx.module_type) - - current.ops.append(op_info) - - _populate_boundary_ops(root, get_attr_names) - return _ModelOpSummary(model=root, mode=ExecutionMode.GRAPH) - - -def _filter_module_tree(module: _ModuleSummary, keep_names: set[str]) -> _ModuleSummary: - """Recursively filter a ``ModuleInfo`` tree, keeping only matching ops.""" - filtered_children = { - fqn: _filter_module_tree(child, keep_names) for fqn, child in module.child_modules.items() - } - filtered_ops = [op for op in module.ops if op.op_name in keep_names] - - return _ModuleSummary( - module_name=module.module_name, - module_type=module.module_type, - child_modules=filtered_children, - ops=filtered_ops, - input_ops=module.input_ops, - output_ops=module.output_ops, - ) - - -def filter_by_compressor( - summary: _ModelOpSummary, - compressor: type[_BaseModelCompressor] | None, - gm: torch.fx.GraphModule, -) -> _ModelOpSummary: - """Filter a summary to ops supported by the given compressor. - - Uses FX graph pattern matching to determine which ops the compressor - can target. - - Args: - summary: The full (unfiltered) op summary. - compressor: A compressor class. When ``None``, returns unchanged. - gm: The exported graph module, used for pattern matching. - - Returns: - A filtered :class:`ModelSummary`. - """ - if compressor is None: - return summary - - compressible_names = compressor.get_compressible_op_names(gm, ExecutionMode.GRAPH) - - filtered_root = _filter_module_tree(summary.model, compressible_names) - return _ModelOpSummary(model=filtered_root, mode=summary.mode) + # For graph mode, graph annotation has no concept of multiple outputs. Graph quantizer + # lumps them all as a single output index. Hardcode all outputs to index 0. + op_info.outputs = { + 0: tuple( + node_name_to_op_info_dict[user.name] + for user in node.users + if user.name in node_name_to_op_info_dict + ) + } + + # Phase 3: Build the module tree. + root = build_module_tree(root_module_type, all_ops) + _populate_boundary_ops_graph(root, model, node_name_to_op_info_dict) + + if compressor is not None: + if issubclass(compressor, Quantizer): + compressible_names = GraphQuantizer.get_compressible_op_names(model) + else: + msg = f"No graph mode op filtering for compressor {compressor.__name__}." + raise ValueError(msg) + root = filter_module_tree(root, compressible_names) + + return ModelSummary(model=root, mode=ExecutionMode.GRAPH) diff --git a/src/coreai_opt/inspection/model_inspector.py b/src/coreai_opt/inspection/model_inspector.py index 576a10d..a4c703c 100644 --- a/src/coreai_opt/inspection/model_inspector.py +++ b/src/coreai_opt/inspection/model_inspector.py @@ -16,14 +16,13 @@ from coreai_opt._utils.python_utils import fqn as _fqn from coreai_opt._utils.torch_utils import export_model as _export_model from coreai_opt.base_model_compressor import _BaseModelCompressor -from coreai_opt.quantization import Quantizer as _Quantizer +from coreai_opt.palettization import KMeansPalettizer +from coreai_opt.quantization import Quantizer from coreai_opt.quantization.config.quantization_config import ExecutionMode +from ._eager_mode import parse_ops_for_eager as _parse_ops_for_eager from ._formatting import format_model_summary as _format_model_summary -from ._graph_mode import ( - filter_by_compressor as _filter_by_compressor_graph_mode, - parse_ops_in_graph as _parse_ops_in_graph, -) +from ._graph_mode import parse_ops_for_graph as _parse_ops_for_graph from .types import ModelSummary, OpInfo @@ -51,8 +50,8 @@ class ModelInspector: torch.no_grad() context. Defaults to True. Raises: - TypeError: If *model* is not an ``nn.Module``. - NotImplementedError: If *execution_mode* is ``"eager"``. + TypeError: If *model* is not an ``nn.Module``, or if *model* is a + ``GraphModule`` and *execution_mode* is ``"eager"``. RuntimeError: If model export fails (graph mode). ValueError: If example_inputs is None without the right model/execution_mode combination, or if execution_mode is not either "eager" or "graph". @@ -112,39 +111,60 @@ def __init__( dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None = None, export_with_no_grad: bool = True, ) -> None: + self._validate_args( + model, example_inputs, execution_mode, compressor, dynamic_shapes, export_with_no_grad + ) + + if execution_mode == ExecutionMode.GRAPH: + gm = model + if not isinstance(gm, torch.fx.GraphModule): + gm = _export_model(model, example_inputs, dynamic_shapes, export_with_no_grad) + self._summary = _parse_ops_for_graph(gm, compressor) + else: + self._summary = _parse_ops_for_eager(model, example_inputs, compressor) - # Check that model is an accepted type + @staticmethod + def _validate_args( + model: torch.fx.GraphModule | torch.nn.Module, + example_inputs: tuple[Any, ...] | None, + execution_mode: ExecutionMode, + compressor: type[_BaseModelCompressor] | None, + dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None, + export_with_no_grad: bool, + ) -> None: + """Validate constructor arguments.""" if not isinstance(model, (torch.fx.GraphModule, torch.nn.Module)): msg = f"Expected a torch.fx.GraphModule or torch.nn.Module, got {type(model).__name__}" raise TypeError(msg) - # Check that model and execution_mode are GraphModule and ExecutionMode.GRAPH respectively - # when example_inputs is None - if example_inputs is None: - if not ( - isinstance(model, torch.fx.GraphModule) and execution_mode == ExecutionMode.GRAPH - ): - msg = ( - "example_inputs can only be None when model is a GraphModule and " - "execution_mode is ExecutionMode.GRAPH" - ) - raise ValueError(msg) + if execution_mode not in (ExecutionMode.GRAPH, ExecutionMode.EAGER): + msg = f"Unknown execution_mode {execution_mode}. Expected 'graph' or 'eager'." + raise ValueError(msg) - if compressor is not None and not issubclass(compressor, _Quantizer): + if example_inputs is None and not ( + isinstance(model, torch.fx.GraphModule) and execution_mode == ExecutionMode.GRAPH + ): + msg = ( + "example_inputs can only be None when model is a GraphModule and " + "execution_mode is ExecutionMode.GRAPH" + ) + raise ValueError(msg) + + if compressor is not None and not issubclass(compressor, (Quantizer, KMeansPalettizer)): msg = ( f"Unsupported compressor class {compressor.__name__}. " - "Currently only Quantizer is supported." + "Supported compressors: Quantizer, KMeansPalettizer." ) raise ValueError(msg) if execution_mode == ExecutionMode.GRAPH: - gm = model - if not isinstance(gm, torch.fx.GraphModule): - gm = _export_model(model, example_inputs, dynamic_shapes, export_with_no_grad) - self._summary = _parse_ops_in_graph(gm) - self._summary = _filter_by_compressor_graph_mode(self._summary, compressor, gm) - - elif execution_mode == ExecutionMode.EAGER: + if compressor is not None and not issubclass(compressor, Quantizer): + msg = ( + f"Compressor {compressor.__name__} is not supported in graph mode. " + "Only Quantizer is supported for graph mode inspection." + ) + raise ValueError(msg) + else: if isinstance(model, torch.fx.GraphModule): msg = ( "Expected a torch.nn.Module for Eager execution_mode, got torch.fx.GraphModule" @@ -154,18 +174,14 @@ def __init__( warnings.warn( "dynamic_shapes is only supported in graph mode and will be ignored.", UserWarning, - stacklevel=2, + stacklevel=3, ) if not export_with_no_grad: warnings.warn( "export_with_no_grad is only supported in graph mode and will be ignored.", UserWarning, - stacklevel=2, + stacklevel=3, ) - raise NotImplementedError("Eager mode op discovery is not yet implemented.") - else: - msg = f"Unknown execution_mode {execution_mode}. Expected 'graph' or 'eager'." - raise ValueError(msg) @property def summary(self) -> ModelSummary: diff --git a/src/coreai_opt/inspection/types.py b/src/coreai_opt/inspection/types.py index 4313ddb..e9e8b1d 100644 --- a/src/coreai_opt/inspection/types.py +++ b/src/coreai_opt/inspection/types.py @@ -11,6 +11,7 @@ from dataclasses import dataclass from typing import Any, ClassVar +from coreai_opt._utils.fx_utils import get_local_state_name as _get_local_state_name from coreai_opt.quantization.config.quantization_config import ExecutionMode @@ -53,6 +54,54 @@ class ModuleContext: module_type: str +@dataclass(frozen=True) +class InputEdge: + """One input edge into an op, pairing the producing op with its output slot. + + Used as the element type of :attr:`OpInfo.inputs`. Delegation properties + forward the most-accessed :class:`OpInfo` attributes so that code iterating + ``op.inputs`` does not need to go through ``.op`` for routine checks. + + Attributes: + op (OpInfo): The op that produced this input tensor. + output_idx (int | None): Which output slot of ``op`` this tensor came from, + or ``None`` for synthetic ops (registered states, ephemeral/untracked + tensors) that have no meaningful output slot. + """ + + op: OpInfo + output_idx: int | None + + # --- delegation properties ------------------------------------------------- + @property + def op_name(self) -> str: + return self.op.op_name + + @property + def op_type(self) -> str | None: + return self.op.op_type + + @property + def is_state(self) -> bool: + return self.op.is_state + + @property + def module_stack(self) -> tuple[ModuleContext, ...]: + return self.op.module_stack + + @property + def inputs(self) -> tuple[InputEdge, ...]: + return self.op.inputs + + @property + def outputs(self) -> dict[int, tuple[OpInfo, ...]]: + return self.op.outputs + + @property + def _display_name(self) -> str: + return self.op._display_name + + @dataclass(eq=False) class OpInfo: """Information about a single operation discovered in a model. @@ -71,21 +120,39 @@ class OpInfo: source_frames (tuple[SourceFrame, ...]): Source code locations from outermost ``forward()`` to innermost, showing the call chain that produced this op. May be empty if source information is unavailable. - inputs (tuple[OpInfo, ...]): Ordered input ops (ops, placeholders, - parameters) that feed into this op. - outputs (tuple[OpInfo, ...]): Consumer ops that receive the output - of this op, in graph order. + inputs (tuple[InputEdge, ...]): Ordered input edges. Each :class:`InputEdge` + carries the producing op and the output slot of that op the tensor came from. + outputs (dict[int, tuple[OpInfo, ...]]): Dictionary mapping op outputs to a tuple of ops + consuming the output. + is_state (bool): ``True`` if this op represents a model parameter or + buffer rather than a computation. State ops have an empty + ``module_stack`` and do not appear in module tree or boundary lists. """ op_name: str op_type: str | None module_stack: tuple[ModuleContext, ...] source_frames: tuple[SourceFrame, ...] - inputs: tuple[OpInfo, ...] - outputs: tuple[OpInfo, ...] + inputs: tuple[InputEdge, ...] + outputs: dict[int, tuple[OpInfo, ...]] + is_state: bool _IMMUTABLE_FIELDS: ClassVar[frozenset[str]] = frozenset({"op_name"}) + @property + def _display_name(self) -> str: + """Name for user-facing output such as ``format_summary``. + + Equal to :attr:`op_name` except for state ops, where only the last + dotted component is returned because state-matching configs are + currently keyed on that suffix (e.g., a parameter with name ``"conv.weight"`` + can only be matched with ``"weight"`` in the config). Remove this property + once full-FQN state matching lands (rdar://177076777). + """ + if self.is_state: + return _get_local_state_name(self.op_name) or self.op_name + return self.op_name + def __repr__(self) -> str: return f"OpInfo(op_name={self.op_name!r}, op_type={self.op_type!r})" @@ -116,6 +183,26 @@ def __delattr__(self, name: str) -> None: super().__delattr__(name) +@dataclass(frozen=True) +class BoundaryEdge: + """A single data-flow edge crossing a module boundary. + + Each entry in :attr:`ModuleInfo.input_ops` or :attr:`ModuleInfo.output_ops` + corresponds to one quantizer configuration point at the boundary. + + Attributes: + op (OpInfo): The op inside the module at the boundary. + index (int): For input boundaries: the input slot of ``op`` that + receives external data, matching the index used in + ``module_input_spec``. For output boundaries: the output slot of + ``op`` whose tensor leaves the module, matching the index used in + ``module_output_spec``. + """ + + op: OpInfo + index: int + + @dataclass class ModuleInfo: """A node in the ``nn.Module`` hierarchy with its directly-owned ops. @@ -133,18 +220,27 @@ class ModuleInfo: ``module_name``, in insertion order. ops (list[OpInfo]): Ops directly owned by this module, in graph order. - input_ops (list[OpInfo]): Ops owned by this module, that receive data from - outside this module. - output_ops (list[OpInfo]): Ops owned by this module, that send data outside - this module. + input_ops (dict[int, list[BoundaryEdge]]): Boundary edges where data enters this + module from outside. Keys are module input spec indices (positions in the + flattened module forward arguments). Values are lists of all + ``(op, input_slot)`` pairs that the tensor at that position feeds into inside + the module — a single module input tensor can fan out to multiple ops. Keys + are absent for positions occupied by state tensors, untracked tensors, or + unused arguments. The key is what the user passes to ``module_input_spec``. + output_ops (dict[int, BoundaryEdge]): Boundary edges where data leaves this + module. Keys are module output spec indices (positions in the flattened + module return value). Each key maps to the ``(op, output_slot)`` pair that + produces the tensor at that position. Keys are absent for positions occupied + by state tensors or untracked tensors. The key is what the user passes to + ``module_output_spec``. """ module_name: str module_type: str child_modules: dict[str, ModuleInfo] ops: list[OpInfo] - input_ops: list[OpInfo] - output_ops: list[OpInfo] + input_ops: dict[int, list[BoundaryEdge]] + output_ops: dict[int, BoundaryEdge] def children(self) -> Iterator[ModuleInfo]: """Yield direct child modules in insertion order.""" diff --git a/src/coreai_opt/palettization/kmeans/kmeans_support_mixins.py b/src/coreai_opt/palettization/kmeans/kmeans_support_mixins.py index 4d21bf2..d1061e9 100644 --- a/src/coreai_opt/palettization/kmeans/kmeans_support_mixins.py +++ b/src/coreai_opt/palettization/kmeans/kmeans_support_mixins.py @@ -92,3 +92,14 @@ def reshape_to_original( ) .transpose(0, 1) ) + + +class _ConvTransposePalettizationMixin(_ConvPalettizationMixin): + """Mixin providing palettization support for transposed convolution operations. + + ``ConvTranspose`` weights are shaped ``[in_channels, out_channels, *kernel]``, + so the output-channel axis is 1. The reshape logic from + ``_ConvPalettizationMixin`` works unchanged. + """ + + default_axis: ClassVar[int] = 1 diff --git a/src/coreai_opt/palettization/kmeans/palettizer.py b/src/coreai_opt/palettization/kmeans/palettizer.py index e33a8aa..f0411d4 100644 --- a/src/coreai_opt/palettization/kmeans/palettizer.py +++ b/src/coreai_opt/palettization/kmeans/palettizer.py @@ -143,6 +143,11 @@ def __init__(self, model: torch.nn.Module, config: KMeansPalettizerConfig | None self._num_workers = 1 + @classmethod + def get_op_type_resolver(cls) -> Callable[[Callable], str | None]: + """Return a function that maps a torch function to its palettizable op type.""" + return _KMeansPalettizerSupportedOpsRegistry.get_func_type + def prepare( self, example_inputs: tuple[torch.Tensor], diff --git a/src/coreai_opt/palettization/kmeans/supported_ops_registry.py b/src/coreai_opt/palettization/kmeans/supported_ops_registry.py index 492f0c5..96a4dd0 100644 --- a/src/coreai_opt/palettization/kmeans/supported_ops_registry.py +++ b/src/coreai_opt/palettization/kmeans/supported_ops_registry.py @@ -15,6 +15,7 @@ ) from coreai_opt.palettization.kmeans.kmeans_support_mixins import ( _ConvPalettizationMixin, + _ConvTransposePalettizationMixin, _LinearPalettizationMixin, _PalettizationSupportMixin, ) @@ -75,6 +76,21 @@ class _Conv3dSupport(_ConvPalettizationMixin): ops = [F.conv3d] +@_KMeansPalettizerSupportedOpsRegistry.register("conv_transpose1d") +class _ConvTranspose1dSupport(_ConvTransposePalettizationMixin): + ops = [F.conv_transpose1d] + + +@_KMeansPalettizerSupportedOpsRegistry.register("conv_transpose2d") +class _ConvTranspose2dSupport(_ConvTransposePalettizationMixin): + ops = [F.conv_transpose2d] + + +@_KMeansPalettizerSupportedOpsRegistry.register("conv_transpose3d") +class _ConvTranspose3dSupport(_ConvTransposePalettizationMixin): + ops = [F.conv_transpose3d] + + @_KMeansPalettizerSupportedOpsRegistry.register("linear") class _LinearSupport(_LinearPalettizationMixin): ops = [F.linear] diff --git a/src/coreai_opt/palettization/spec/factory.py b/src/coreai_opt/palettization/spec/factory.py index fd658e7..792aa3d 100644 --- a/src/coreai_opt/palettization/spec/factory.py +++ b/src/coreai_opt/palettization/spec/factory.py @@ -13,7 +13,6 @@ class _PalettizationComponentFactory(CompressionComponentFactoryBase): - @classmethod def construct( cls, diff --git a/src/coreai_opt/palettization/spec/fake_palettize.py b/src/coreai_opt/palettization/spec/fake_palettize.py index 709873f..ee4f7e3 100644 --- a/src/coreai_opt/palettization/spec/fake_palettize.py +++ b/src/coreai_opt/palettization/spec/fake_palettize.py @@ -27,6 +27,7 @@ logger = logging.getLogger(__name__) + class _FakePalettizeImplBase(CompressionSimulatorBase, nn.Module): """Base class for fake palettization implementations with clustering and reconstruction methods. @@ -57,9 +58,7 @@ def __init__( self.cluster_dim = cluster_dim self.enable_per_channel_scale = enable_per_channel_scale - self.register_buffer( - "fake_palett_enabled", torch.tensor([1], dtype=torch.uint8) - ) + self.register_buffer("fake_palett_enabled", torch.tensor([1], dtype=torch.uint8)) self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.uint8)) self._disabled = False @@ -111,9 +110,7 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: return tensor if self.fake_palett_enabled[0] == 1: - return self._palettize( - lut=self.lut, indices=self.indices, original_weights=tensor - ) + return self._palettize(lut=self.lut, indices=self.indices, original_weights=tensor) return tensor @@ -125,9 +122,7 @@ def _palettize( raise NotImplementedError() @abstractmethod - def _calculate_centroids( - self, weight: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: + def _calculate_centroids(self, weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Cluster weights and return lookup table (LUT) and corresponding indices. If tensor is incompatible with the specified granularity, this method @@ -167,14 +162,7 @@ def with_args(cls, **kwargs: dict) -> _PartialConstructor[_FakePalettizeImplBase return fake_palett_constructor def _load_from_state_dict( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): """Custom state dict loading for palettization-specific buffers. @@ -205,8 +193,7 @@ def _load_from_state_dict( unexpected_keys.remove(prefixed_key) super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, - unexpected_keys, error_msgs + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) def enable_fake_palett(self, enabled: bool = True) -> None: diff --git a/src/coreai_opt/palettization/spec/granularity.py b/src/coreai_opt/palettization/spec/granularity.py index eec35f1..1e166c9 100644 --- a/src/coreai_opt/palettization/spec/granularity.py +++ b/src/coreai_opt/palettization/spec/granularity.py @@ -118,7 +118,8 @@ class PerGroupedChannelGranularity(PalettizationGranularity): This applies palettization to a specific channel which is selected through the ``axis`` argument. ``axis`` defaults to ``None``, in which case the default - axis for the consuming op is used (e.g. 0 for ``Linear``/``Conv``). + axis for the consuming op is used (e.g. 0 for ``Linear``/``Conv``, 1 for + ``ConvTranspose``). """ axis: Annotated[int | None, Field(default=None, ge=0, le=1)] diff --git a/src/coreai_opt/pruning/spec/spec.py b/src/coreai_opt/pruning/spec/spec.py index 188bff2..0afb662 100644 --- a/src/coreai_opt/pruning/spec/spec.py +++ b/src/coreai_opt/pruning/spec/spec.py @@ -11,7 +11,6 @@ from pydantic import BeforeValidator, Field, PrivateAttr, field_validator, model_validator -from coreai_opt._utils.registry_utils import ClassRegistryMixin from coreai_opt.common import CompressionType from coreai_opt.config.spec import CompressionSpec @@ -46,50 +45,11 @@ class PruningSpec(CompressionSpec): ] = Field(default_factory=Unstructured) pruning_algo: type[PruneImplBase] = Field(default="default", validate_default=True) - @staticmethod - def _convert_with_registry(data: str | type, registry_class: type[ClassRegistryMixin]) -> type: - """Convert string or type to a registered class from the given registry. - - Args: - data (str | type): Either a string key or a class type. - registry_class (type[ClassRegistryMixin]): The registry class to look up. - - Returns: - type: The registered class type. - - Raises: - ValueError: If the key is not found in registry or type is not registered. - TypeError: If *data* is neither string nor type. - """ - if isinstance(data, str): - try: - return registry_class.get_class(data) # type: ignore[no-any-return] - except KeyError as err: - available_keys = registry_class.list_registry_keys() - raise ValueError( - f"No class is registered with key: '{data}' " - f"in registry {registry_class.__name__}. " - f"Available keys: {sorted(available_keys)}" - ) from err - elif isinstance(data, type): - if data in registry_class.list_registry_values(): - return data - available_classes = [cls.__name__ for cls in registry_class.list_registry_values()] - raise ValueError( - f"Class {data.__name__} is not registered in " - f"{registry_class.__name__}. " - f"Available classes: {sorted(available_classes)}" - ) - else: - raise TypeError( - f"Expected str or type for registry lookup, got {type(data).__name__}: {data}" - ) - @field_validator("pruning_algo", mode="before") @classmethod def convert_pruning_algo(cls, data: Any) -> type[PruneImplBase]: """Resolve string keys to registered pruning implementation classes.""" - return cls._convert_with_registry(data, PruneImplBase) + return PruneImplBase.resolve(data) @model_validator(mode="before") @classmethod diff --git a/src/coreai_opt/quantization/_eager/quantizer.py b/src/coreai_opt/quantization/_eager/quantizer.py index fbcc4f6..084deaa 100644 --- a/src/coreai_opt/quantization/_eager/quantizer.py +++ b/src/coreai_opt/quantization/_eager/quantizer.py @@ -6,7 +6,7 @@ import itertools import warnings from collections import defaultdict -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import contextmanager from os import PathLike @@ -89,21 +89,6 @@ class EagerQuantizer(_BaseQuantizer, _EagerCompressionComponentBuilderMixin): >>> final_model = quantizer.finalize() """ - @classmethod - def get_compressible_op_names( - cls, - model: nn.Module, - ) -> set[str]: - """Return op names in *model* that this quantizer can target. - - Args: - model (nn.Module): The model to inspect. - - Returns: - set[str]: Op names that can be compressed via quantization. - """ - return set() # TODO: Implement eager mode op discovery. - def __init__( self, model: nn.Module, @@ -125,6 +110,11 @@ def __init__( optimization_type_name="quantize", ) + @classmethod + def get_op_type_resolver(cls) -> Callable[[Callable], str | None]: + """Return a function that maps a torch function to its quantizable op type.""" + return EagerQuantizerSupportedOpsRegistry.get_func_type + @staticmethod def _fill_op_input_output_specs_to_check_for_module_config( op_input_specs_to_check: list[_ACTIVATION_SPEC_DICT], diff --git a/src/coreai_opt/quantization/_eager/supported_ops_registry.py b/src/coreai_opt/quantization/_eager/supported_ops_registry.py index d33804e..cad2ceb 100644 --- a/src/coreai_opt/quantization/_eager/supported_ops_registry.py +++ b/src/coreai_opt/quantization/_eager/supported_ops_registry.py @@ -58,11 +58,13 @@ class ConvTranspose3dQuantizationSupport: class LinearQuantizationSupport: ops = [F.linear] + # Register embedding operations @EagerQuantizerSupportedOpsRegistry.register("embedding") class EmbeddingQuantizationSupport: ops = [F.embedding] + # Register pooling operations (no weight parameter) @EagerQuantizerSupportedOpsRegistry.register("max_pool2d") class MaxPool2dQuantizationSupport: diff --git a/src/coreai_opt/quantization/_graph/_annotation_config.py b/src/coreai_opt/quantization/_graph/_annotation_config.py index 386396f..bec1c2d 100644 --- a/src/coreai_opt/quantization/_graph/_annotation_config.py +++ b/src/coreai_opt/quantization/_graph/_annotation_config.py @@ -5,6 +5,10 @@ from __future__ import annotations +from collections.abc import Mapping, Set +from dataclasses import dataclass + +import torch.fx from torchao.quantization.pt2e.quantizer import ( QuantizationSpec as TorchAOQuantizationSpec, ) @@ -14,6 +18,31 @@ from coreai_opt.quantization.spec import QuantizationComponentFactory, QuantizationSpec +@dataclass(frozen=True) +class AnnotationContext: + """Pass-invariant inputs an annotator may need. + + Held constant across all matches in a single annotation pass. Constructed + once when ``_AnnotationHandler.annotate`` begins and shared by every + annotator invocation during that pass. + + Distinct from :class:`AnnotationConfig`, which carries per-op specs that + vary per match. + + Attributes: + module_name_to_state_names_map (Mapping[str, Mapping[str, list[str]]]): + For each module name, a mapping from each state target (FQN) to the + list of local names the module uses for that state. Used during + state-input annotation to translate a state node's target into the + consumer module's local name(s). + shared_observer_nodes (Set[torch.fx.Node]): Nodes whose output annotations + are shared with their input annotations if any. + """ + + module_name_to_state_names_map: Mapping[str, Mapping[str, list[str]]] + shared_observer_nodes: Set[torch.fx.Node] + + class AnnotationConfig: """ Configuration class for PT2E quantization annotations using TorchAO QuantizationSpec diff --git a/src/coreai_opt/quantization/_graph/_annotation_pattern_registry.py b/src/coreai_opt/quantization/_graph/_annotation_pattern_registry.py index 3a9315e..abbf466 100644 --- a/src/coreai_opt/quantization/_graph/_annotation_pattern_registry.py +++ b/src/coreai_opt/quantization/_graph/_annotation_pattern_registry.py @@ -18,6 +18,7 @@ from . import _annotation_utils from ._annotation_config import ( AnnotationConfig as _AnnotationConfig, + AnnotationContext as _AnnotationContext, ) from ._annotation_utils import ( OpsListPattern as _OpsListPattern, @@ -26,18 +27,18 @@ # Generic type variable for match results MatchType = TypeVar("MatchType") -# Generic annotator function type +# Generic annotator function type. # The function is expected to take exactly 3 inputs: # 1. Matched nodes to annotate. The type of this entity is flexible depending -# on the implementation of the AnnotationPattern subclass. Whatever entity -# is returned in the subclass's match_single_pattern dictionary values will -# be passed into this function as the first input. -# 2. Quantization Config to use when annotating the matched nodes. -# 3. Typically when annotating matched nodes, if any immediate child nodes -# are shared observer nodes, the annotation will propagate through the -# child nodes. This argument allows the annotator function to know which -# nodes are shared observer nodes. -AnnotatorFunc: TypeAlias = Callable[[MatchType, _AnnotationConfig, set[torch.fx.Node]], Any] +# on the implementation of the AnnotationPattern subclass. Whatever entity +# is returned in the subclass's match_single_pattern dictionary values will +# be passed into this function as the first input. +# 2. Quantization Config to use when annotating the matched nodes (per-match, +# derived from OpQuantizerConfig). +# 3. Annotation pass context. Holds pass-invariant inputs the annotator may +# need (the model's module-name-to-state-names map and the set of shared +# observer nodes computed at the start of this annotation pass). +AnnotatorFunc: TypeAlias = Callable[[MatchType, _AnnotationConfig, _AnnotationContext], Any] @dataclass(frozen=True) diff --git a/src/coreai_opt/quantization/_graph/_annotation_utils.py b/src/coreai_opt/quantization/_graph/_annotation_utils.py index 1205c5a..fe304ff 100644 --- a/src/coreai_opt/quantization/_graph/_annotation_utils.py +++ b/src/coreai_opt/quantization/_graph/_annotation_utils.py @@ -30,6 +30,12 @@ from coreai_opt._utils.config_utils import ( ALL_TENSORS as _ALL_TENSORS, ConfigLevel as _ConfigLevel, + get_last_matching_spec, +) +from coreai_opt._utils.fx_utils import ( + get_local_state_name, + get_module_boundary_nodes, + is_coreai_compressed_state_node, ) from coreai_opt._utils.python_utils import get_fn_arg_names from coreai_opt._utils.version_utils import version_ge as _version_ge @@ -42,7 +48,7 @@ ) from coreai_opt.quantization.spec import QuantizationSpec -from ._annotation_config import AnnotationConfig +from ._annotation_config import AnnotationConfig, AnnotationContext logger = logging.getLogger(__name__) @@ -630,79 +636,22 @@ def _is_fx_node_floating_point(node: torch.fx.Node) -> bool: return False -def _is_coreai_compressed_state_node(node: torch.fx.Node) -> bool: - """Check if a call_function node represents a coreai state-producing op. - - Recognized patterns: - - - ``coreai.lut_to_dense``: always state (palettized weight decompression). - - ``coreai.constexpr_blockwise_shift_scale``: always state (this op is only - intended for weights). - - NOTE: Update this function if new coreai ops are introduced that produce state - tensors from compressed representations or if the names of existing ops changes. - - Args: - node (torch.fx.Node): An FX graph node with ``op == "call_function"``. - - Returns: - bool: True if the node is a recognized coreai state-producing op. - """ - target = node.target - if not isinstance(target, torch._ops.OpOverload) or target.namespace != "coreai": - return False - - return target._opname in ( - "lut_to_dense", - "constexpr_blockwise_shift_scale", - ) - - -def _is_state_node(node: torch.fx.Node) -> bool: - """Return True if the node represents model state, False otherwise. - - A node is considered state if it is: - - 1. A ``get_attr`` node - 2. A ``call_function`` node targeting a recognized coreai state-producing op - (``lut_to_dense`` for palettized weights, or - ``constexpr_blockwise_shift_scale`` for block shift/scale on weights). - - All other nodes (placeholders, unrecognized call_function ops, call_module, - etc.) are not state. - - Args: - node (torch.fx.Node): The FX graph node to check. - - Returns: - bool: True if the node is a state node, False otherwise. - """ - if node.op == "get_attr": - return True - return node.op == "call_function" and _is_coreai_compressed_state_node(node) - - -def _get_local_state_name(state_node: torch.fx.Node) -> str | None: - """ - Return the local state name by taking the final section of the name after the last - period. The string from target is the torch.nn.Module given state name, not the - torch exported getattr node name. - - For call_function nodes that are identified as state (e.g., lut_to_dense outputs - from palettization), returns None since they don't have a traditional state name. +def _get_state_aliases( + state_node: torch.fx.Node, + module_name_to_state_names_map: Mapping[str, Mapping[str, list[str]]], +) -> set[str]: + """Return all local names any module uses for the state tensor at ``state_node.target``. - Examples: - - Top level model parameter name "model_weight" has local state name - "model_weight" (no period in name) - - Multiple level nested parameter name "model.mod1.mod2.weight" has local state - name "weight", taking the last part of the name after the last period - - call_function state node (lut_to_dense) returns None + A single state tensor may be aliased under different attribute names by different + modules. This collects every such name across all modules so that spec lookups and + warning checks are not limited to a single module's perspective. """ - if state_node.op != "get_attr": - # call_function nodes identified as state (e.g., lut_to_dense from palettization) - # don't have a traditional state name - they are already compressed - return None - return state_node.target.rsplit(".", 1)[-1] + return { + name + for module_states in module_name_to_state_names_map.values() + if state_node.target in module_states + for name in module_states[state_node.target] + } def _warn_non_quantizable_tensor_setting( @@ -748,7 +697,7 @@ def _validate_state_referenced_as_input( Raise error if the user attempts to set a state tensor using input idx in op_input_spec. """ - if _is_state_node(node) and input_idx in op_input_spec: + if is_coreai_compressed_state_node(node) and input_idx in op_input_spec: raise RuntimeError( f"Config is attempting to set op_input_spec idx {input_idx}, but the input " f"is a state tensor (node: {node.name}). Use op_state_spec to configure " @@ -758,7 +707,9 @@ def _validate_state_referenced_as_input( def _get_input_qspec_map( - input_and_state_nodes: list[torch.fx.Node], quantization_config: AnnotationConfig + input_and_state_nodes: list[torch.fx.Node], + quantization_config: AnnotationConfig, + context: AnnotationContext, ) -> dict[torch.fx.Node, TorchAOQuantizationSpec | None]: """ Get input_qspec_map for a node according to the settings in quantization_config. @@ -773,40 +724,31 @@ def _get_input_qspec_map( # warning (settings using "*" will not be flagged) if idx in op_input_spec: _warn_non_quantizable_tensor_setting(node, "input", idx, op_input_spec) - state_name = _get_local_state_name(node) if _is_state_node(node) else None - if state_name is not None and state_name in op_state_spec: - _warn_non_quantizable_tensor_setting(node, "state", state_name, op_state_spec) + if is_coreai_compressed_state_node(node): + state_names = _get_state_aliases(node, context.module_name_to_state_names_map) + matching_keys = [key for key in op_state_spec if key in state_names] + if matching_keys: + _warn_non_quantizable_tensor_setting( + node, "state", matching_keys[-1], op_state_spec + ) + input_qspec_map[node] = None continue _validate_state_referenced_as_input(node, idx, op_input_spec) - if _is_state_node(node): - _fill_input_qspec_map_for_state(input_qspec_map, node, op_state_spec) + + if is_coreai_compressed_state_node(node): + _fill_input_qspec_map_for_state(input_qspec_map, node, op_state_spec, context) else: _fill_input_qspec_map_for_input(input_qspec_map, node, idx, op_input_spec) return input_qspec_map -def _get_spec_for_tensor( - idx_or_name: int | str, op_spec: dict[int | str, TorchAOQuantizationSpec | None] -) -> TorchAOQuantizationSpec | None: - """ - Get the spec for a tensor from op_spec. - - First check for an exact identifier match (index or state name). If there is not - one, use the spec for "*" if possible. Return None if no applicable match is found. - """ - if idx_or_name in op_spec: - return op_spec[idx_or_name] - if _ALL_TENSORS in op_spec: - return op_spec[_ALL_TENSORS] - return None - - def _fill_input_qspec_map_for_state( input_qspec_map: dict[torch.fx.Node, TorchAOQuantizationSpec | None], state_node: torch.fx.Node, op_state_spec: dict[str, TorchAOQuantizationSpec | None], + context: AnnotationContext, ) -> None: """ Fill input_qspec_map with state_node as the key. @@ -817,12 +759,13 @@ def _fill_input_qspec_map_for_state( """ found, spec = _get_state_node_shared_spec(state_node) if not found: - state_name = _get_local_state_name(state_node) + state_name = get_local_state_name(state_node) if state_name is None: # Already compressed state (e.g., lut_to_dense from palettization) - don't quantize spec = None else: - spec = _get_spec_for_tensor(state_name, op_state_spec) + state_names = _get_state_aliases(state_node, context.module_name_to_state_names_map) + spec, _ = get_last_matching_spec(state_names, op_state_spec) input_qspec_map[state_node] = spec @@ -862,7 +805,8 @@ def _fill_input_qspec_map_for_input( # Check if any qspec is already set from a parent node output. If so, simply # use that spec. if not is_node_annotated(input_node) or input_node.meta[Q_ANNOTATION_KEY].output_qspec is None: - input_qspec_map[input_node] = _get_spec_for_tensor(idx, op_input_spec) + spec, _ = get_last_matching_spec([idx], op_input_spec) + input_qspec_map[input_node] = spec else: input_qspec_map[input_node] = input_node.meta[Q_ANNOTATION_KEY].output_qspec @@ -890,7 +834,7 @@ def _get_output_qspec( return None # First read qspec from config without applying it yet. - qspec_from_config = _get_spec_for_tensor(0, op_output_spec) + qspec_from_config, _ = get_last_matching_spec([0], op_output_spec) # Don't set output qspec if it is specified to be None. If the op has multiple child # ops where a subset of child ops don't have input quantization, we should not @@ -1039,7 +983,7 @@ def match_pattern_with_subgraph_matcher( def annotate_weighted_mod_match( annotator_match: InternalMatch, quantization_config: AnnotationConfig, - shared_observer_nodes: set[torch.fx.Node], + context: AnnotationContext, ) -> None: """ Try to annotate specific nodes in the model designated by ``annotator_match`` using @@ -1070,7 +1014,12 @@ def annotate_weighted_mod_match( if is_any_annotated(partition): return - input_qspec_map = _get_input_qspec_map(mod_node.all_input_nodes, quantization_config) + shared_observer_nodes = context.shared_observer_nodes + input_qspec_map = _get_input_qspec_map( + mod_node.all_input_nodes, + quantization_config, + context, + ) output_qspec = _get_output_qspec( output_node or mod_node, quantization_config, shared_observer_nodes ) @@ -1093,7 +1042,7 @@ def annotate_weighted_mod_match( def annotate_n_ary_act_match( annotator_match: tuple[SourcePartition], quantization_config: AnnotationConfig, - shared_observer_nodes: set[torch.fx.Node], + context: AnnotationContext, ) -> None: """ Try to annotate specific nodes in the model designated by ``annotator_match`` using @@ -1116,7 +1065,12 @@ def annotate_n_ary_act_match( # TODO: skip partition if any intermediate node output is used by an op outside the pattern. - input_qspec_map = _get_input_qspec_map(first_op_node.all_input_nodes, quantization_config) + shared_observer_nodes = context.shared_observer_nodes + input_qspec_map = _get_input_qspec_map( + first_op_node.all_input_nodes, + quantization_config, + context, + ) output_qspec = _get_output_qspec(last_op_node, quantization_config, shared_observer_nodes) if len(nodes_to_annotate) == 1: first_op_node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( @@ -1161,7 +1115,7 @@ def _adjust_input_qspec_map_for_shared_observers( def annotate_shared_observer_match( annotator_match: tuple[SourcePartition], quantization_config: AnnotationConfig, - shared_observer_nodes: set[torch.fx.Node], + context: AnnotationContext, ) -> None: """ Try to annotate specific nodes in the model designated by ``annotator_match`` using @@ -1178,7 +1132,12 @@ def annotate_shared_observer_match( if is_node_annotated(op_node): return - input_qspec_map = _get_input_qspec_map(op_node.all_input_nodes, quantization_config) + shared_observer_nodes = context.shared_observer_nodes + input_qspec_map = _get_input_qspec_map( + op_node.all_input_nodes, + quantization_config, + context, + ) output_qspec = _adjust_input_qspec_map_for_shared_observers(op_node, input_qspec_map) if output_qspec is None: @@ -1276,7 +1235,7 @@ def forward(self, inp): configured with outer_param's spec. """ for node in model.graph.nodes: - if _is_state_node(node): + if is_coreai_compressed_state_node(node): _match_and_annotate_state_node(node, module_configs, module_name_to_state_names_map) @@ -1409,7 +1368,7 @@ def _annotate_nodes_for_module_level_input_output_spec( module_config: The module quantizer config containing module-level specs nodes_in_module: List of nodes present in the module being annotated """ - (input_consumer_tuples, outputs) = _get_module_boundary_nodes(nodes_in_module) + (input_consumer_tuples, outputs) = get_module_boundary_nodes(nodes_in_module) # Annotate module inputs if module_config.module_input_spec: @@ -1456,38 +1415,3 @@ def _find_and_apply_module_level_spec( _annotate_node_input_qspec(node_to_annotate, input_node, converted_spec) else: _annotate_node_output_qspec(node_to_annotate, converted_spec) - - -def _get_module_boundary_nodes( - nodes_in_module: list[torch.fx.Node], -) -> tuple[list[tuple[torch.fx.Node, torch.fx.Node]], list[torch.fx.Node]]: - """ - Get all input and output nodes for a module. - - Args: - nodes_in_module: List of nodes in the module - - Returns: - A tuple of (input_consumer_tuples, outputs) - where: - - input_consumer_tuples: List of (input_node, consumer_node) for inputs from - outside the module - - outputs: List of output nodes leading to nodes outside the module - """ - input_consumer_tuples: list[tuple[torch.fx.Node, torch.fx.Node]] = [] - outputs: list[torch.fx.Node] = [] - nodes_in_module_set = set(nodes_in_module) - - for node in nodes_in_module: - # Processing for module inputs and states - for input_node in node.all_input_nodes: - if not _is_state_node(input_node) and input_node not in nodes_in_module_set: - input_consumer_tuples.append((input_node, node)) - - # Processing for module outputs - for user in node.users: - if user not in nodes_in_module_set: - outputs.append(node) - break - - return input_consumer_tuples, outputs diff --git a/src/coreai_opt/quantization/_graph/quantizer.py b/src/coreai_opt/quantization/_graph/quantizer.py index 9080cd0..fe6de4b 100644 --- a/src/coreai_opt/quantization/_graph/quantizer.py +++ b/src/coreai_opt/quantization/_graph/quantizer.py @@ -21,7 +21,7 @@ from contextlib import contextmanager from enum import Enum, auto from os import PathLike -from typing import Any, TypeAlias +from typing import Any, NamedTuple, TypeAlias import torch import torchao @@ -40,12 +40,14 @@ ALL_TENSORS as _ALL_TENSORS, ConfigLevel as _ConfigLevel, ) +from coreai_opt._utils.fx_utils import ( + get_node_type as _get_node_type, + normalize_module_fqn, +) from coreai_opt._utils.torch_utils import ( export_model as _export_model, - get_node_type as _get_node_type, move_model_to_eval, move_model_to_train, - normalize_module_fqn, ) from coreai_opt._utils.version_utils import version_ge as _version_ge from coreai_opt.common import ExportBackend @@ -65,7 +67,7 @@ FakeQuantizeImplBase, ) -from ._annotation_config import AnnotationConfig +from ._annotation_config import AnnotationConfig, AnnotationContext from ._annotation_pattern_registry import ( AnnotatorMatchInfo as _AnnotatorMatchInfo, SharedObserverModulePattern as _SharedObserverModulePattern, @@ -116,7 +118,41 @@ def priority_order(cls) -> list[_OpConfigLevel]: return list(cls) -NodeConfigDict: TypeAlias = dict[_OpConfigLevel, dict[torch.fx.Node, OpQuantizerConfig]] +class _NodePriorityConfig(NamedTuple): + """Config attached to a node, paired with its priority within a config level. + + Attributes: + config (OpQuantizerConfig): The op-level config to apply at this node. + priority (int): Position of the matching module in + ``module_config_dict[level]``. Lower = higher precedence within + the level (matches eager-mode ``module_priority_dict`` semantics: + ``build_module_config_dict`` processes user configs in reverse, so + the last-listed user config claims modules first and gets the + smallest index). + """ + + config: OpQuantizerConfig + priority: int + + +class _RankedAnnotation(NamedTuple): + """One annotator match ranked for the priority sort. + + Attributes: + node (torch.fx.Node): The node being annotated. + config (OpQuantizerConfig): The op-level config to apply. + match (_AnnotatorMatchInfo): The annotator match info for ``node``. + priority (int): Within-level priority carried over from + :class:`_NodePriorityConfig`. + """ + + node: torch.fx.Node + config: OpQuantizerConfig + match: _AnnotatorMatchInfo + priority: int + + +NodeConfigDict: TypeAlias = dict[_OpConfigLevel, dict[torch.fx.Node, _NodePriorityConfig]] class _AnnotationHandler(TorchPT2EQuantizer): @@ -208,12 +244,19 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: # Annotation phase - go through sorted nodes with matches list to annotate shared_observer_nodes = self._get_shared_observer_nodes(model) + # Build pass-invariant context once; shared by all annotator invocations. + context = AnnotationContext( + module_name_to_state_names_map=self._module_name_to_state_names_map, + shared_observer_nodes=shared_observer_nodes, + ) for node, config, annotator_match_info in sorted_nodes_with_annotation_match_info: if is_node_annotated(node): continue annotation_config = AnnotationConfig.from_quantizer_config(config) annotator_match_info.annotator_func( - annotator_match_info.annotator_match, annotation_config, shared_observer_nodes + annotator_match_info.annotator_match, + annotation_config, + context, ) _annotate_module_level_specs( @@ -266,6 +309,7 @@ def _sort_nodes_in_annotation_order( The list is sorted with the following criteria, in decreasing priority: - Config type (module_name > module_type > global) - Pattern length (Longer pattern > shorter pattern) + - Config index within a config level (Later config > earlier config) - Topological ordering in the model (Earlier in the graph > later in the graph) """ config_level_node_dicts = self._get_config_level_node_dicts( @@ -306,6 +350,13 @@ def _get_config_level_node_dicts( module_type_node_config_dict: NodeConfigDict = {level: {} for level in _OpConfigLevel} module_name_node_config_dict: NodeConfigDict = {level: {} for level in _OpConfigLevel} + # Precompute {canonical_key: insertion_index} once per config level so that + # _set_config_to_use_for_node can look up a key's priority in O(1) + config_key_index: dict[_ConfigLevel, dict[object, int]] = { + level: {key: idx for idx, key in enumerate(self._module_configs[level].keys())} + for level in (_ConfigLevel.MODULE_NAME, _ConfigLevel.MODULE_TYPE) + } + # Iterating through the nodes in topological ordering guarantees that when # sorting nodes later, any nodes with identical config priority and pattern # lengths will remain ordered by topological ordering (essentially the last @@ -314,13 +365,19 @@ def _get_config_level_node_dicts( if node in node_to_annotator_match_info_dict: # Try to find a config to set for the node for module_name config level if self._set_config_to_use_for_node( - node, module_name_node_config_dict, _ConfigLevel.MODULE_NAME + node, + module_name_node_config_dict, + _ConfigLevel.MODULE_NAME, + config_key_index[_ConfigLevel.MODULE_NAME], ): continue # Try to find a config to set for the node for module_type config level if self._set_config_to_use_for_node( - node, module_type_node_config_dict, _ConfigLevel.MODULE_TYPE + node, + module_type_node_config_dict, + _ConfigLevel.MODULE_TYPE, + config_key_index[_ConfigLevel.MODULE_TYPE], ): continue @@ -329,16 +386,28 @@ def _get_config_level_node_dicts( global_config = list(self._module_configs[_ConfigLevel.GLOBAL].values())[0] config, op_config_level = self._get_config_for_node(node, global_config) - global_node_config_dict[op_config_level][node] = config + # GLOBAL level has a single config, so priority is trivially 0. + global_node_config_dict[op_config_level][node] = _NodePriorityConfig( + config, priority=0 + ) return (module_name_node_config_dict, module_type_node_config_dict, global_node_config_dict) def _set_config_to_use_for_node( - self, node: torch.fx.Node, node_config_dict: NodeConfigDict, config_level: _ConfigLevel + self, + node: torch.fx.Node, + node_config_dict: NodeConfigDict, + config_level: _ConfigLevel, + config_key_index: dict[object, int], ) -> bool: """ Add a node to config entry for node_config_dict for the given config_level if applicable. Returns True if a config was set, False otherwise. + + The stored entry pairs the matched config with its position in + ``self._module_configs[config_level]``. That position is used as a + within-level priority during the sort phase: lower index = higher + precedence. """ qualified_name = _get_source_module_name(node) if qualified_name is None: @@ -352,8 +421,9 @@ def _set_config_to_use_for_node( return False config_to_use = self._module_configs[config_level][canonical] + config_idx = config_key_index[canonical] config_to_use, op_config_level = self._get_config_for_node(node, config_to_use) - node_config_dict[op_config_level][node] = config_to_use + node_config_dict[op_config_level][node] = _NodePriorityConfig(config_to_use, config_idx) return True @staticmethod @@ -391,7 +461,7 @@ def _get_config_for_node( def _expand_and_sort_nodes_for_pattern_length( self, - node_to_config_dict: dict[torch.fx.Node, OpQuantizerConfig], + node_to_config_dict: dict[torch.fx.Node, _NodePriorityConfig], node_to_annotator_match_info_dict: dict[torch.fx.Node, list[_AnnotatorMatchInfo]], ) -> list[tuple[torch.fx.Node, OpQuantizerConfig, _AnnotatorMatchInfo]]: """ @@ -444,19 +514,18 @@ def _expand_and_sort_nodes_for_pattern_length( A list of lists of (node, config, annotation match info) ordered by priority. """ - nodes_with_annotation_info: list[ - tuple[torch.fx.Node, OpQuantizerConfig, _AnnotatorMatchInfo] - ] = [ - (node, config, annotator_match_info) - for node, config in node_to_config_dict.items() - for annotator_match_info in node_to_annotator_match_info_dict[node] + nodes_with_annotation_info: list[_RankedAnnotation] = [ + _RankedAnnotation(node, entry.config, match, entry.priority) + for node, entry in node_to_config_dict.items() + for match in node_to_annotator_match_info_dict[node] ] - # Sort node - nodes_with_annotation_info = sorted( - nodes_with_annotation_info, key=lambda item: item[-1].pattern_length, reverse=True - ) + # Higher pattern_length wins; within equal length, lower priority wins + # (later-listed user configs claim modules first, so they get the + # smaller index in module_config_dict). Stable sort preserves + # topological order as the final tiebreaker. + nodes_with_annotation_info.sort(key=lambda r: (-r.match.pattern_length, r.priority)) - return nodes_with_annotation_info + return [(r.node, r.config, r.match) for r in nodes_with_annotation_info] def validate(self, model: torch.fx.GraphModule) -> None: """ @@ -645,24 +714,24 @@ def _get_module_name_to_state_names_map( # inner_model_1.c = inner_model_2.a # Then we would have: - # module_name_to_state_names[inner_model_1]["inner_model_1.a"] = ["a", "b"] - # module_name_to_state_names[inner_model_1]["inner_model_1.b"] = ["a", "b"] - # module_name_to_state_names[inner_model_1]["inner_model_2.b"] = ["a", "b"] - # module_name_to_state_names[inner_model_1]["inner_model_1.c"] = ["c"] - # module_name_to_state_names[inner_model_1]["inner_model_2.a"] = ["c"] - - # module_name_to_state_names[inner_model_2]["inner_model_2.b"] = ["b"] - # module_name_to_state_names[inner_model_2]["inner_model_1.a"] = ["b"] - # module_name_to_state_names[inner_model_2]["inner_model_1.b"] = ["b"] - # module_name_to_state_names[inner_model_2]["inner_model_1.c"] = ["a"] - # module_name_to_state_names[inner_model_2]["inner_model_2.a"] = ["a"] + # module_name_to_state_names["inner_model_1"]["inner_model_1.a"] = ["a", "b"] + # module_name_to_state_names["inner_model_1"]["inner_model_1.b"] = ["a", "b"] + # module_name_to_state_names["inner_model_1"]["inner_model_2.b"] = ["a", "b"] + # module_name_to_state_names["inner_model_1"]["inner_model_1.c"] = ["c"] + # module_name_to_state_names["inner_model_1"]["inner_model_2.a"] = ["c"] + + # module_name_to_state_names["inner_model_2"]["inner_model_2.b"] = ["b"] + # module_name_to_state_names["inner_model_2"]["inner_model_1.a"] = ["b"] + # module_name_to_state_names["inner_model_2"]["inner_model_1.b"] = ["b"] + # module_name_to_state_names["inner_model_2"]["inner_model_1.c"] = ["a"] + # module_name_to_state_names["inner_model_2"]["inner_model_2.a"] = ["a"] # Observe that since "inner_model_1.a", "inner_model_1.b", and "inner_model_2.a" # all refer to the same parameter object, both inner_model_1 and inner_model_2 # contain all 3 of these full names as keys. However, from the perspective of # inner_model_1, there are only two local names which would point to this param: # "a" and "b". Thus all 3 full names are associated with ["a", "b"] in - # module_name_to_state_names[inner_model_1]. + # module_name_to_state_names["inner_model_1"]. # From inner_model_2's perspective, the same parameter would be referenced by # local name "b" only, so all 3 full names map to ["b"]. diff --git a/src/coreai_opt/quantization/_utils.py b/src/coreai_opt/quantization/_utils.py index b50cc7f..4d72b6b 100644 --- a/src/coreai_opt/quantization/_utils.py +++ b/src/coreai_opt/quantization/_utils.py @@ -5,7 +5,6 @@ """Quantization utilities and helper functions.""" - import torch from torchao.quantization.quant_primitives import _get_reduction_params @@ -41,9 +40,7 @@ def get_quantization_shapes( """ original_shape = tensor.shape - blockwise_shape, reduction_dims = _get_reduction_params( - block_size, tensor.size() - ) + blockwise_shape, reduction_dims = _get_reduction_params(block_size, tensor.size()) reduced_shape = list(blockwise_shape) for i in reduction_dims: reduced_shape[i] = 1 diff --git a/src/coreai_opt/quantization/quantizer.py b/src/coreai_opt/quantization/quantizer.py index 1283247..a6de40f 100644 --- a/src/coreai_opt/quantization/quantizer.py +++ b/src/coreai_opt/quantization/quantizer.py @@ -10,7 +10,6 @@ from os import PathLike from typing import Any -import torch import torch.nn as nn from torch import fx from torchao.quantization.pt2e import ( @@ -35,6 +34,7 @@ QuantizerConfig, ) from coreai_opt.quantization.spec.fake_quantize import FakeQuantizeImplBase +from coreai_opt.quantization.spec.qparams_calculator import StatelessQParamsCalculatorBase class Quantizer(_BaseQuantizer): @@ -180,32 +180,6 @@ def _get_fake_quantize_modules(self) -> dict[str, list]: """Delegate to the underlying execution-mode quantizer.""" return self._quantizer._get_fake_quantize_modules() - @classmethod - def get_compressible_op_names( - cls, - model: nn.Module | torch.fx.GraphModule, - execution_mode: ExecutionMode, - ) -> set[str]: - """Return op names in *model* that this quantizer can target. - - Dispatches to the appropriate underlying quantizer based on - *execution_mode*. - - Args: - model (nn.Module): The model to get compressible op names for. - execution_mode (ExecutionMode): The execution mode. - - Returns: - set[str]: Op names that can be compressed via quantization. - """ - if execution_mode == ExecutionMode.GRAPH: - return _GraphQuantizer.get_compressible_op_names(model) - if execution_mode == ExecutionMode.EAGER: - return _EagerQuantizer.get_compressible_op_names(model) - - msg = f"Unknown execution_mode {execution_mode}. Expected 'graph' or 'eager'." - raise ValueError(msg) - def _resolve_schedule(self, module_name: str) -> QATSchedule | None: """Look up the QAT schedule for a module via the config hierarchy.""" for level in _ConfigLevel.priority_order(): @@ -432,6 +406,31 @@ def _validate_mmap_dir_constraints( model_to_check = model if model is not None else self._model _validate_mmap_backend_and_device(model_to_check, backend, mmap_dir) + def _validate_no_persistent_observer_calculators( + self, + model: nn.Module | fx.GraphModule | None, + backend: ExportBackend, + ) -> None: + """Reject CoreAI/CoreML export when any qparams calculator is a + ``StatelessQParamsCalculatorBase`` (e.g. dynamic quantization). + """ + if backend == ExportBackend._TORCH: + return + model_to_check = model if model is not None else self._model + stateless_fq_names = [ + name + for name, mod in model_to_check.named_modules() + if isinstance(mod, FakeQuantizeImplBase) + and isinstance(mod.qparams_calculator, StatelessQParamsCalculatorBase) + ] + if stateless_fq_names: + raise NotImplementedError( + f"backend={backend} does not yet support qparams calculators that " + f"recompute every forward (e.g. dynamic quantization). " + f"Affected FakeQuantize modules: {stateless_fq_names}. Use " + f"backend=ExportBackend._TORCH for torch-only inference." + ) + def finalize( self, model: nn.Module | fx.GraphModule | None = None, @@ -480,6 +479,7 @@ def finalize( finalize frees the original dense weights. """ self._validate_mmap_dir_constraints(model, backend, mmap_dir) + self._validate_no_persistent_observer_calculators(model, backend) return self._quantizer.finalize(model, backend, mmap_dir=mmap_dir) @contextmanager diff --git a/src/coreai_opt/quantization/spec/__init__.py b/src/coreai_opt/quantization/spec/__init__.py index 0af6b80..fb329f3 100644 --- a/src/coreai_opt/quantization/spec/__init__.py +++ b/src/coreai_opt/quantization/spec/__init__.py @@ -14,10 +14,13 @@ ) from .qformulation import QuantizationFormulation from .qparams_calculator import ( + DynamicQParamsCalculator, GlobalMinMaxQParamsCalculator, MovingAverageQParamsCalculator, QParamsCalculatorBase, RunningRangeMixin, + StatefulQParamsCalculatorBase, + StatelessQParamsCalculatorBase, StaticQParamsCalculator, ) from .qscheme import QuantizationScheme @@ -29,6 +32,7 @@ ) __all__ = [ + "DynamicQParamsCalculator", "GlobalMinMaxQParamsCalculator", "MinMaxRangeCalculator", "MovingAverageQParamsCalculator", @@ -43,6 +47,8 @@ "QuantizationSpec", "RangeCalculatorBase", "RunningRangeMixin", + "StatefulQParamsCalculatorBase", + "StatelessQParamsCalculatorBase", "StaticQParamsCalculator", "default_activation_quantization_spec", "default_weight_quantization_spec", diff --git a/src/coreai_opt/quantization/spec/errors.py b/src/coreai_opt/quantization/spec/errors.py index ebe99ea..2d537e3 100644 --- a/src/coreai_opt/quantization/spec/errors.py +++ b/src/coreai_opt/quantization/spec/errors.py @@ -3,5 +3,6 @@ # Use of this source code is governed by a BSD-3-Clause license that can # be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + class _BlockSizeMismatchError(ValueError): """Raised when a tensor dimension is not divisible by the block size.""" diff --git a/src/coreai_opt/quantization/spec/factory.py b/src/coreai_opt/quantization/spec/factory.py index 30f2ef4..9a449b0 100644 --- a/src/coreai_opt/quantization/spec/factory.py +++ b/src/coreai_opt/quantization/spec/factory.py @@ -13,6 +13,7 @@ from .fake_quantize import FakeQuantizeImplBase from .qparams_calculator import ( + DynamicQParamsCalculator, MovingAverageQParamsCalculator, QParamsCalculatorBase, StaticQParamsCalculator, @@ -83,6 +84,15 @@ def create_qparams_calculator( f"Expected WEIGHT, ACTIVATION, or LUT." ) + if ( + qparam_calculator_cls is DynamicQParamsCalculator + and quantization_target != CompressionTargetTensor.ACTIVATION + ): + raise ValueError( + f"DynamicQParamsCalculator is only supported for activation " + f"quantization, got quantization_target={quantization_target}." + ) + # Create range calculator first range_calculator = cls.create_range_calculator(spec) diff --git a/src/coreai_opt/quantization/spec/fake_quantize.py b/src/coreai_opt/quantization/spec/fake_quantize.py index ed07ced..aafd476 100644 --- a/src/coreai_opt/quantization/spec/fake_quantize.py +++ b/src/coreai_opt/quantization/spec/fake_quantize.py @@ -31,7 +31,7 @@ from .granularity import QuantizationGranularity from .qformulation import QuantizationFormulation -from .qparams_calculator import QParamsCalculatorBase +from .qparams_calculator import QParamsCalculatorBase, StatelessQParamsCalculatorBase from .qscheme import QuantizationScheme __all__ = ["FakeQuantizeImplBase"] @@ -98,6 +98,30 @@ def is_disabled(self) -> bool: """Return True if fake quantization has been disabled.""" return self._disabled.item() + def disable_observer(self) -> None: + """Disable the observer, unless the qparams calculator is stateless. + + Applies to **any** caller (direct, ``apply(disable_observer)``, + ``convert_pt2e``, QAT scheduling). Stateless calculators recompute per + forward and need ``observer_enabled=1`` permanently — ``forward`` uses + that flag to route between live recompute and the stateful + ``get_qparams()`` cache (which stateless doesn't have). + """ + if isinstance(self.qparams_calculator, StatelessQParamsCalculatorBase): + return + super().disable_observer() + + def enable_observer(self, enabled: bool = True) -> None: + """Inverse of ``disable_observer``: ignore ``enabled=False`` when the + qparams calculator is stateless. Covers callers that invoke + ``enable_observer(False)`` directly (e.g. the QAT scheduler at + ``quantizer.py:_maybe_apply_qat_schedule``); ``disable_observer()`` + itself routes through the override above. + """ + if not enabled and isinstance(self.qparams_calculator, StatelessQParamsCalculatorBase): + return + super().enable_observer(enabled) + def _warn_and_disable(self, error: _BlockSizeMismatchError) -> None: """Log a warning and permanently disable this module.""" logger.warning( diff --git a/src/coreai_opt/quantization/spec/qparams_calculator.py b/src/coreai_opt/quantization/spec/qparams_calculator.py index 48286fa..0e9faa7 100644 --- a/src/coreai_opt/quantization/spec/qparams_calculator.py +++ b/src/coreai_opt/quantization/spec/qparams_calculator.py @@ -29,14 +29,19 @@ class QParamsCalculatorBase(_ClassRegistryMixin, nn.Module): - """ - Base class for implementing logic to calculate quantization parameters - (scale, zero_point, minval) given min/max values. - """ + """Abstract base for qparams calculators — common configuration and helpers. - scale: torch.Tensor - zero_point: torch.Tensor | None - minval: torch.Tensor | None + Concrete subclasses inherit from either: + + - ``StatefulQParamsCalculatorBase`` — has scale/zp/minval buffers; used by + Static, MovingAverage, GlobalMinMax. + - ``StatelessQParamsCalculatorBase`` — no buffers, qparams recomputed per + forward; used by Dynamic. ``FakeQuantizeImplBase`` and ``Quantizer`` + detect this subclass to keep the observer always on and to reject + export. + + Subclasses must implement ``forward(tensor) -> (scale, zero_point, minval)``. + """ def __init__( self, @@ -62,17 +67,7 @@ def __init__( self.range_calculator = range_calculator self.float_range = float_range - self.register_buffer("scale", torch.empty(0)) - - if dtype.is_floating_point: - self.register_buffer("zero_point", None) - self.register_buffer("minval", None) - else: - self.register_buffer("zero_point", torch.empty(0, dtype=torch.int32)) - self.register_buffer("minval", torch.empty(0)) - self._initialized = False - self._export_mode = False # This is added to address MLIR limitation where # tensor after q-dq op is not casted to incoming tensor dtype @@ -258,7 +253,7 @@ def _compute_scale_zero_point_minval( minval = torch.min(min_val, torch.zeros_like(min_val)) # For FP dtypes, neither zero_point nor minval is used (symmetric - # quantization with no offset). The buffers are registered as None. + # quantization with no offset). if self.dtype.is_floating_point: zero_point = None minval = None @@ -285,19 +280,50 @@ def compute_qparams( ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: """Given the observed min/max range, return ``(scale, zero_point, minval)``. - The default implementation directly computes qparams from the given - range via ``_compute_scale_zero_point_minval``. This is the correct behavior - for stateless calculators (e.g. ``StaticQParamsCalculator``). - - Stateful calculators override this via ``RunningRangeMixin`` to update - running-range buffers before computing qparams from the smoothed range. + Default implementation: pure function of the supplied range, no running state. + ``RunningRangeMixin`` overrides this to apply a running-range smoothing rule + before computing qparams. """ return self._compute_scale_zero_point_minval(tensor, min_val, max_val) + @abstractmethod def forward( self, tensor: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: - """Compute scale, zero point, and minval from the input tensor. + """Compute and return ``(scale, zero_point, minval)`` for ``tensor``.""" + + +class StatefulQParamsCalculatorBase(QParamsCalculatorBase): + """Stateful base: maintains scale/zero_point/minval as nn.Module buffers + across forwards. + + Buffer shapes are allocated on first forward and must remain stable + (``copy_`` requires shape compatibility) — use + ``StatelessQParamsCalculatorBase`` for variable-shape scales. + """ + + scale: torch.Tensor + zero_point: torch.Tensor | None + minval: torch.Tensor | None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.register_buffer("scale", torch.empty(0)) + + if self.dtype.is_floating_point: + self.register_buffer("zero_point", None) + self.register_buffer("minval", None) + else: + self.register_buffer("zero_point", torch.empty(0, dtype=torch.int32)) + self.register_buffer("minval", torch.empty(0)) + + self._export_mode = False + + def forward( + self, tensor: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + """Compute qparams from ``tensor``; cache to buffers; return. On the first forward pass, initializes internal buffers using the observed tensor shape and device. Delegates the actual qparams @@ -377,6 +403,64 @@ def extra_repr(self) -> str: ) +class StatelessQParamsCalculatorBase(QParamsCalculatorBase): + """Stateless base: no cached qparams; recomputed every forward. + + Used for dynamic quantization where activations vary per inference and the + scale shape may change across forwards (e.g. LLM token-wise with variable + sequence length). ``self.scale``/``zero_point``/``minval`` are assigned in + forward as plain Python attributes (not buffers) for debugging visibility + only — they reflect the most recent forward and are not in ``state_dict``. + + - ``FakeQuantizeImplBase`` keeps ``observer_enabled = 1`` for this subclass + so the recompute path stays live. + - ``get_qparams`` is undefined; ``set_export_mode(True)`` raises; + ``float_range=[None, None]`` is required. + """ + + def __init__(self, **kwargs): + float_range = kwargs.get("float_range", (None, None)) + if float_range[0] is not None or float_range[1] is not None: + raise ValueError( + f"StatelessQParamsCalculatorBase requires float_range=[None, None]; " + f"got {float_range}. Bounded ranges contradict the per-forward " + f"recompute contract." + ) + super().__init__(**kwargs) + + def forward( + self, tensor: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + """Compute qparams from ``tensor`` and return; no buffer state.""" + self._resolve_axis(tensor.ndim) + + min_val, max_val = self._get_min_and_max_val(tensor) + + if not self._initialized: + self._compute_dtype_for_export = tensor.dtype + self._initialized = True + + scale, zero_point, minval = self.compute_qparams(tensor, min_val, max_val) + + # Plain-attribute (not buffer) assignment for debugging visibility. + # Shape-mobile across forwards; not in state_dict. + self.scale = scale.detach() + self.zero_point = zero_point.detach() if zero_point is not None else None + self.minval = minval.detach() if minval is not None else None + + return scale, zero_point, minval + + def set_export_mode(self, enabled: bool = True) -> None: + if enabled: + raise NotImplementedError( + "Stateless quantization (e.g. dynamic) does not support export mode; " + "qparams are input-dependent and cannot be frozen for export." + ) + # ``enabled=False`` is a deliberate no-op: stateless calculators have no + # ``_export_mode`` attribute (defined only on StatefulQParamsCalculatorBase), + # and there is nothing to disable. + + @QParamsCalculatorBase.register("default") class _DefaultQParamsCalculator(QParamsCalculatorBase): """ @@ -399,11 +483,16 @@ def __init__(self, **kwargs): "based on quantization target (weight or activation)." ) + def forward( + self, tensor: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + raise RuntimeError("_DefaultQParamsCalculator.forward() should never be called") + @QParamsCalculatorBase.register("static") -class StaticQParamsCalculator(QParamsCalculatorBase): +class StaticQParamsCalculator(StatefulQParamsCalculatorBase): """ - Computes scale and zero point using min/max values from the current tensor. + Computes scale/zero-point/minval using min/max values from the current tensor. This QParamsCalculator directly uses the min/max range from each forward pass to compute quantization parameters. So in that sense, it does not maintain any "history" and @@ -421,6 +510,19 @@ class StaticQParamsCalculator(QParamsCalculatorBase): # which directly computes qparams from current min/max with no running state. +@QParamsCalculatorBase.register("dynamic") +class DynamicQParamsCalculator(StatelessQParamsCalculatorBase): + """ + Dynamically computes scale/zero-point/minval from the current tensor every forward. + + Typically used for activation quantization where activations vary per + inference and there is no calibration phase. Supports variable-shape scales + (e.g. LLM token-wise quantization with variable sequence length) since no + nn.Module buffers are allocated — see ``StatelessQParamsCalculatorBase`` + for the full stateless contract. + """ + + # ``# type: ignore`` comments are used where the mixin accesses # attributes and methods provided by ``QParamsCalculatorBase`` / # ``nn.Module``, which mypy cannot resolve from the mixin class alone. @@ -436,7 +538,7 @@ class RunningRangeMixin: parameters but with different ways of updating the running statistics can override the ``update_running_range`` method. - Must appear before ``QParamsCalculatorBase`` in the MRO so that its + Must appear before ``StatefulQParamsCalculatorBase`` in the MRO so that its ``compute_qparams`` and ``_initialize_state`` take precedence over the base-class defaults. """ @@ -492,7 +594,7 @@ def extra_repr(self) -> str: @QParamsCalculatorBase.register("moving_average") -class MovingAverageQParamsCalculator(RunningRangeMixin, QParamsCalculatorBase): +class MovingAverageQParamsCalculator(RunningRangeMixin, StatefulQParamsCalculatorBase): """ Computes the scale and zero point using a moving average of the range. @@ -546,7 +648,7 @@ def update_running_range( @QParamsCalculatorBase.register("global_minmax") -class GlobalMinMaxQParamsCalculator(RunningRangeMixin, QParamsCalculatorBase): +class GlobalMinMaxQParamsCalculator(RunningRangeMixin, StatefulQParamsCalculatorBase): """Computes scale and zero point by tracking the running min/max. Maintains ``running_min`` and ``running_max`` buffers that are updated each diff --git a/src/coreai_opt/quantization/spec/range_calculator.py b/src/coreai_opt/quantization/spec/range_calculator.py index 488ec72..66a5adc 100644 --- a/src/coreai_opt/quantization/spec/range_calculator.py +++ b/src/coreai_opt/quantization/spec/range_calculator.py @@ -42,11 +42,9 @@ def _reshape_min_max(self, range_tensor: torch.Tensor, input_shape: torch.Size): # 2 to get [1, 5, 1, 1]. # In the end, each dimension in scale should have size equal to the number of # blocks for that dimension. - range_tensor_shape = \ - [input_shape[i] // block_size_list[i] for i in range(len(input_shape))] + range_tensor_shape = [input_shape[i] // block_size_list[i] for i in range(len(input_shape))] return range_tensor.reshape(range_tensor_shape) - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Compute range statistics on an input and return the min/max bounds. @@ -61,7 +59,6 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: max_tensor = self._reshape_min_max(max_tensor, x.shape) return min_tensor, max_tensor - @abstractmethod def _generate_min_max(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Compute the lower and upper bound of the range. @@ -79,12 +76,9 @@ class MinMaxRangeCalculator(RangeCalculatorBase): values of the tensor. """ - def _generate_min_max(self, tensor: torch.Tensor) -> \ - tuple[torch.Tensor, torch.Tensor]: + def _generate_min_max(self, tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: block_size_list = self.granularity.get_block_size(tensor.shape) - shape_for_reduction, reduction_dims = _get_reduction_params( - block_size_list, tensor.size() - ) + shape_for_reduction, reduction_dims = _get_reduction_params(block_size_list, tensor.size()) # If tensor is already the shape required, no minmaxing is needed. if len(reduction_dims) == 0: diff --git a/src/coreai_opt/quantization/spec/spec.py b/src/coreai_opt/quantization/spec/spec.py index 4a19c6d..52a8372 100644 --- a/src/coreai_opt/quantization/spec/spec.py +++ b/src/coreai_opt/quantization/spec/spec.py @@ -18,7 +18,6 @@ model_validator, ) -from coreai_opt._utils.registry_utils import ClassRegistryMixin from coreai_opt._utils.torch_utils import ( get_n_bits_from_dtype, is_float4_dtype as _is_float4_dtype, @@ -243,6 +242,9 @@ class QuantizationSpec(CompressionSpec): most recent calibration sample only - "moving_average": Uses exponential moving average for stability - "global_minmax": Tracks running min/max across all calibration samples + - "dynamic": Computes scale/zero/minval point on each forward pass from the + current tensor — no calibration. Only valid for activation quantization + (rejected by the factory for weights/LUT). - Custom registered class string name - coreai_opt.quantization.qparams_calculator.QParamsCalculatorBase class type: StaticQParamsCalculator, @@ -421,51 +423,10 @@ def validate_dtype(cls, dtype: torch.dtype) -> torch.dtype: raise ValueError(error_msg) return dtype - @staticmethod - def _convert_with_registry(data: str | type, registry_class: type[ClassRegistryMixin]) -> type: - """ - Convert string or type to a registered class from the given registry. - - Args: - data: Either a string key or a class type - registry_class: The registry class to look up the key/type in - - Returns: - The registered class type - - Raises: - ValueError: If the key is not found in registry or type is not registered - TypeError: If data is neither string nor type - """ - if isinstance(data, str): - try: - return registry_class.get_class(data) - except KeyError as err: - available_keys = registry_class.list_registry_keys() - raise ValueError( - f"No class is registered with key: '{data}' " - f"in registry {registry_class.__name__}. " - f"Available keys: {sorted(available_keys)}" - ) from err - elif isinstance(data, type): - if data in registry_class.list_registry_values(): - return data - else: - available_classes = [cls.__name__ for cls in registry_class.list_registry_values()] - raise ValueError( - f"Class {data.__name__} is not registered in " - f"{registry_class.__name__}. " - f"Available classes: {sorted(available_classes)}" - ) - else: - raise TypeError( - f"Expected str or type for registry lookup, got {type(data).__name__}: {data}" - ) - @field_validator("range_calculator_cls", mode="before") @classmethod def convert_range_calculator(cls, data: Any) -> type[RangeCalculatorBase]: - return cls._convert_with_registry(data, RangeCalculatorBase) + return RangeCalculatorBase.resolve(data) @field_validator("float_range", mode="before") @classmethod @@ -499,12 +460,12 @@ def validate_float_range( @field_validator("qparam_calculator_cls", mode="before") @classmethod def convert_qparam_calculator(cls, data: Any) -> type[QParamsCalculatorBase]: - return cls._convert_with_registry(data, QParamsCalculatorBase) + return QParamsCalculatorBase.resolve(data) @field_validator("fake_quantize_cls", mode="before") @classmethod def convert_fake_quantize(cls, data: Any) -> type[FakeQuantizeImplBase]: - return cls._convert_with_registry(data, FakeQuantizeImplBase) + return FakeQuantizeImplBase.resolve(data) @model_validator(mode="before") @classmethod diff --git a/tests/conftest.py b/tests/conftest.py index f0b1ed8..e4f685d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,121 +7,26 @@ import random import tempfile -from dataclasses import dataclass -from typing import Any, Literal import numpy as np import pytest import torch -from coreai_opt import ExportBackend -from coreai_opt.palettization import ( - KMeansPalettizerConfig, - ModuleKMeansPalettizerConfig, -) -from coreai_opt.palettization.spec import ( - PalettizationSpec, - PerGroupedChannelGranularity, - PerTensorGranularity as PalettizationPerTensorGranularity, -) -from coreai_opt.palettization.spec.spec import _SUPPORTED_LUT_DTYPES -from coreai_opt.pruning import MagnitudePrunerConfig, ModuleMagnitudePrunerConfig, PruningSpec -from coreai_opt.pruning.spec import ChannelStructured, PruningScheme, Unstructured -from coreai_opt.quantization import ModuleQuantizerConfig, QuantizerConfig -from coreai_opt.quantization.spec import ( - PerBlockGranularity, - PerChannelGranularity, - PerTensorGranularity, - QuantizationScheme, - QuantizationSpec, -) -from coreai_opt.quantization.spec.fake_quantize import _DefaultFakeQuantizeImpl -from coreai_opt.quantization.spec.qparams_calculator import StaticQParamsCalculator -from coreai_opt.quantization.spec.range_calculator import MinMaxRangeCalculator -from tests.models.mnist import ( # noqa: F401 - custom_test_mnist_model, - mnist_data, - mnist_dataset, - mnist_example_input, - mnist_example_output, -) -from tests.models.resnet import ( # noqa: F401 - resnet18_model, - resnet50_model, - resnet_example_input, -) -from tests.models.simple import ( # noqa: F401 - gated_mlp_model, - gated_mlp_model_input, - shared_params_model, - shared_params_model_input, - simple_conv_linear_model, - simple_linear_model, - simple_linear_model_input, - simple_mha_model, - simple_mha_model_input, - simple_model_input, -) from tests.utils import test_artifact_path -_DEFAULT_SEED: int = 42 - - -# Quantization dtypes that CoreML export must reject. Weight dtypes include both -# torch dtype objects and string aliases. -COREML_WEIGHT_REJECT_DTYPES = [ - pytest.param(torch.float8_e4m3fn, id="fp8-torch-e4m3fn"), - pytest.param("float8_e4m3fn", id="fp8-str-e4m3fn"), - pytest.param(torch.float8_e5m2, id="fp8-torch-e5m2"), - pytest.param("float4_e2m1fn", id="fp4-str"), - pytest.param(torch.int2, id="int2-torch"), - pytest.param(torch.uint2, id="uint2-torch"), +pytest_plugins = [ + "tests.fixtures.quantization", + "tests.fixtures.palettization", + "tests.fixtures.fp8", + "tests.fixtures.fp4", + "tests.fixtures.compression", + "tests.fixtures.pruning", + "tests.models.mnist", + "tests.models.resnet", + "tests.models.simple", ] -COREML_ACT_REJECT_DTYPES = [ - pytest.param(torch.float8_e4m3fn, id="e4m3fn"), - pytest.param(torch.float8_e5m2, id="e5m2"), - pytest.param(torch.int4, id="int4"), - pytest.param(torch.uint4, id="uint4"), - pytest.param(torch.int2, id="int2"), - pytest.param(torch.uint2, id="uint2"), -] - - -def make_quant_config( - *, - weight_dtype: torch.dtype | str | None, - act_dtype: torch.dtype | str | None, - execution_mode: str, -) -> QuantizerConfig: - """Build a per-tensor symmetric QuantizerConfig for export tests. - - Args: - weight_dtype (torch.dtype | str | None): Weight dtype, or None to disable. - act_dtype (torch.dtype | str | None): Activation dtype, or None to disable. - execution_mode (str): Either "eager" or "graph". - - Returns: - QuantizerConfig: Config with the requested per-tensor symmetric specs. - """ - - def _spec(dtype: torch.dtype | str) -> QuantizationSpec: - return QuantizationSpec( - dtype=dtype, - qscheme=QuantizationScheme.SYMMETRIC, - granularity=PerTensorGranularity(), - ) - - weight_spec = _spec(weight_dtype) if weight_dtype is not None else None - act_spec = _spec(act_dtype) if act_dtype is not None else None - return QuantizerConfig( - global_config=ModuleQuantizerConfig( - op_state_spec={"weight": weight_spec} if weight_spec is not None else None, - op_input_spec={"*": act_spec}, - op_output_spec={"*": act_spec}, - ), - execution_mode=execution_mode, - ) +_DEFAULT_SEED: int = 42 @pytest.fixture(autouse=True) @@ -178,1054 +83,10 @@ def temp_dir(): @pytest.fixture(scope="function") -def mnist_pretrained_model(custom_test_mnist_model): # noqa: F811 +def mnist_pretrained_model(custom_test_mnist_model): """Load the committed 1-epoch MNIST checkpoint into a fresh model.""" model = custom_test_mnist_model model.load_state_dict( torch.load(test_artifact_path("mnist/mnist_pretrained_1epoch_09032025.pt")) ) return model - - -@dataclass -class ParametrizedQuantConfigs: - """Container for parametrized Eager and PT2E quantization configs. - - Used by the parametrized_quant_config test fixture to provide both config - types with identical quantization parameters. - - Attributes: - eager: QuantizerConfig with eager execution mode - pt2e: QuantizerConfig with pt2e execution mode - model_dtype: Model dtype (float16, float32, bfloat16, or None for no conversion) - - """ - - eager: QuantizerConfig - pt2e: QuantizerConfig - model_dtype: torch.dtype | None - - @classmethod - def from_quant_params( - cls, - weight_dtype: torch.dtype, - act_dtype: torch.dtype | None, - qscheme: QuantizationScheme, - w_granularity: PerTensorGranularity | PerChannelGranularity | PerBlockGranularity, - model_dtype: torch.dtype | None, - act_granularity: PerTensorGranularity | PerChannelGranularity | None = None, - ) -> "ParametrizedQuantConfigs": - """Create ParametrizedQuantConfigs from quantization parameters. - - Args: - weight_dtype: Weight quantization dtype - act_dtype: Activation quantization dtype (None to disable) - qscheme: Quantization scheme - w_granularity: Weight Quantization granularity - model_dtype: Model dtype - act_granularity: Activation Quantization granularity - - Returns: - ParametrizedQuantConfigs instance - - """ - activation_qspec = None - if act_dtype is not None: - activation_qspec = QuantizationSpec( - dtype=act_dtype, - qscheme=QuantizationScheme.SYMMETRIC, - granularity=act_granularity or PerTensorGranularity(), - fake_quantize_cls=_DefaultFakeQuantizeImpl, - qparam_calculator_cls=StaticQParamsCalculator, - range_calculator_cls=MinMaxRangeCalculator, - ) - - weight_qspec = QuantizationSpec( - dtype=weight_dtype, - qscheme=qscheme, - granularity=w_granularity, - fake_quantize_cls=_DefaultFakeQuantizeImpl, - qparam_calculator_cls=StaticQParamsCalculator, - range_calculator_cls=MinMaxRangeCalculator, - ) - - eager_config = QuantizerConfig( - global_config=ModuleQuantizerConfig( - op_state_spec={"weight": weight_qspec}, - op_input_spec={"*": activation_qspec}, - op_output_spec={"*": activation_qspec}, - ), - execution_mode="eager", - ) - - pt2e_config = QuantizerConfig( - global_config=ModuleQuantizerConfig( - op_state_spec={"weight": weight_qspec}, - op_input_spec={"*": activation_qspec}, - op_output_spec={"*": activation_qspec}, - ), - execution_mode="graph", - ) - - return cls( - eager=eager_config, - pt2e=pt2e_config, - model_dtype=model_dtype, - ) - - @property - def has_activation_quantization(self) -> bool: - """Check if activation quantization is enabled in this config. - - Returns: - True if activation quantization is enabled - - """ - # Eager and pt2e configs have identical quantization settings. - # could use self.pt2e here as well - return ( - self.eager.global_config.op_input_spec != {"*": None} - if self.eager.global_config - else False - ) - - def skip_if_unsupported( - self, - mode: Literal["eager", "graph"], - backend: ExportBackend, - unsupported_configs: dict[str, Any] | list[dict[str, Any]] | None = None, - reason: str = "", - ) -> None: - """Skip test if this config matches unsupported constraints. - - Args: - mode: Quantization mode to check - backend: Export backend to check - unsupported_configs: Dictionary or list of dictionaries of constraints that - make this config unsupported. Constraint keys: - - "backend": ExportBackend value to match - - "act_dtype": torch dtype for activation quantization (torch.int8, - torch.uint8, None for disabled) - - "weight_dtype": torch dtype for weight quantization - - "granularity_type": String name of granularity class - ("PerTensorGranularity", "PerChannelGranularity", - "PerBlockGranularity") - - "act_granularity_axis": int axis value on activation granularity - - Example: {"backend": ExportBackend.CoreML, "act_dtype": torch.int8} - Example: [{"granularity_type": "PerChannelGranularity"}, - {"granularity_type": "PerBlockGranularity"}] - - Raises: - pytest.skip: If config matches any unsupported constraints - - """ - if unsupported_configs is None: - return - - config = self.eager if mode == "eager" else self.pt2e - - # Normalize to list - configs_to_check = ( - unsupported_configs if isinstance(unsupported_configs, list) else [unsupported_configs] - ) - - # Check each unsupported config - for constraints in configs_to_check: - if "backend" in constraints and backend != constraints["backend"]: - continue - if self._matches_constraints(config, constraints): - pytest.skip( - reason or f"{mode.upper()} + {backend.value} does not support this config", - ) - - def xfail_if_unsupported( - self, - mode: Literal["eager", "graph"], - backend: ExportBackend, - unsupported_config: dict[str, Any] | list[dict[str, Any]] | None = None, - reason: str = "", - ) -> None: - """Mark test as expected failure if this config matches unsupported constraints. - - Args: - mode: Quantization mode to check - backend: Export backend to check - unsupported_config: Dictionary or list of dictionaries of constraints - reason: Reason for the expected failure - - """ - if unsupported_config is None: - return - - config = self.eager if mode == "eager" else self.pt2e - - # Normalize to list - configs_to_check = ( - unsupported_config if isinstance(unsupported_config, list) else [unsupported_config] - ) - - # Check each unsupported config - for constraints in configs_to_check: - if self._matches_constraints(config, constraints): - pytest.xfail( - reason or f"{mode.upper()} + {backend.value} does not support this config", - ) - - def _matches_constraints( - self, - config: QuantizerConfig, - constraints: dict[str, Any], - ) -> bool: - """Check if config matches all specified constraints. - - Args: - config: Config to check - constraints: Dictionary of constraints to match. Valid keys: - - backend: ExportBackend value (checked by caller, ignored here) - - act_dtype: torch dtype for activation quantization - - weight_dtype: torch dtype for weight quantization - - granularity_type: String name of granularity class - - model_dtype: torch dtype for model - - act_granularity_axis: int axis value on activation granularity - - Returns: - True if all constraints match - - Raises: - ValueError: If constraints contain unknown keys - - Note: - The 'backend' key is checked by the caller before this method is called, - so it's included in valid_keys but ignored in the constraint matching logic. - - """ - if not config.global_config: - return False - weight_qspec = config.global_config.op_state_spec.get("weight") - act_qspec = config.global_config.op_input_spec.get("*") - # Validate constraint keys to catch typos - valid_keys = { - "backend", - "act_dtype", - "weight_dtype", - "granularity_type", - "model_dtype", - "act_granularity_axis", - } - invalid_keys = set(constraints.keys()) - valid_keys - if invalid_keys: - msg = f"Unknown constraint keys: {invalid_keys}. Valid keys: {valid_keys}" - raise ValueError(msg) - - for key, value in constraints.items(): - if key == "act_dtype": - if act_qspec is None: - if value is not None: - return False - elif act_qspec.dtype != value: - return False - elif key == "weight_dtype": - if weight_qspec is None: - if value is not None: - return False - elif weight_qspec.dtype != value: - return False - elif key == "granularity_type": - if weight_qspec is None: - if value is not None: - return False - elif weight_qspec.granularity.__class__.__name__ != value: - return False - elif key == "model_dtype" and self.model_dtype != value: - return False - elif key == "act_granularity_axis": - if ( - act_qspec is None - or not hasattr(act_qspec.granularity, "axis") - or act_qspec.granularity.axis != value - ): - return False - - return True - - -@dataclass -class ParametrizedPalettConfigs: - """Container for parametrized palettization configs. - - Used by the parametrized_palett_config test fixture to provide KMeans - palettization configuration with parameterized settings. - - Attributes: - config: KMeansPalettizerConfig instance - n_bits: Number of palette bits - granularity: Palettization granularity - enable_per_channel_scale: Whether per-channel scaling is enabled - cluster_dim: Cluster dimension (1 for scalar, >1 for vector palettization) - lut_qspec: LUT quantization spec (None if LUT is not quantized) - - """ - - config: KMeansPalettizerConfig - n_bits: int - granularity: PalettizationPerTensorGranularity | PerGroupedChannelGranularity - enable_per_channel_scale: bool - cluster_dim: int = 1 - lut_qspec: QuantizationSpec | None = None - - @classmethod - def from_palett_params( - cls, - n_bits: int, - granularity: PalettizationPerTensorGranularity | PerGroupedChannelGranularity, - enable_per_channel_scale: bool, - cluster_dim: int = 1, - lut_qspec: QuantizationSpec | None = None, - ) -> "ParametrizedPalettConfigs": - """Create ParametrizedPalettConfigs from palettization parameters. - - Args: - n_bits: Number of palette bits - granularity: Palettization granularity - enable_per_channel_scale: Whether to enable per-channel scaling - cluster_dim: Cluster dimension (1 for scalar, >1 for vector) - lut_qspec: LUT quantization spec - - Returns: - ParametrizedPalettConfigs instance - - """ - palett_spec = PalettizationSpec( - n_bits=n_bits, - lut_qspec=lut_qspec, - granularity=granularity, - cluster_dim=cluster_dim, - enable_per_channel_scale=enable_per_channel_scale, - ) - - config = KMeansPalettizerConfig( - global_config=ModuleKMeansPalettizerConfig( - op_state_spec={ - "weight": palett_spec, - }, - enable_fast_kmeans_mode=cluster_dim == 1, - ), - ) - - return cls( - config=config, - n_bits=n_bits, - granularity=granularity, - enable_per_channel_scale=enable_per_channel_scale, - cluster_dim=cluster_dim, - lut_qspec=lut_qspec, - ) - - -@dataclass -class ParametrizedFP8Configs: - """Container for parametrized FP8 quantization configs. - - Used by the parametrized_fp8_config test fixture to provide FP8 quantization - configurations for both Eager and PT2E quantizers. - - Attributes: - eager: QuantizerConfig instance with FP8 quantization - pt2e: QuantizerConfig instance with FP8 quantization - fp8_dtype: FP8 dtype (float8_e4m3fn or float8_e5m2) - with_activation_quant: Whether activation quantization is enabled - - """ - - eager: QuantizerConfig - pt2e: QuantizerConfig - fp8_dtype: torch.dtype - with_activation_quant: bool - model_dtype: torch.dtype - - @classmethod - def from_fp8_params( - cls, - fp8_dtype: torch.dtype, - with_activation_quant: bool, - model_dtype: torch.dtype = torch.float32, - per_channel_activations: bool = False, - per_channel_activations_axis: int = 0, - ) -> "ParametrizedFP8Configs": - """Create ParametrizedFP8Configs from FP8 parameters. - - FP8 quantization requires symmetric scheme and per-tensor granularity. - - Args: - fp8_dtype: FP8 dtype (float8_e4m3fn or float8_e5m2) - with_activation_quant: Whether to enable activation quantization - model_dtype: Model dtype for the test (default: float32) - per_channel_activations: [default=False] Whether activations are to be - quantized per-channel. - per_channel_activations_axis: [default=0] If per_channel_activations is set, - this value specifies the axis for per-channel quantization. - - Returns: - ParametrizedFP8Configs instance - - """ - weight_qspec = QuantizationSpec( - dtype=fp8_dtype, - qscheme=QuantizationScheme.SYMMETRIC, - granularity=PerTensorGranularity(), - fake_quantize_cls="default", - qparam_calculator_cls="default", - range_calculator_cls="minmax", - ) - - activation_qspec = None - if with_activation_quant: - activation_qspec = QuantizationSpec( - dtype=fp8_dtype, - qscheme=QuantizationScheme.SYMMETRIC, - granularity=PerChannelGranularity(axis=per_channel_activations_axis) - if per_channel_activations - else PerTensorGranularity(), - fake_quantize_cls="default", - qparam_calculator_cls="moving_average", - range_calculator_cls="minmax", - ) - - eager_config = QuantizerConfig( - global_config=ModuleQuantizerConfig( - op_state_spec={"weight": weight_qspec}, - op_input_spec={"*": activation_qspec}, - op_output_spec={"*": activation_qspec}, - ), - execution_mode="eager", - ) - - pt2e_config = QuantizerConfig( - global_config=ModuleQuantizerConfig( - op_state_spec={"weight": weight_qspec}, - op_input_spec={"*": activation_qspec}, - op_output_spec={"*": activation_qspec}, - ), - execution_mode="graph", - ) - - return cls( - eager=eager_config, - pt2e=pt2e_config, - fp8_dtype=fp8_dtype, - with_activation_quant=with_activation_quant, - model_dtype=model_dtype, - ) - - -@dataclass -class ParametrizedFP4Configs: - """Container for parametrized FP4 quantization configs. - - Used by the parametrized_fp4_config test fixture to provide FP4 quantization - configurations for both Eager and PT2E quantizers. - - Attributes: - eager: QuantizerConfig instance with FP4 quantization - pt2e: QuantizerConfig instance with FP4 quantization - with_activation_quant: Whether activation quantization is enabled - model_dtype: Model dtype for the test (default: float32) - """ - - eager: QuantizerConfig - pt2e: QuantizerConfig - with_activation_quant: bool - model_dtype: torch.dtype - - @classmethod - def from_fp4_params( - cls, - with_activation_quant: bool, - model_dtype: torch.dtype = torch.float32, - weight_dtype: torch.dtype | str = "float4_e2m1fn", - per_block_weights: bool = False, - weight_block_size: int = 32, - activation_dtype: torch.dtype | str = "float4_e2m1fn", - per_block_activations: bool = False, - activation_block_size: int = 32, - ) -> "ParametrizedFP4Configs": - """Create ParametrizedFP4Configs from FP4 parameters. - - FP4 quantization requires symmetric scheme and per-block granularity with block_size=32. - - Args: - with_activation_quant: Whether to enable activation quantization. - model_dtype: Model dtype for the test (default: float32). - weight_dtype: Weight dtype for quantization (default: float4_e2m1fn). - per_block_weights: Whether weights are to be quantized per-block. - weight_block_size: Block size for weight quantization. - activation_dtype: Activation dtype for quantization (default: float4_e2m1fn). - per_block_activations: Whether activations are to be quantized per-block. - activation_block_size: Block size for activation quantization. - - Returns: - ParametrizedFP4Configs instance - - """ - weight_qspec = QuantizationSpec( - dtype=weight_dtype, - qscheme=QuantizationScheme.SYMMETRIC, - granularity=PerBlockGranularity(axis=1, block_size=weight_block_size) - if per_block_weights - else PerTensorGranularity(), - ) - - activation_qspec = None - if with_activation_quant: - activation_qspec = QuantizationSpec( - dtype=activation_dtype, - qscheme=QuantizationScheme.SYMMETRIC, - granularity=PerBlockGranularity(axis=1, block_size=activation_block_size) - if per_block_activations - else PerTensorGranularity(), - ) - - eager_config = QuantizerConfig( - global_config=ModuleQuantizerConfig( - op_state_spec={"weight": weight_qspec}, - op_input_spec={"*": activation_qspec}, - op_output_spec={"*": activation_qspec}, - ), - execution_mode="eager", - ) - - pt2e_config = QuantizerConfig( - global_config=ModuleQuantizerConfig( - op_state_spec={"weight": weight_qspec}, - op_input_spec={"*": activation_qspec}, - op_output_spec={"*": activation_qspec}, - ), - execution_mode="graph", - ) - - return cls( - eager=eager_config, - pt2e=pt2e_config, - with_activation_quant=with_activation_quant, - model_dtype=model_dtype, - ) - - -@pytest.fixture( - params=[ - (weight_dtype, act_dtype, qscheme, w_granularity, act_granularity) - for weight_dtype in [ - torch.int8, - torch.uint8, - torch.int4, - torch.uint4, - ] - for act_dtype in [torch.int8, torch.uint8, None] - for qscheme in list(QuantizationScheme) - for w_granularity in [ - PerTensorGranularity(), - PerChannelGranularity(axis=1), - PerBlockGranularity(axis=0, block_size=2), - ] - for act_granularity in [ - PerTensorGranularity(), - PerChannelGranularity(axis=0), - PerChannelGranularity(axis=-1), - ] - # Weight-only configs (act_dtype=None) produce identical results regardless of - # act_granularity. Only include 1 combination (with PerTensorGranularity) for - # weight-only to avoid running redundant identical tests across all - # act_granularity values. - if act_dtype is not None or isinstance(act_granularity, PerTensorGranularity) - ], - ids=lambda p: ( - f"wt:{str(p[0]).split('.')[-1]}--" - f"act:{str(p[1]).split('.')[-1] if p[1] else 'disabled'}--" - f"qs:{p[2].value}--" - f"wg:{p[3].__class__.__name__.replace('Granularity', '')}--" - f"ag:{p[4].__class__.__name__.replace('Granularity', '')}--" - f"axis:{p[4].axis}" - ), -) -def parametrized_quant_config_general( - request: pytest.FixtureRequest, -) -> ParametrizedQuantConfigs: - """Fixture for general quantization configs without model dtype conversion. - - Sets model_dtype=None to skip dtype conversion. - Generates 252 parameter combinations. - Weight-only configs use only PerTensorGranularity for act_granularity. - - Returns: - ParametrizedQuantConfigs with model_dtype=None - - """ - weight_dtype, act_dtype, qscheme, w_granularity, act_granularity = request.param - return ParametrizedQuantConfigs.from_quant_params( - weight_dtype, - act_dtype, - qscheme, - w_granularity, - None, - act_granularity, - ) - - -@pytest.fixture( - params=[ - (weight_dtype, act_dtype, qscheme, w_granularity, model_dtype, act_granularity) - for weight_dtype in [ - torch.int8, - torch.uint8, - torch.int4, - torch.uint4, - ] - for act_dtype in [torch.int8, torch.uint8, None] - for qscheme in list(QuantizationScheme) - for w_granularity in [ - PerTensorGranularity(), - PerChannelGranularity(axis=1), - PerBlockGranularity(axis=0, block_size=2), - ] - for model_dtype in [ - torch.float16, - torch.float32, - torch.bfloat16, - ] - for act_granularity in [ - PerTensorGranularity(), - PerChannelGranularity(axis=0), - PerChannelGranularity(axis=-1), - ] - # Weight-only configs (act_dtype=None) produce identical results regardless of - # act_granularity. Only include 1 combination (with PerTensorGranularity) for - # weight-only to avoid running redundant identical tests across all - # act_granularity values. - if act_dtype is not None or isinstance(act_granularity, PerTensorGranularity) - ], - ids=lambda p: ( - f"wt:{str(p[0]).split('.')[-1]}--" - f"act:{str(p[1]).split('.')[-1] if p[1] else 'disabled'}--" - f"qs:{p[2].value}--" - f"wg:{p[3].__class__.__name__.replace('Granularity', '')}--" - f"m_dtype:{str(p[4]).split('.')[-1]}--" - f"ag:{p[5].__class__.__name__.replace('Granularity', '')}--" - f"axis:{p[5].axis}" - ), -) -def parametrized_quant_config_mlir( - request: pytest.FixtureRequest, -) -> ParametrizedQuantConfigs: - """Fixture for MLIR backend quantization configs. - - MLIR backend supports multiple model dtypes. - Generates 756 parameter combinations. - Weight-only configs use only PerTensorGranularity for act_granularity. - - Returns: - ParametrizedQuantConfigs with model_dtype varying across - float16/float32/bfloat16 - - """ - weight_dtype, act_dtype, qscheme, w_granularity, model_dtype, act_granularity = request.param - return ParametrizedQuantConfigs.from_quant_params( - weight_dtype, - act_dtype, - qscheme, - w_granularity, - model_dtype, - act_granularity, - ) - - -@pytest.fixture( - params=[ - (qscheme, act_granularity) - for qscheme in list(QuantizationScheme) - for act_granularity in [ - PerTensorGranularity(), - PerChannelGranularity(axis=0), - PerChannelGranularity(axis=1), - PerChannelGranularity(axis=2), - PerChannelGranularity(axis=-1), - PerChannelGranularity(axis=-2), - PerChannelGranularity(axis=-3), - ] - ], - ids=lambda p: ( - f"qs:{p[0].value}--" - f"ag:{p[1].__class__.__name__.replace('Granularity', '')}--" - f"axis:{p[1].axis}" - ), -) -def parametrized_quant_config_perchannel_act_axis_coverage( - request: pytest.FixtureRequest, -) -> ParametrizedQuantConfigs: - """Fixture for per-channel activation quantization axis testing. - - Uses fixed values for weight dtype (int8), activation dtype (uint8), - weight granularity (PerTensor), and model dtype (None) to isolate - per-channel activation axis behavior. - Compatible with both CoreML and CoreAI backends. Intended for use with - GatedMLPModel which has uniform rank-3 activations supporting all - axes in [-3, 3). - - Generates 21 parameter combinations (3 qschemes x 7 act granularities). - - Returns: - ParametrizedQuantConfigs with varied activation granularity axes - - """ - qscheme, act_granularity = request.param - return ParametrizedQuantConfigs.from_quant_params( - torch.int8, - torch.uint8, - qscheme, - PerTensorGranularity(), - None, - act_granularity, - ) - - -@pytest.fixture( - params=[ - (n_bits, granularity, enable_per_channel_scale, cluster_dim, lut_qspec) - for n_bits in [1, 2, 4] - for granularity in [ - PalettizationPerTensorGranularity(), - PerGroupedChannelGranularity(axis=0, group_size=2), - PerGroupedChannelGranularity(axis=1, group_size=2), - ] - for enable_per_channel_scale in [True, False] - for cluster_dim in [1, 2] - for lut_qspec in [ - None, - *( - QuantizationSpec( - dtype=dtype, - qscheme=QuantizationScheme.SYMMETRIC, - ) - for dtype in sorted(_SUPPORTED_LUT_DTYPES, key=str) - ), - ] - # cluster_dim=2 (vector palettization) is slow; only test with n_bits=4 - if cluster_dim == 1 or n_bits == 4 - ], - ids=lambda p: ( - f"n_bits:{p[0]}-" - f"granularity:{p[1].__class__.__name__.replace('Granularity', '')}" - + ( - f"_axis{p[1].axis}_gs{p[1].group_size}" - if isinstance(p[1], PerGroupedChannelGranularity) - else "" - ) - + f"-pcs:{'enabled' if p[2] else 'disabled'}" - + (f"-cd:{p[3]}" if p[3] > 1 else "") - + (f"-lut:{p[4].dtype}" if p[4] is not None else "") - ), -) -def parametrized_palett_config( - request: pytest.FixtureRequest, -) -> ParametrizedPalettConfigs: - """Fixture for palettization configs. - - Generates parameter combinations across: - - 3 n_bits values: [1, 2, 4] - - 3 granularities: [PerTensor, PerGroupedChannel(axis=0), PerGroupedChannel(axis=1)] - - 2 enable_per_channel_scale values: [True, False] - - 2 cluster_dim values: [1, 2] - - N+1 lut_qspec values: [None, + one symmetric spec per dtype in _SUPPORTED_LUT_DTYPES] - - cluster_dim=2 (vector palettization) is only combined with n_bits=4 to reduce - test runtime. - - Returns: - ParametrizedPalettConfigs instance - - """ - n_bits, granularity, enable_per_channel_scale, cluster_dim, lut_qspec = request.param - return ParametrizedPalettConfigs.from_palett_params( - n_bits, - granularity, - enable_per_channel_scale, - cluster_dim, - lut_qspec, - ) - - -@pytest.fixture( - params=[ - pytest.param( - (torch.float8_e4m3fn, True, torch.float32, True, -1), - id="wt:float8_e4m3fn-act:float8_e4m3fn-qs:symmetric-wg:PerTensor-ag:PerChannel-axis:-1", - ), - pytest.param( - (torch.float8_e4m3fn, False, torch.float32, False, 0), - id="wt:float8_e4m3fn-act:disabled-qs:symmetric-wg:PerTensor", - ), - pytest.param( - (torch.float8_e4m3fn, False, torch.float16, False, 0), - id="wt:float8_e4m3fn-act:disabled-qs:symmetric-wg:PerTensor-m_dtype:float16", - ), - pytest.param( - (torch.float8_e4m3fn, True, torch.float32, False, 0), - id="wt:float8_e4m3fn-act:float8_e4m3fn-qs:symmetric-wg:PerTensor", - ), - pytest.param( - (torch.float8_e5m2, False, torch.float32, False, 0), - id="wt:float8_e5m2-act:disabled-qs:symmetric-wg:PerTensor", - ), - pytest.param( - (torch.float8_e5m2, False, torch.float16, False, 0), - id="wt:float8_e5m2-act:disabled-qs:symmetric-wg:PerTensor-m_dtype:float16", - ), - pytest.param( - (torch.float8_e5m2, True, torch.float32, False, 0), - id="wt:float8_e5m2-act:float8_e5m2-qs:symmetric-wg:PerTensor", - ), - ], -) -def parametrized_fp8_config( - request: pytest.FixtureRequest, -) -> ParametrizedFP8Configs: - """Fixture for FP8 quantization configs. - - Generates 7 parameter combinations: - - 2 FP8 dtypes: [float8_e4m3fn, float8_e5m2] - - 2 activation quantization modes: [False (weight-only), True (with activation)] - - Weight-only configs also include float16 model dtype to verify scale casting - - a per channel activation quantization configs with axis=-1 - - All combinations are marked as xfail pending COREAI updates and output verification. - - Returns: - ParametrizedFP8Configs instance - - """ - ( - fp8_dtype, - with_activation_quant, - model_dtype, - per_channel_activations, - per_channel_activations_axis, - ) = request.param - - return ParametrizedFP8Configs.from_fp8_params( - fp8_dtype, - with_activation_quant, - model_dtype, - per_channel_activations, - per_channel_activations_axis, - ) - - -@dataclass -class ParametrizedP4A8CompressionConfigs: - """Container for parametrized P4-A8 compression (palettization + quantization) configs. - - Attributes: - palett_config (KMeansPalettizerConfig): Palettization configuration. - quant_config (QuantizerConfig): Activation quantization configuration. - has_lut_quantization (bool): Whether LUT quantization is enabled. - - """ - - palett_config: KMeansPalettizerConfig - quant_config: QuantizerConfig - has_lut_quantization: bool - - @classmethod - def from_params( - cls, - lut_qspec: QuantizationSpec | None = None, - ) -> "ParametrizedP4A8CompressionConfigs": - """Create config pair for P4-A8 joint compression. - - Palettization: 4-bit, per-tensor granularity. - Activation quantization: int8 symmetric per-tensor (input + output). - Weight quantization: disabled (weights are palettized). - - Args: - lut_qspec (QuantizationSpec | None): LUT quantization spec. - None for unquantized LUT, or a QuantizationSpec for quantized LUT. - - Returns: - ParametrizedP4A8CompressionConfigs: Config pair. - - """ - palett_spec = PalettizationSpec( - n_bits=4, - lut_qspec=lut_qspec, - ) - palett_config = KMeansPalettizerConfig( - global_config=ModuleKMeansPalettizerConfig( - op_state_spec={"weight": palett_spec}, - ), - ) - - act_spec = QuantizationSpec( - dtype=torch.int8, - qscheme=QuantizationScheme.SYMMETRIC, - ) - quant_config = QuantizerConfig( - global_config=ModuleQuantizerConfig( - op_state_spec=None, - op_input_spec={"*": act_spec}, - op_output_spec={"*": act_spec}, - ), - ) - - return cls( - palett_config=palett_config, - quant_config=quant_config, - has_lut_quantization=lut_qspec is not None, - ) - - -@pytest.fixture( - params=[ - pytest.param( - QuantizationSpec( - dtype=torch.int8, - qscheme=QuantizationScheme.SYMMETRIC, - ), - id="P4-A8-int8lut", - ), - pytest.param(None, id="P4-A8-nolut"), - ], -) -def parametrized_p4a8_compression_config( - request: pytest.FixtureRequest, -) -> ParametrizedP4A8CompressionConfigs: - """Fixture for P4-A8 compression (palettization + activation quantization) configs. - - Generates 2 parameter combinations: - - P4-A8-int8lut: 4-bit palettization with int8 symmetric LUT quantization - - P4-A8-nolut: 4-bit palettization without LUT quantization - - Both use int8 symmetric per-tensor activation quantization. - - Returns: - ParametrizedP4A8CompressionConfigs: P4-A8 compression config pair. - - """ - return ParametrizedP4A8CompressionConfigs.from_params(lut_qspec=request.param) - - -@pytest.fixture( - params=[ - pytest.param( - ("float4_e2m1fn", False, None, torch.float16, True, 32, False, 32), - id="wt:float4_e2m1fn-act:disabled-wg:PerBlock-wbs:32", - # TODO: handle float4 export with torch >=2.8. - marks=pytest.mark.xfail( - reason="Requires fix to handle float4 export with torch >=2.8." - ), - ), - pytest.param( - ("float4_e2m1fn", True, "float8_e4m3fn", torch.float16, True, 32, False, 32), - id="wt:float4_e2m1fn-act:float8_e4m3fn-wg:PerBlock-wbs:32-ag:PerTensor", - # TODO: handle float4 export with torch >=2.8. - marks=pytest.mark.xfail( - reason="Requires fix to handle float4 export with torch >=2.8." - ), - ), - ], -) -def parametrized_fp4_config( - request: pytest.FixtureRequest, -) -> ParametrizedFP4Configs: - """ - Fixture for FP4 quantization configs. - - Testing following combinations for weight and activation quantization: - - Weight Quantization dtype: torch.float4_e2m1fn_x2 - - Activation Quantization dtype: {torch.float4_e2m1fn_x2, torch.float8_e4m3fn} - - Weight quantization torch.float4_e2m1fn_x2: MLIR export only supported with - per-block granularity and block_size=32 - - Activation quantization torch.float4_e2m1fn_x2: MLIR export not supported - - Returns: - ParametrizedFP4Configs instance - """ - ( - weight_dtype, - with_activation_quant, - activation_dtype, - model_dtype, - per_block_weights, - weight_block_size, - per_block_activations, - activation_block_size, - ) = request.param - - return ParametrizedFP4Configs.from_fp4_params( - with_activation_quant, - model_dtype, - weight_dtype, - per_block_weights, - weight_block_size, - activation_dtype, - per_block_activations, - activation_block_size, - ) - - -# ─── Pruning parametrized configs ──────────────────────────────────────────── - - -def _get_pruning_schemes() -> list: - return [Unstructured(), ChannelStructured(axis=0)] - - -@dataclass -class ParametrizedPruneConfigs: - """Container for parametrized pruning configs. - - Attributes: - config: MagnitudePrunerConfig instance. - target_sparsity: Target sparsity fraction. - pruning_scheme: PruningScheme instance (Unstructured or ChannelStructured). - backend: Export backend (CoreML or CoreAI). - """ - - config: MagnitudePrunerConfig - target_sparsity: float - pruning_scheme: PruningScheme | str - backend: ExportBackend - - @classmethod - def from_prune_params( - cls, - target_sparsity: float, - pruning_scheme: PruningScheme | str, - backend: ExportBackend, - ) -> "ParametrizedPruneConfigs": - spec = PruningSpec(target_sparsity=target_sparsity, pruning_scheme=pruning_scheme) - config = MagnitudePrunerConfig( - global_config=ModuleMagnitudePrunerConfig(op_state_spec={"weight": spec}) - ) - return cls( - config=config, - target_sparsity=target_sparsity, - pruning_scheme=pruning_scheme, - backend=backend, - ) - - -@pytest.fixture( - params=[ - (target_sparsity, pruning_scheme, backend) - for target_sparsity in [0.25, 0.5, 0.75] - for pruning_scheme in _get_pruning_schemes() - for backend in [ExportBackend.CoreML, ExportBackend.CoreAI] - ], - ids=lambda p: f"sparsity:{p[0]}-scheme:{p[1].__class__.__name__}-backend:{p[2].value}", -) -def parametrized_prune_config( - request: pytest.FixtureRequest, -) -> ParametrizedPruneConfigs: - """Fixture for pruning configs parametrized across sparsity, scheme, and backend.""" - target_sparsity, pruning_scheme, backend = request.param - return ParametrizedPruneConfigs.from_prune_params(target_sparsity, pruning_scheme, backend) diff --git a/tests/coreai_utils/test_sparse_utils.py b/tests/coreai_utils/test_sparse_utils.py new file mode 100644 index 0000000..fcdd30e --- /dev/null +++ b/tests/coreai_utils/test_sparse_utils.py @@ -0,0 +1,193 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +"""Tests for coreai_opt.coreai_utils._utils.sparse_utils.""" + +import numpy as np +import pytest + +from coreai_opt.coreai_utils._utils.sparse_utils import ( + SparseParams, + _compress_by_magnitude, + _compress_by_nm_sparsity, + _produce_sparse_param, +) + + +class TestProduceSparseParam: + def test_basic(self) -> None: + """Correct nonzero_data and uint8 mask for a mixed-zero array.""" + val = np.array([[1.0, 0.0], [0.0, 2.0]]) + result = _produce_sparse_param(val) + assert isinstance(result, SparseParams) + np.testing.assert_array_equal(result.nonzero_data, [1.0, 2.0]) + np.testing.assert_array_equal(result.mask, [[1, 0], [0, 1]]) + assert result.mask.dtype == np.uint8 + + def test_all_zeros(self) -> None: + """All-zero input yields empty nonzero_data and an all-zero mask.""" + val = np.zeros((3, 3)) + result = _produce_sparse_param(val) + assert len(result.nonzero_data) == 0 + np.testing.assert_array_equal(result.mask, np.zeros((3, 3), dtype=np.uint8)) + + def test_all_nonzero(self) -> None: + """No-zero input yields nonzero_data equal to the flattened input.""" + val = np.array([[1.0, 2.0], [3.0, 4.0]]) + result = _produce_sparse_param(val) + np.testing.assert_array_equal(result.nonzero_data, val.flatten()) + np.testing.assert_array_equal(result.mask, np.ones((2, 2), dtype=np.uint8)) + + def test_mask_shape_matches_input(self) -> None: + """Mask shape matches the original (possibly multi-dimensional) input.""" + val = np.arange(1.0, 25.0).reshape(2, 3, 4) + result = _produce_sparse_param(val) + assert result.mask.shape == val.shape + + def test_mask_dtype_is_uint8(self) -> None: + val = np.array([1.0, 0.0, 2.0]) + result = _produce_sparse_param(val) + assert result.mask.dtype == np.uint8 + + +class TestCompressByMagnitude: + def test_zero_sparsity_keeps_all(self) -> None: + """target_sparsity=0 leaves every element intact.""" + val = np.array([[1.0, 2.0], [3.0, 4.0]]) + result = _compress_by_magnitude(val, target_sparsity=0.0) + assert result.mask.sum() == val.size + + def test_full_sparsity_zeros_all(self) -> None: + """target_sparsity=1 zeros every element.""" + val = np.array([[1.0, 2.0], [3.0, 4.0]]) + result = _compress_by_magnitude(val, target_sparsity=1.0) + assert result.mask.sum() == 0 + assert len(result.nonzero_data) == 0 + + def test_half_sparsity_nonzero_count(self) -> None: + """target_sparsity=0.5 yields half the elements nonzero.""" + val = np.array([1.0, 2.0, 3.0, 4.0]) + result = _compress_by_magnitude(val, target_sparsity=0.5) + assert result.mask.sum() == 2 + + def test_smallest_magnitude_zeroed(self) -> None: + """The n smallest-magnitude elements are zeroed, not the largest.""" + val = np.array([10.0, 1.0, 20.0, 2.0]) + result = _compress_by_magnitude(val, target_sparsity=0.5) + np.testing.assert_array_equal(result.nonzero_data, [10.0, 20.0]) + + def test_returns_sparse_params(self) -> None: + val = np.ones((4, 4)) + result = _compress_by_magnitude(val, target_sparsity=0.25) + assert isinstance(result, SparseParams) + + def test_block_sparsity_zeros_entire_blocks(self) -> None: + """Block sparsity assigns the same mask to all rows in a block. + + val has two blocks of rows (0-1 and 2-3). Block 0 has smaller L2 + norms per column, so it is zeroed at 50% sparsity. + """ + val = np.array([[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0]], dtype=np.float32) + result = _compress_by_magnitude(val, target_sparsity=0.5, block_size=2, dim=0) + np.testing.assert_array_equal(result.mask, [[0, 0], [0, 0], [1, 1], [1, 1]]) + np.testing.assert_array_equal(result.nonzero_data, [3.0, 4.0, 3.0, 4.0]) + + def test_block_sparsity_dim1(self) -> None: + """Block sparsity along dim=1 returns a SparseParams (not None).""" + val = np.ones((4, 4), dtype=np.float32) + result = _compress_by_magnitude(val, target_sparsity=0.5, block_size=2, dim=1) + assert isinstance(result, SparseParams) + + def test_block_size_without_dim_raises(self) -> None: + val = np.ones((4, 4)) + with pytest.raises(ValueError, match="`dim` must be provided"): + _compress_by_magnitude(val, target_sparsity=0.5, block_size=2, dim=None) + + def test_block_size_larger_than_half_channel_returns_none(self) -> None: + """block_size > channel/2 is not applicable; function returns None.""" + # channel=4 along dim=0, block_size=3: 3 > 4/2=2 + val = np.ones((4, 4), dtype=np.float32) + result = _compress_by_magnitude(val, target_sparsity=0.5, block_size=3, dim=0) + assert result is None + + def test_invalid_dim_raises(self) -> None: + # Use a 3D array so dim=2 is a valid shape index; the ValueError is + # then raised inside _apply_block_sparsity which checks dim in [0, 1]. + val = np.ones((4, 4, 4)) + with pytest.raises(ValueError, match="block sparsity pruning only supports dim"): + _compress_by_magnitude(val, target_sparsity=0.5, block_size=2, dim=2) + + def test_invalid_rank_raises(self) -> None: + val = np.ones((8,)) + with pytest.raises(ValueError, match="block sparsity only supports weights of rank"): + _compress_by_magnitude(val, target_sparsity=0.5, block_size=2, dim=0) + + +class TestCompressByNmSparsity: + def test_1_2_dim1_zeros_smallest_per_pair(self) -> None: + """1:2 pruning along dim=1 zeros the smaller of each consecutive pair.""" + val = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], dtype=np.float32) + result = _compress_by_nm_sparsity(val, n_m_ratio=(1, 2), dim=1) + np.testing.assert_array_equal(result.nonzero_data, [2.0, 4.0, 6.0, 8.0]) + np.testing.assert_array_equal(result.mask, [[0, 1, 0, 1], [0, 1, 0, 1]]) + + def test_1_2_dim0_zeros_smaller_row_per_block(self) -> None: + """1:2 pruning along dim=0 zeros the smaller row in each pair of rows.""" + val = np.array( + [ + [1.0, 3.0, 5.0, 7.0], + [2.0, 4.0, 6.0, 8.0], + [9.0, 11.0, 13.0, 15.0], + [10.0, 12.0, 14.0, 16.0], + ], + dtype=np.float32, + ) + result = _compress_by_nm_sparsity(val, n_m_ratio=(1, 2), dim=0) + np.testing.assert_array_equal( + result.mask, [[0, 0, 0, 0], [1, 1, 1, 1], [0, 0, 0, 0], [1, 1, 1, 1]] + ) + + def test_n_zero_keeps_all(self) -> None: + """n=0 means nothing is pruned; all elements are nonzero.""" + val = np.arange(1.0, 9.0, dtype=np.float32).reshape(2, 4) + result = _compress_by_nm_sparsity(val, n_m_ratio=(0, 2), dim=1) + assert result.mask.sum() == val.size + + def test_n_equals_m_zeros_all(self) -> None: + """n=m means all elements are pruned.""" + val = np.arange(1.0, 9.0, dtype=np.float32).reshape(2, 4) + result = _compress_by_nm_sparsity(val, n_m_ratio=(2, 2), dim=1) + assert result.mask.sum() == 0 + assert len(result.nonzero_data) == 0 + + def test_channel_not_divisible_by_m(self) -> None: + """Padding zeros consume pruning slots in the last group when channel % m != 0. + + With n=1, m=2, channel=5 the last group per row is [real_elem, 0(pad)]. + The padded zero has magnitude 0, so it takes the single pruning slot and + the last real element is kept instead of being zeroed. + """ + val = np.array([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]], dtype=np.float32) + result = _compress_by_nm_sparsity(val, n_m_ratio=(1, 2), dim=1) + # Groups [1,2] and [3,4] each zero their smaller element. + # Last group [5, 0(pad)]: padded zero takes the slot, so 5 and 10 survive. + np.testing.assert_array_equal(result.mask, [[0, 1, 0, 1, 1], [0, 1, 0, 1, 1]]) + + def test_m_larger_than_half_channel_returns_none(self) -> None: + """m > channel/2 is not applicable; function returns None.""" + # channel along dim=1 is 4, m=3: 3 > 4/2=2 + val = np.ones((4, 4), dtype=np.float32) + result = _compress_by_nm_sparsity(val, n_m_ratio=(1, 3), dim=1) + assert result is None + + def test_invalid_dim_raises(self) -> None: + val = np.ones((4, 4)) + with pytest.raises(ValueError, match="n:m pruning only supports dim"): + _compress_by_nm_sparsity(val, n_m_ratio=(1, 2), dim=2) + + def test_invalid_rank_raises(self) -> None: + val = np.ones((8,)) + with pytest.raises(ValueError, match="n:m pruning only supports weights of rank"): + _compress_by_nm_sparsity(val, n_m_ratio=(1, 2), dim=0) diff --git a/tests/coreai_utils/test_sparsify_weights.py b/tests/coreai_utils/test_sparsify_weights.py index f05e9b6..8b3d7bc 100644 --- a/tests/coreai_utils/test_sparsify_weights.py +++ b/tests/coreai_utils/test_sparsify_weights.py @@ -3,7 +3,10 @@ # Use of this source code is governed by a BSD-3-Clause license that can # be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause +import numpy as np import pytest +import torch +import torch.nn as nn from coreai_opt.coreai_utils import DType, sparsify_weights from tests.export.export_utils import MLIRConverter @@ -209,3 +212,112 @@ def test_mlir_weight_sparsification_validation( coreai_program, _, _ = _coreai_program with pytest.raises(ValueError, match=error_match): sparsify_weights(coreai_program=coreai_program, **kwargs) + + +class _MatmulModel(nn.Module): + """Model whose fp16 weight buffer lowers to broadcasting_batch_matmul.""" + + def __init__(self) -> None: + super().__init__() + val = np.array( + [ + [1, 3, 4, -3], + [-6, -7, 2, 4], + [0, 3, 4, 1], + [-9, 2, -1, 8], + ], + dtype=np.float16, + ) + self.register_buffer("weight", torch.from_numpy(val)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.matmul(self.weight, x) + + +def _make_coreai_program() -> object: + x = torch.eye(4, dtype=torch.float16) + exported = MLIRConverter().trace(_MatmulModel(), x, {}) + return MLIRConverter._lower_to_coreai(exported) + + +def test_sparsify_weights_n_m_ratio_e2e() -> None: + """Compressed model output matches expected 1:2 pruned values. + + The weight is a fp16 (4, 4) matrix consumed by broadcasting_batch_matmul, + so the pass prunes along input_channel_axis (dim=1). With n:m=(1,2), the + smaller of each consecutive pair of columns is zeroed. Running with an + identity input isolates the weight in the output. + """ + compressed = sparsify_weights( + coreai_program=_make_coreai_program(), + target_sparsity=None, + n_m_ratio=(1, 2), + weight_num_threshold=0, + in_place=False, + ) + assert "coreai.build_sparse_with_bitmask" in str(compressed) + (output,) = MLIRConverter()._run_inference(compressed, torch.eye(4, dtype=torch.float16)) + expected = np.array( + [ + [0, 3, 4, 0], + [0, -7, 0, 4], + [0, 3, 4, 0], + [-9, 0, 0, 8], + ], + dtype=np.float16, + ) + np.testing.assert_array_equal(output, expected) + + +def test_sparsify_weights_magnitude_e2e() -> None: + """Compressed model output matches expected 50%-sparsity pruned values. + + With target_sparsity=0.5, the lowest-magnitude elements are zeroed. + Running with an identity input isolates the weight in the output. + """ + compressed = sparsify_weights( + coreai_program=_make_coreai_program(), + target_sparsity=0.5, + weight_num_threshold=0, + in_place=False, + ) + assert "coreai.build_sparse_with_bitmask" in str(compressed) + (output,) = MLIRConverter()._run_inference(compressed, torch.eye(4, dtype=torch.float16)) + expected = np.array( + [ + [0, 0, 4, 0], + [-6, -7, 0, 4], + [0, 0, 4, 0], + [-9, 0, 0, 8], + ], + dtype=np.float16, + ) + np.testing.assert_array_equal(output, expected) + + +def test_sparsify_weights_block_e2e() -> None: + """Compressed model output matches expected block-sparsity pruned values. + + The weight is a fp16 (4, 4) matrix. With block_size=2 and target_sparsity=0.5, + row pairs are treated as blocks (output_channel_axis=0) and pruned by L2 norm. + Running with an identity input isolates the weight in the output. + """ + compressed = sparsify_weights( + coreai_program=_make_coreai_program(), + target_sparsity=0.5, + block_size=2, + weight_num_threshold=0, + in_place=False, + ) + assert "coreai.build_sparse_with_bitmask" in str(compressed) + (output,) = MLIRConverter()._run_inference(compressed, torch.eye(4, dtype=torch.float16)) + expected = np.array( + [ + [1, 3, 0, 0], + [-6, -7, 0, 0], + [0, 0, 0, 1], + [-9, 0, 0, 8], + ], + dtype=np.float16, + ) + np.testing.assert_array_equal(output, expected) diff --git a/tests/export/export_utils.py b/tests/export/export_utils.py index 4460835..aab7e55 100644 --- a/tests/export/export_utils.py +++ b/tests/export/export_utils.py @@ -413,7 +413,12 @@ async def _run_inference_async( inputs={input_name: NDArray(input_data.cpu())}, ) - return tuple(torch.from_numpy(v.numpy()) for v in coreai_outputs.values()) + # TODO(rdar://180563027): replace this private-attribute DLPack workaround with + # coreai's public NDArray torch() conversion once that API is available. + return tuple( + torch.from_dlpack(v._tensor.to_dlpack()) # noqa: SLF001 + for v in coreai_outputs.values() + ) def _get_op_counts( self, diff --git a/tests/export/test_eager_mil_export.py b/tests/export/test_eager_mil_export.py index b357324..10835f1 100644 --- a/tests/export/test_eager_mil_export.py +++ b/tests/export/test_eager_mil_export.py @@ -10,7 +10,7 @@ from coreai_opt import CoreMLExportError, ExportBackend from coreai_opt.quantization import Quantizer, QuantizerConfig -from tests.conftest import ( +from tests.fixtures.quantization import ( COREML_ACT_REJECT_DTYPES, COREML_WEIGHT_REJECT_DTYPES, ParametrizedQuantConfigs, diff --git a/tests/export/test_eager_mlir_export.py b/tests/export/test_eager_mlir_export.py index 9db26d3..6671bee 100644 --- a/tests/export/test_eager_mlir_export.py +++ b/tests/export/test_eager_mlir_export.py @@ -22,11 +22,9 @@ QuantizationFormulation, QuantizationScheme, ) -from tests.conftest import ( - ParametrizedFP4Configs, - ParametrizedFP8Configs, - ParametrizedQuantConfigs, -) +from tests.fixtures.fp4 import ParametrizedFP4Configs +from tests.fixtures.fp8 import ParametrizedFP8Configs +from tests.fixtures.quantization import ParametrizedQuantConfigs from . import export_utils @@ -105,11 +103,31 @@ def test_simple_model_export( """Test eager Core AI export with various quantization configurations.""" has_act_quant = parametrized_quant_config_mlir.has_activation_quantization + # 4-bit-weight and int8 weight+activation per-tensor bfloat16 configs abort the + # CoreAI interpreter (SIGABRT); xfail them without running so the native crash + # cannot abort the session. + parametrized_quant_config_mlir.xfail_if_unsupported( + "eager", + ExportBackend.CoreAI, + unsupported_config=[ + {"model_dtype": torch.bfloat16, "weight_dtype": torch.int4}, + {"model_dtype": torch.bfloat16, "weight_dtype": torch.uint4}, + { + "model_dtype": torch.bfloat16, + "weight_dtype": torch.int8, + "act_dtype": torch.int8, + "granularity_type": "PerTensorGranularity", + }, + ], + reason="CoreAI interpreter aborts on this bfloat16 config.", + ) + if parametrized_quant_config_mlir.model_dtype == torch.bfloat16: request.applymarker( - # TODO: add coreai conv2d kernel for bfloat16. - # TODO: add coreai round kernel for bfloat16. - pytest.mark.xfail(reason="coreai interpreter has missing kernels for bfloat16.") + pytest.mark.xfail( + reason="bfloat16 CoreAI export not yet reliable (flaky SNR).", + strict=False, + ) ) _run_eager_mlir_export_test( diff --git a/tests/export/test_graph_mode_mlir_export.py b/tests/export/test_graph_mode_mlir_export.py index deeba1c..210fa93 100644 --- a/tests/export/test_graph_mode_mlir_export.py +++ b/tests/export/test_graph_mode_mlir_export.py @@ -24,12 +24,10 @@ QuantizationFormulation, QuantizationScheme, ) -from tests.conftest import ( - ParametrizedFP4Configs, - ParametrizedFP8Configs, - ParametrizedP4A8CompressionConfigs, - ParametrizedQuantConfigs, -) +from tests.fixtures.compression import ParametrizedP4A8CompressionConfigs +from tests.fixtures.fp4 import ParametrizedFP4Configs +from tests.fixtures.fp8 import ParametrizedFP8Configs +from tests.fixtures.quantization import ParametrizedQuantConfigs from . import export_utils @@ -132,9 +130,31 @@ def test_simple_model_export( """Test graph-mode Core AI export with various quantization configurations.""" has_act_quant = parametrized_quant_config_mlir.has_activation_quantization + # 4-bit-weight and int8 weight+activation per-tensor bfloat16 configs abort the + # CoreAI interpreter (SIGABRT); xfail them without running so the native crash + # cannot abort the session. + parametrized_quant_config_mlir.xfail_if_unsupported( + "graph", + ExportBackend.CoreAI, + unsupported_config=[ + {"model_dtype": torch.bfloat16, "weight_dtype": torch.int4}, + {"model_dtype": torch.bfloat16, "weight_dtype": torch.uint4}, + { + "model_dtype": torch.bfloat16, + "weight_dtype": torch.int8, + "act_dtype": torch.int8, + "granularity_type": "PerTensorGranularity", + }, + ], + reason="CoreAI interpreter aborts on this bfloat16 config.", + ) + if parametrized_quant_config_mlir.model_dtype == torch.bfloat16: request.applymarker( - pytest.mark.xfail(reason="coreai interpreter has missing kernels for bfloat16: , ") + pytest.mark.xfail( + reason="bfloat16 CoreAI export not yet reliable (flaky SNR).", + strict=False, + ) ) _run_graph_mode_mlir_export_test( @@ -350,8 +370,8 @@ def test_mnist_p4a8_compression_export( input_data=mnist_example_input, config=parametrized_p4a8_compression_config, expected_ops={ - "lut_to_dense": 4, - "constexpr_blockwise_shift_scale": 4 if has_lut else 0, + "lut_to_dense": 6, + "constexpr_blockwise_shift_scale": 6 if has_lut else 0, "quantize": 12, "dequantize": 12, }, diff --git a/tests/export/test_kmeans_export.py b/tests/export/test_kmeans_export.py index f9a5af1..7af7651 100644 --- a/tests/export/test_kmeans_export.py +++ b/tests/export/test_kmeans_export.py @@ -8,7 +8,7 @@ from coreai_opt import ExportBackend from coreai_opt.palettization import KMeansPalettizer, KMeansPalettizerConfig -from tests.conftest import ParametrizedPalettConfigs +from tests.fixtures.palettization import ParametrizedPalettConfigs from . import export_utils @@ -181,8 +181,10 @@ def test_mnist_export( _skip_unsupported_mil_configs(backend, parametrized_palett_config) - # For axis = 1, group_size is not divisible for conv1 layer - expected_count = 3 if granularity.axis == 1 else 4 + # The MNIST model has 6 weight-bearing layers (conv1, conv2, conv_transpose1, + # conv_transpose2, dense1, dense2). For axis=1 with group_size=2, conv1's + # axis-1 (in_channels=1) is not divisible, so palettization is skipped there. + expected_count = 5 if granularity.axis == 1 else 6 _run_kmeans_export_test( model=custom_test_mnist_model, diff --git a/tests/export/test_pruning_export.py b/tests/export/test_pruning_export.py index a80308b..d63815d 100644 --- a/tests/export/test_pruning_export.py +++ b/tests/export/test_pruning_export.py @@ -12,7 +12,7 @@ from coreai_opt import ExportBackend from coreai_opt.pruning import MagnitudePruner, MagnitudePrunerConfig from coreai_opt.pruning.spec import ChannelStructured -from tests.conftest import ParametrizedPruneConfigs +from tests.fixtures.pruning import ParametrizedPruneConfigs from . import export_utils diff --git a/tests/export/test_pt2e_mil_export.py b/tests/export/test_pt2e_mil_export.py index a8b2834..a12b0e6 100644 --- a/tests/export/test_pt2e_mil_export.py +++ b/tests/export/test_pt2e_mil_export.py @@ -23,7 +23,7 @@ StaticQParamsCalculator, ) from coreai_opt.quantization.spec.range_calculator import MinMaxRangeCalculator -from tests.conftest import ( +from tests.fixtures.quantization import ( COREML_ACT_REJECT_DTYPES, COREML_WEIGHT_REJECT_DTYPES, make_quant_config, diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 0000000..30b83ae --- /dev/null +++ b/tests/fixtures/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/tests/fixtures/compression.py b/tests/fixtures/compression.py new file mode 100644 index 0000000..931375a --- /dev/null +++ b/tests/fixtures/compression.py @@ -0,0 +1,112 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +"""P4-A8 compression parametrization config and the fixture that provides it.""" + +from dataclasses import dataclass + +import pytest +import torch + +from coreai_opt.palettization import ( + KMeansPalettizerConfig, + ModuleKMeansPalettizerConfig, +) +from coreai_opt.palettization.spec import PalettizationSpec +from coreai_opt.quantization import ModuleQuantizerConfig, QuantizerConfig +from coreai_opt.quantization.spec import QuantizationScheme, QuantizationSpec + + +@dataclass +class ParametrizedP4A8CompressionConfigs: + """Container for parametrized P4-A8 compression (palettization + quantization) configs. + + Attributes: + palett_config (KMeansPalettizerConfig): Palettization configuration. + quant_config (QuantizerConfig): Activation quantization configuration. + has_lut_quantization (bool): Whether LUT quantization is enabled. + + """ + + palett_config: KMeansPalettizerConfig + quant_config: QuantizerConfig + has_lut_quantization: bool + + @classmethod + def from_params( + cls, + lut_qspec: QuantizationSpec | None = None, + ) -> "ParametrizedP4A8CompressionConfigs": + """Create config pair for P4-A8 joint compression. + + Palettization: 4-bit, per-tensor granularity. + Activation quantization: int8 symmetric per-tensor (input + output). + Weight quantization: disabled (weights are palettized). + + Args: + lut_qspec (QuantizationSpec | None): LUT quantization spec. + None for unquantized LUT, or a QuantizationSpec for quantized LUT. + + Returns: + ParametrizedP4A8CompressionConfigs: Config pair. + + """ + palett_spec = PalettizationSpec( + n_bits=4, + lut_qspec=lut_qspec, + ) + palett_config = KMeansPalettizerConfig( + global_config=ModuleKMeansPalettizerConfig( + op_state_spec={"weight": palett_spec}, + ), + ) + + act_spec = QuantizationSpec( + dtype=torch.int8, + qscheme=QuantizationScheme.SYMMETRIC, + ) + quant_config = QuantizerConfig( + global_config=ModuleQuantizerConfig( + op_state_spec=None, + op_input_spec={"*": act_spec}, + op_output_spec={"*": act_spec}, + ), + ) + + return cls( + palett_config=palett_config, + quant_config=quant_config, + has_lut_quantization=lut_qspec is not None, + ) + + +@pytest.fixture( + params=[ + pytest.param( + QuantizationSpec( + dtype=torch.int8, + qscheme=QuantizationScheme.SYMMETRIC, + ), + id="P4-A8-int8lut", + ), + pytest.param(None, id="P4-A8-nolut"), + ], +) +def parametrized_p4a8_compression_config( + request: pytest.FixtureRequest, +) -> ParametrizedP4A8CompressionConfigs: + """Fixture for P4-A8 compression (palettization + activation quantization) configs. + + Generates 2 parameter combinations: + - P4-A8-int8lut: 4-bit palettization with int8 symmetric LUT quantization + - P4-A8-nolut: 4-bit palettization without LUT quantization + + Both use int8 symmetric per-tensor activation quantization. + + Returns: + ParametrizedP4A8CompressionConfigs: P4-A8 compression config pair. + + """ + return ParametrizedP4A8CompressionConfigs.from_params(lut_qspec=request.param) diff --git a/tests/fixtures/fp4.py b/tests/fixtures/fp4.py new file mode 100644 index 0000000..0b1628e --- /dev/null +++ b/tests/fixtures/fp4.py @@ -0,0 +1,163 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +"""FP4 quantization parametrization config and the fixture that provides it.""" + +from dataclasses import dataclass + +import pytest +import torch + +from coreai_opt.quantization import ModuleQuantizerConfig, QuantizerConfig +from coreai_opt.quantization.spec import ( + PerBlockGranularity, + PerTensorGranularity, + QuantizationScheme, + QuantizationSpec, +) + + +@dataclass +class ParametrizedFP4Configs: + """Container for parametrized FP4 quantization configs. + + Used by the parametrized_fp4_config test fixture to provide FP4 quantization + configurations for both Eager and PT2E quantizers. + + Attributes: + eager: QuantizerConfig instance with FP4 quantization + pt2e: QuantizerConfig instance with FP4 quantization + with_activation_quant: Whether activation quantization is enabled + model_dtype: Model dtype for the test (default: float32) + """ + + eager: QuantizerConfig + pt2e: QuantizerConfig + with_activation_quant: bool + model_dtype: torch.dtype + + @classmethod + def from_fp4_params( + cls, + with_activation_quant: bool, + model_dtype: torch.dtype = torch.float32, + weight_dtype: torch.dtype | str = "float4_e2m1fn", + per_block_weights: bool = False, + weight_block_size: int = 32, + activation_dtype: torch.dtype | str = "float4_e2m1fn", + per_block_activations: bool = False, + activation_block_size: int = 32, + ) -> "ParametrizedFP4Configs": + """Create ParametrizedFP4Configs from FP4 parameters. + + FP4 quantization requires symmetric scheme and per-block granularity with block_size=32. + + Args: + with_activation_quant: Whether to enable activation quantization. + model_dtype: Model dtype for the test (default: float32). + weight_dtype: Weight dtype for quantization (default: float4_e2m1fn). + per_block_weights: Whether weights are to be quantized per-block. + weight_block_size: Block size for weight quantization. + activation_dtype: Activation dtype for quantization (default: float4_e2m1fn). + per_block_activations: Whether activations are to be quantized per-block. + activation_block_size: Block size for activation quantization. + + Returns: + ParametrizedFP4Configs instance + + """ + weight_qspec = QuantizationSpec( + dtype=weight_dtype, + qscheme=QuantizationScheme.SYMMETRIC, + granularity=PerBlockGranularity(axis=1, block_size=weight_block_size) + if per_block_weights + else PerTensorGranularity(), + ) + + activation_qspec = None + if with_activation_quant: + activation_qspec = QuantizationSpec( + dtype=activation_dtype, + qscheme=QuantizationScheme.SYMMETRIC, + granularity=PerBlockGranularity(axis=1, block_size=activation_block_size) + if per_block_activations + else PerTensorGranularity(), + ) + + eager_config = QuantizerConfig( + global_config=ModuleQuantizerConfig( + op_state_spec={"weight": weight_qspec}, + op_input_spec={"*": activation_qspec}, + op_output_spec={"*": activation_qspec}, + ), + execution_mode="eager", + ) + + pt2e_config = QuantizerConfig( + global_config=ModuleQuantizerConfig( + op_state_spec={"weight": weight_qspec}, + op_input_spec={"*": activation_qspec}, + op_output_spec={"*": activation_qspec}, + ), + execution_mode="graph", + ) + + return cls( + eager=eager_config, + pt2e=pt2e_config, + with_activation_quant=with_activation_quant, + model_dtype=model_dtype, + ) + + +@pytest.fixture( + params=[ + pytest.param( + ("float4_e2m1fn", False, None, torch.float16, True, 32, False, 32), + id="wt:float4_e2m1fn-act:disabled-wg:PerBlock-wbs:32", + ), + pytest.param( + ("float4_e2m1fn", True, "float8_e4m3fn", torch.float16, True, 32, False, 32), + id="wt:float4_e2m1fn-act:float8_e4m3fn-wg:PerBlock-wbs:32-ag:PerTensor", + ), + ], +) +def parametrized_fp4_config( + request: pytest.FixtureRequest, +) -> ParametrizedFP4Configs: + """ + Fixture for FP4 quantization configs. + + Testing following combinations for weight and activation quantization: + - Weight Quantization dtype: torch.float4_e2m1fn_x2 + - Activation Quantization dtype: {torch.float4_e2m1fn_x2, torch.float8_e4m3fn} + - Weight quantization torch.float4_e2m1fn_x2: MLIR export only supported with + per-block granularity and block_size=32 + - Activation quantization torch.float4_e2m1fn_x2: MLIR export not supported + + Returns: + ParametrizedFP4Configs instance + """ + ( + weight_dtype, + with_activation_quant, + activation_dtype, + model_dtype, + per_block_weights, + weight_block_size, + per_block_activations, + activation_block_size, + ) = request.param + + return ParametrizedFP4Configs.from_fp4_params( + with_activation_quant, + model_dtype, + weight_dtype, + per_block_weights, + weight_block_size, + activation_dtype, + per_block_activations, + activation_block_size, + ) diff --git a/tests/fixtures/fp8.py b/tests/fixtures/fp8.py new file mode 100644 index 0000000..47937fb --- /dev/null +++ b/tests/fixtures/fp8.py @@ -0,0 +1,179 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +"""FP8 quantization parametrization config and the fixture that provides it.""" + +from dataclasses import dataclass + +import pytest +import torch + +from coreai_opt.quantization import ModuleQuantizerConfig, QuantizerConfig +from coreai_opt.quantization.spec import ( + PerChannelGranularity, + PerTensorGranularity, + QuantizationScheme, + QuantizationSpec, +) + + +@dataclass +class ParametrizedFP8Configs: + """Container for parametrized FP8 quantization configs. + + Used by the parametrized_fp8_config test fixture to provide FP8 quantization + configurations for both Eager and PT2E quantizers. + + Attributes: + eager: QuantizerConfig instance with FP8 quantization + pt2e: QuantizerConfig instance with FP8 quantization + fp8_dtype: FP8 dtype (float8_e4m3fn or float8_e5m2) + with_activation_quant: Whether activation quantization is enabled + + """ + + eager: QuantizerConfig + pt2e: QuantizerConfig + fp8_dtype: torch.dtype + with_activation_quant: bool + model_dtype: torch.dtype + + @classmethod + def from_fp8_params( + cls, + fp8_dtype: torch.dtype, + with_activation_quant: bool, + model_dtype: torch.dtype = torch.float32, + per_channel_activations: bool = False, + per_channel_activations_axis: int = 0, + ) -> "ParametrizedFP8Configs": + """Create ParametrizedFP8Configs from FP8 parameters. + + FP8 quantization requires symmetric scheme and per-tensor granularity. + + Args: + fp8_dtype: FP8 dtype (float8_e4m3fn or float8_e5m2) + with_activation_quant: Whether to enable activation quantization + model_dtype: Model dtype for the test (default: float32) + per_channel_activations: [default=False] Whether activations are to be + quantized per-channel. + per_channel_activations_axis: [default=0] If per_channel_activations is set, + this value specifies the axis for per-channel quantization. + + Returns: + ParametrizedFP8Configs instance + + """ + weight_qspec = QuantizationSpec( + dtype=fp8_dtype, + qscheme=QuantizationScheme.SYMMETRIC, + granularity=PerTensorGranularity(), + fake_quantize_cls="default", + qparam_calculator_cls="default", + range_calculator_cls="minmax", + ) + + activation_qspec = None + if with_activation_quant: + activation_qspec = QuantizationSpec( + dtype=fp8_dtype, + qscheme=QuantizationScheme.SYMMETRIC, + granularity=PerChannelGranularity(axis=per_channel_activations_axis) + if per_channel_activations + else PerTensorGranularity(), + fake_quantize_cls="default", + qparam_calculator_cls="moving_average", + range_calculator_cls="minmax", + ) + + eager_config = QuantizerConfig( + global_config=ModuleQuantizerConfig( + op_state_spec={"weight": weight_qspec}, + op_input_spec={"*": activation_qspec}, + op_output_spec={"*": activation_qspec}, + ), + execution_mode="eager", + ) + + pt2e_config = QuantizerConfig( + global_config=ModuleQuantizerConfig( + op_state_spec={"weight": weight_qspec}, + op_input_spec={"*": activation_qspec}, + op_output_spec={"*": activation_qspec}, + ), + execution_mode="graph", + ) + + return cls( + eager=eager_config, + pt2e=pt2e_config, + fp8_dtype=fp8_dtype, + with_activation_quant=with_activation_quant, + model_dtype=model_dtype, + ) + + +@pytest.fixture( + params=[ + pytest.param( + (torch.float8_e4m3fn, True, torch.float32, True, -1), + id="wt:float8_e4m3fn-act:float8_e4m3fn-qs:symmetric-wg:PerTensor-ag:PerChannel-axis:-1", + ), + pytest.param( + (torch.float8_e4m3fn, False, torch.float32, False, 0), + id="wt:float8_e4m3fn-act:disabled-qs:symmetric-wg:PerTensor", + ), + pytest.param( + (torch.float8_e4m3fn, False, torch.float16, False, 0), + id="wt:float8_e4m3fn-act:disabled-qs:symmetric-wg:PerTensor-m_dtype:float16", + ), + pytest.param( + (torch.float8_e4m3fn, True, torch.float32, False, 0), + id="wt:float8_e4m3fn-act:float8_e4m3fn-qs:symmetric-wg:PerTensor", + ), + pytest.param( + (torch.float8_e5m2, False, torch.float32, False, 0), + id="wt:float8_e5m2-act:disabled-qs:symmetric-wg:PerTensor", + ), + pytest.param( + (torch.float8_e5m2, False, torch.float16, False, 0), + id="wt:float8_e5m2-act:disabled-qs:symmetric-wg:PerTensor-m_dtype:float16", + ), + pytest.param( + (torch.float8_e5m2, True, torch.float32, False, 0), + id="wt:float8_e5m2-act:float8_e5m2-qs:symmetric-wg:PerTensor", + ), + ], +) +def parametrized_fp8_config( + request: pytest.FixtureRequest, +) -> ParametrizedFP8Configs: + """Fixture for FP8 quantization configs. + + Generates 7 parameter combinations: + - 2 FP8 dtypes: [float8_e4m3fn, float8_e5m2] + - 2 activation quantization modes: [False (weight-only), True (with activation)] + - Weight-only configs also include float16 model dtype to verify scale casting + - a per channel activation quantization configs with axis=-1 + + Returns: + ParametrizedFP8Configs instance + + """ + ( + fp8_dtype, + with_activation_quant, + model_dtype, + per_channel_activations, + per_channel_activations_axis, + ) = request.param + + return ParametrizedFP8Configs.from_fp8_params( + fp8_dtype, + with_activation_quant, + model_dtype, + per_channel_activations, + per_channel_activations_axis, + ) diff --git a/tests/fixtures/palettization.py b/tests/fixtures/palettization.py new file mode 100644 index 0000000..20e7194 --- /dev/null +++ b/tests/fixtures/palettization.py @@ -0,0 +1,161 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +"""Palettization parametrization config and the fixture that provides it.""" + +from dataclasses import dataclass + +import pytest + +from coreai_opt.palettization import ( + KMeansPalettizerConfig, + ModuleKMeansPalettizerConfig, +) +from coreai_opt.palettization.spec import ( + PalettizationSpec, + PerGroupedChannelGranularity, + PerTensorGranularity as PalettizationPerTensorGranularity, +) +from coreai_opt.palettization.spec.spec import _SUPPORTED_LUT_DTYPES +from coreai_opt.quantization.spec import QuantizationScheme, QuantizationSpec + + +@dataclass +class ParametrizedPalettConfigs: + """Container for parametrized palettization configs. + + Used by the parametrized_palett_config test fixture to provide KMeans + palettization configuration with parameterized settings. + + Attributes: + config: KMeansPalettizerConfig instance + n_bits: Number of palette bits + granularity: Palettization granularity + enable_per_channel_scale: Whether per-channel scaling is enabled + cluster_dim: Cluster dimension (1 for scalar, >1 for vector palettization) + lut_qspec: LUT quantization spec (None if LUT is not quantized) + + """ + + config: KMeansPalettizerConfig + n_bits: int + granularity: PalettizationPerTensorGranularity | PerGroupedChannelGranularity + enable_per_channel_scale: bool + cluster_dim: int = 1 + lut_qspec: QuantizationSpec | None = None + + @classmethod + def from_palett_params( + cls, + n_bits: int, + granularity: PalettizationPerTensorGranularity | PerGroupedChannelGranularity, + enable_per_channel_scale: bool, + cluster_dim: int = 1, + lut_qspec: QuantizationSpec | None = None, + ) -> "ParametrizedPalettConfigs": + """Create ParametrizedPalettConfigs from palettization parameters. + + Args: + n_bits: Number of palette bits + granularity: Palettization granularity + enable_per_channel_scale: Whether to enable per-channel scaling + cluster_dim: Cluster dimension (1 for scalar, >1 for vector) + lut_qspec: LUT quantization spec + + Returns: + ParametrizedPalettConfigs instance + + """ + palett_spec = PalettizationSpec( + n_bits=n_bits, + lut_qspec=lut_qspec, + granularity=granularity, + cluster_dim=cluster_dim, + enable_per_channel_scale=enable_per_channel_scale, + ) + + config = KMeansPalettizerConfig( + global_config=ModuleKMeansPalettizerConfig( + op_state_spec={ + "weight": palett_spec, + }, + enable_fast_kmeans_mode=cluster_dim == 1, + ), + ) + + return cls( + config=config, + n_bits=n_bits, + granularity=granularity, + enable_per_channel_scale=enable_per_channel_scale, + cluster_dim=cluster_dim, + lut_qspec=lut_qspec, + ) + + +@pytest.fixture( + params=[ + (n_bits, granularity, enable_per_channel_scale, cluster_dim, lut_qspec) + for n_bits in [1, 2, 4] + for granularity in [ + PalettizationPerTensorGranularity(), + PerGroupedChannelGranularity(axis=0, group_size=2), + PerGroupedChannelGranularity(axis=1, group_size=2), + ] + for enable_per_channel_scale in [True, False] + for cluster_dim in [1, 2] + for lut_qspec in [ + None, + *( + QuantizationSpec( + dtype=dtype, + qscheme=QuantizationScheme.SYMMETRIC, + ) + for dtype in sorted(_SUPPORTED_LUT_DTYPES, key=str) + ), + ] + # cluster_dim=2 (vector palettization) is slow; only test with n_bits=4 + if cluster_dim == 1 or n_bits == 4 + ], + ids=lambda p: ( + f"n_bits:{p[0]}-" + f"granularity:{p[1].__class__.__name__.replace('Granularity', '')}" + + ( + f"_axis{p[1].axis}_gs{p[1].group_size}" + if isinstance(p[1], PerGroupedChannelGranularity) + else "" + ) + + f"-pcs:{'enabled' if p[2] else 'disabled'}" + + (f"-cd:{p[3]}" if p[3] > 1 else "") + + (f"-lut:{p[4].dtype}" if p[4] is not None else "") + ), +) +def parametrized_palett_config( + request: pytest.FixtureRequest, +) -> ParametrizedPalettConfigs: + """Fixture for palettization configs. + + Generates parameter combinations across: + - 3 n_bits values: [1, 2, 4] + - 3 granularities: [PerTensor, PerGroupedChannel(axis=0), PerGroupedChannel(axis=1)] + - 2 enable_per_channel_scale values: [True, False] + - 2 cluster_dim values: [1, 2] + - N+1 lut_qspec values: [None, + one symmetric spec per dtype in _SUPPORTED_LUT_DTYPES] + + cluster_dim=2 (vector palettization) is only combined with n_bits=4 to reduce + test runtime. + + Returns: + ParametrizedPalettConfigs instance + + """ + n_bits, granularity, enable_per_channel_scale, cluster_dim, lut_qspec = request.param + return ParametrizedPalettConfigs.from_palett_params( + n_bits, + granularity, + enable_per_channel_scale, + cluster_dim, + lut_qspec, + ) diff --git a/tests/fixtures/pruning.py b/tests/fixtures/pruning.py new file mode 100644 index 0000000..f4392dc --- /dev/null +++ b/tests/fixtures/pruning.py @@ -0,0 +1,66 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +"""Pruning parametrization config and the fixture that provides it.""" + +from dataclasses import dataclass + +import pytest + +from coreai_opt import ExportBackend +from coreai_opt.pruning import MagnitudePrunerConfig, ModuleMagnitudePrunerConfig, PruningSpec +from coreai_opt.pruning.spec import ChannelStructured, PruningScheme, Unstructured + + +@dataclass +class ParametrizedPruneConfigs: + """Container for parametrized pruning configs. + + Attributes: + config: MagnitudePrunerConfig instance. + target_sparsity: Target sparsity fraction. + pruning_scheme: PruningScheme instance (Unstructured or ChannelStructured). + backend: Export backend (CoreML or CoreAI). + """ + + config: MagnitudePrunerConfig + target_sparsity: float + pruning_scheme: PruningScheme | str + backend: ExportBackend + + @classmethod + def from_prune_params( + cls, + target_sparsity: float, + pruning_scheme: PruningScheme | str, + backend: ExportBackend, + ) -> "ParametrizedPruneConfigs": + spec = PruningSpec(target_sparsity=target_sparsity, pruning_scheme=pruning_scheme) + config = MagnitudePrunerConfig( + global_config=ModuleMagnitudePrunerConfig(op_state_spec={"weight": spec}) + ) + return cls( + config=config, + target_sparsity=target_sparsity, + pruning_scheme=pruning_scheme, + backend=backend, + ) + + +@pytest.fixture( + params=[ + (target_sparsity, pruning_scheme, backend) + for target_sparsity in [0.25, 0.5, 0.75] + for pruning_scheme in [Unstructured(), ChannelStructured(axis=0)] + for backend in [ExportBackend.CoreML, ExportBackend.CoreAI] + ], + ids=lambda p: f"sparsity:{p[0]}-scheme:{p[1].__class__.__name__}-backend:{p[2].value}", +) +def parametrized_prune_config( + request: pytest.FixtureRequest, +) -> ParametrizedPruneConfigs: + """Fixture for pruning configs parametrized across sparsity, scheme, and backend.""" + target_sparsity, pruning_scheme, backend = request.param + return ParametrizedPruneConfigs.from_prune_params(target_sparsity, pruning_scheme, backend) diff --git a/tests/fixtures/quantization.py b/tests/fixtures/quantization.py new file mode 100644 index 0000000..82d17ae --- /dev/null +++ b/tests/fixtures/quantization.py @@ -0,0 +1,522 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +"""Quantization parametrization config and the fixtures that provide it.""" + +from dataclasses import dataclass +from typing import Any, Literal + +import pytest +import torch + +from coreai_opt import ExportBackend +from coreai_opt.quantization import ModuleQuantizerConfig, QuantizerConfig +from coreai_opt.quantization.spec import ( + PerBlockGranularity, + PerChannelGranularity, + PerTensorGranularity, + QuantizationScheme, + QuantizationSpec, +) +from coreai_opt.quantization.spec.fake_quantize import _DefaultFakeQuantizeImpl +from coreai_opt.quantization.spec.qparams_calculator import StaticQParamsCalculator +from coreai_opt.quantization.spec.range_calculator import MinMaxRangeCalculator + +# Quantization dtypes that CoreML export must reject. Weight dtypes include both +# torch dtype objects and string aliases. +COREML_WEIGHT_REJECT_DTYPES = [ + pytest.param(torch.float8_e4m3fn, id="fp8-torch-e4m3fn"), + pytest.param("float8_e4m3fn", id="fp8-str-e4m3fn"), + pytest.param(torch.float8_e5m2, id="fp8-torch-e5m2"), + pytest.param("float4_e2m1fn", id="fp4-str"), + pytest.param(torch.int2, id="int2-torch"), + pytest.param(torch.uint2, id="uint2-torch"), +] + +COREML_ACT_REJECT_DTYPES = [ + pytest.param(torch.float8_e4m3fn, id="e4m3fn"), + pytest.param(torch.float8_e5m2, id="e5m2"), + pytest.param(torch.int4, id="int4"), + pytest.param(torch.uint4, id="uint4"), + pytest.param(torch.int2, id="int2"), + pytest.param(torch.uint2, id="uint2"), +] + + +def make_quant_config( + *, + weight_dtype: torch.dtype | str | None, + act_dtype: torch.dtype | str | None, + execution_mode: str, +) -> QuantizerConfig: + """Build a per-tensor symmetric QuantizerConfig for export tests. + + Args: + weight_dtype (torch.dtype | str | None): Weight dtype, or None to disable. + act_dtype (torch.dtype | str | None): Activation dtype, or None to disable. + execution_mode (str): Either "eager" or "graph". + + Returns: + QuantizerConfig: Config with the requested per-tensor symmetric specs. + """ + + def _spec(dtype: torch.dtype | str) -> QuantizationSpec: + return QuantizationSpec( + dtype=dtype, + qscheme=QuantizationScheme.SYMMETRIC, + granularity=PerTensorGranularity(), + ) + + weight_spec = _spec(weight_dtype) if weight_dtype is not None else None + act_spec = _spec(act_dtype) if act_dtype is not None else None + return QuantizerConfig( + global_config=ModuleQuantizerConfig( + op_state_spec={"weight": weight_spec} if weight_spec is not None else None, + op_input_spec={"*": act_spec}, + op_output_spec={"*": act_spec}, + ), + execution_mode=execution_mode, + ) + + +@dataclass +class ParametrizedQuantConfigs: + """Container for parametrized Eager and PT2E quantization configs. + + Used by the parametrized_quant_config test fixture to provide both config + types with identical quantization parameters. + + Attributes: + eager: QuantizerConfig with eager execution mode + pt2e: QuantizerConfig with pt2e execution mode + model_dtype: Model dtype (float16, float32, bfloat16, or None for no conversion) + + """ + + eager: QuantizerConfig + pt2e: QuantizerConfig + model_dtype: torch.dtype | None + + @classmethod + def from_quant_params( + cls, + weight_dtype: torch.dtype, + act_dtype: torch.dtype | None, + qscheme: QuantizationScheme, + w_granularity: PerTensorGranularity | PerChannelGranularity | PerBlockGranularity, + model_dtype: torch.dtype | None, + act_granularity: PerTensorGranularity | PerChannelGranularity | None = None, + ) -> "ParametrizedQuantConfigs": + """Create ParametrizedQuantConfigs from quantization parameters. + + Args: + weight_dtype: Weight quantization dtype + act_dtype: Activation quantization dtype (None to disable) + qscheme: Quantization scheme + w_granularity: Weight Quantization granularity + model_dtype: Model dtype + act_granularity: Activation Quantization granularity + + Returns: + ParametrizedQuantConfigs instance + + """ + activation_qspec = None + if act_dtype is not None: + activation_qspec = QuantizationSpec( + dtype=act_dtype, + qscheme=QuantizationScheme.SYMMETRIC, + granularity=act_granularity or PerTensorGranularity(), + fake_quantize_cls=_DefaultFakeQuantizeImpl, + qparam_calculator_cls=StaticQParamsCalculator, + range_calculator_cls=MinMaxRangeCalculator, + ) + + weight_qspec = QuantizationSpec( + dtype=weight_dtype, + qscheme=qscheme, + granularity=w_granularity, + fake_quantize_cls=_DefaultFakeQuantizeImpl, + qparam_calculator_cls=StaticQParamsCalculator, + range_calculator_cls=MinMaxRangeCalculator, + ) + + eager_config = QuantizerConfig( + global_config=ModuleQuantizerConfig( + op_state_spec={"weight": weight_qspec}, + op_input_spec={"*": activation_qspec}, + op_output_spec={"*": activation_qspec}, + ), + execution_mode="eager", + ) + + pt2e_config = QuantizerConfig( + global_config=ModuleQuantizerConfig( + op_state_spec={"weight": weight_qspec}, + op_input_spec={"*": activation_qspec}, + op_output_spec={"*": activation_qspec}, + ), + execution_mode="graph", + ) + + return cls( + eager=eager_config, + pt2e=pt2e_config, + model_dtype=model_dtype, + ) + + @property + def has_activation_quantization(self) -> bool: + """Check if activation quantization is enabled in this config. + + Returns: + True if activation quantization is enabled + + """ + # Eager and pt2e configs have identical quantization settings. + # could use self.pt2e here as well + return ( + self.eager.global_config.op_input_spec != {"*": None} + if self.eager.global_config + else False + ) + + def skip_if_unsupported( + self, + mode: Literal["eager", "graph"], + backend: ExportBackend, + unsupported_configs: dict[str, Any] | list[dict[str, Any]] | None = None, + reason: str = "", + ) -> None: + """Skip test if this config matches unsupported constraints. + + Args: + mode: Quantization mode to check + backend: Export backend to check + unsupported_configs: Dictionary or list of dictionaries of constraints that + make this config unsupported. Constraint keys: + - "backend": ExportBackend value to match + - "act_dtype": torch dtype for activation quantization (torch.int8, + torch.uint8, None for disabled) + - "weight_dtype": torch dtype for weight quantization + - "granularity_type": String name of granularity class + ("PerTensorGranularity", "PerChannelGranularity", + "PerBlockGranularity") + - "act_granularity_axis": int axis value on activation granularity + + Example: {"backend": ExportBackend.CoreML, "act_dtype": torch.int8} + Example: [{"granularity_type": "PerChannelGranularity"}, + {"granularity_type": "PerBlockGranularity"}] + + Raises: + pytest.skip: If config matches any unsupported constraints + + """ + if unsupported_configs is None: + return + + config = self.eager if mode == "eager" else self.pt2e + + # Normalize to list + configs_to_check = ( + unsupported_configs if isinstance(unsupported_configs, list) else [unsupported_configs] + ) + + # Check each unsupported config + for constraints in configs_to_check: + if "backend" in constraints and backend != constraints["backend"]: + continue + if self._matches_constraints(config, constraints): + pytest.skip( + reason or f"{mode.upper()} + {backend.value} does not support this config", + ) + + def xfail_if_unsupported( + self, + mode: Literal["eager", "graph"], + backend: ExportBackend, + unsupported_config: dict[str, Any] | list[dict[str, Any]] | None = None, + reason: str = "", + ) -> None: + """Mark test as expected failure if this config matches unsupported constraints. + + Args: + mode: Quantization mode to check + backend: Export backend to check + unsupported_config: Dictionary or list of dictionaries of constraints + reason: Reason for the expected failure + + """ + if unsupported_config is None: + return + + config = self.eager if mode == "eager" else self.pt2e + + # Normalize to list + configs_to_check = ( + unsupported_config if isinstance(unsupported_config, list) else [unsupported_config] + ) + + # Check each unsupported config + for constraints in configs_to_check: + if "backend" in constraints and backend != constraints["backend"]: + continue + if self._matches_constraints(config, constraints): + pytest.xfail( + reason or f"{mode.upper()} + {backend.value} does not support this config", + ) + + def _matches_constraints( + self, + config: QuantizerConfig, + constraints: dict[str, Any], + ) -> bool: + """Check if config matches all specified constraints. + + Args: + config: Config to check + constraints: Dictionary of constraints to match. Valid keys: + - backend: ExportBackend value (checked by caller, ignored here) + - act_dtype: torch dtype for activation quantization + - weight_dtype: torch dtype for weight quantization + - granularity_type: String name of granularity class + - model_dtype: torch dtype for model + - act_granularity_axis: int axis value on activation granularity + + Returns: + True if all constraints match + + Raises: + ValueError: If constraints contain unknown keys + + Note: + The 'backend' key is checked by the caller before this method is called, + so it's included in valid_keys but ignored in the constraint matching logic. + + """ + if not config.global_config: + return False + weight_qspec = config.global_config.op_state_spec.get("weight") + act_qspec = config.global_config.op_input_spec.get("*") + # Validate constraint keys to catch typos + valid_keys = { + "backend", + "act_dtype", + "weight_dtype", + "granularity_type", + "model_dtype", + "act_granularity_axis", + } + invalid_keys = set(constraints.keys()) - valid_keys + if invalid_keys: + msg = f"Unknown constraint keys: {invalid_keys}. Valid keys: {valid_keys}" + raise ValueError(msg) + + for key, value in constraints.items(): + if key == "act_dtype": + if act_qspec is None: + if value is not None: + return False + elif act_qspec.dtype != value: + return False + elif key == "weight_dtype": + if weight_qspec is None: + if value is not None: + return False + elif weight_qspec.dtype != value: + return False + elif key == "granularity_type": + if weight_qspec is None: + if value is not None: + return False + elif weight_qspec.granularity.__class__.__name__ != value: + return False + elif key == "model_dtype" and self.model_dtype != value: + return False + elif key == "act_granularity_axis": + if ( + act_qspec is None + or not hasattr(act_qspec.granularity, "axis") + or act_qspec.granularity.axis != value + ): + return False + + return True + + +@pytest.fixture( + params=[ + (weight_dtype, act_dtype, qscheme, w_granularity, act_granularity) + for weight_dtype in [ + torch.int8, + torch.uint8, + torch.int4, + torch.uint4, + ] + for act_dtype in [torch.int8, torch.uint8, None] + for qscheme in list(QuantizationScheme) + for w_granularity in [ + PerTensorGranularity(), + PerChannelGranularity(axis=1), + PerBlockGranularity(axis=0, block_size=2), + ] + for act_granularity in [ + PerTensorGranularity(), + PerChannelGranularity(axis=0), + PerChannelGranularity(axis=-1), + ] + # Weight-only configs (act_dtype=None) produce identical results regardless of + # act_granularity. Only include 1 combination (with PerTensorGranularity) for + # weight-only to avoid running redundant identical tests across all + # act_granularity values. + if act_dtype is not None or isinstance(act_granularity, PerTensorGranularity) + ], + ids=lambda p: ( + f"wt:{str(p[0]).split('.')[-1]}--" + f"act:{str(p[1]).split('.')[-1] if p[1] else 'disabled'}--" + f"qs:{p[2].value}--" + f"wg:{p[3].__class__.__name__.replace('Granularity', '')}--" + f"ag:{p[4].__class__.__name__.replace('Granularity', '')}--" + f"axis:{p[4].axis}" + ), +) +def parametrized_quant_config_general( + request: pytest.FixtureRequest, +) -> ParametrizedQuantConfigs: + """Fixture for general quantization configs without model dtype conversion. + + Sets model_dtype=None to skip dtype conversion. + Generates 252 parameter combinations. + Weight-only configs use only PerTensorGranularity for act_granularity. + + Returns: + ParametrizedQuantConfigs with model_dtype=None + + """ + weight_dtype, act_dtype, qscheme, w_granularity, act_granularity = request.param + return ParametrizedQuantConfigs.from_quant_params( + weight_dtype, + act_dtype, + qscheme, + w_granularity, + None, + act_granularity, + ) + + +@pytest.fixture( + params=[ + (weight_dtype, act_dtype, qscheme, w_granularity, model_dtype, act_granularity) + for weight_dtype in [ + torch.int8, + torch.uint8, + torch.int4, + torch.uint4, + ] + for act_dtype in [torch.int8, torch.uint8, None] + for qscheme in list(QuantizationScheme) + for w_granularity in [ + PerTensorGranularity(), + PerChannelGranularity(axis=1), + PerBlockGranularity(axis=0, block_size=2), + ] + for model_dtype in [ + torch.float16, + torch.float32, + torch.bfloat16, + ] + for act_granularity in [ + PerTensorGranularity(), + PerChannelGranularity(axis=0), + PerChannelGranularity(axis=-1), + ] + # Weight-only configs (act_dtype=None) produce identical results regardless of + # act_granularity. Only include 1 combination (with PerTensorGranularity) for + # weight-only to avoid running redundant identical tests across all + # act_granularity values. + if act_dtype is not None or isinstance(act_granularity, PerTensorGranularity) + ], + ids=lambda p: ( + f"wt:{str(p[0]).split('.')[-1]}--" + f"act:{str(p[1]).split('.')[-1] if p[1] else 'disabled'}--" + f"qs:{p[2].value}--" + f"wg:{p[3].__class__.__name__.replace('Granularity', '')}--" + f"m_dtype:{str(p[4]).split('.')[-1]}--" + f"ag:{p[5].__class__.__name__.replace('Granularity', '')}--" + f"axis:{p[5].axis}" + ), +) +def parametrized_quant_config_mlir( + request: pytest.FixtureRequest, +) -> ParametrizedQuantConfigs: + """Fixture for MLIR backend quantization configs. + + MLIR backend supports multiple model dtypes. + Generates 756 parameter combinations. + Weight-only configs use only PerTensorGranularity for act_granularity. + + Returns: + ParametrizedQuantConfigs with model_dtype varying across + float16/float32/bfloat16 + + """ + weight_dtype, act_dtype, qscheme, w_granularity, model_dtype, act_granularity = request.param + return ParametrizedQuantConfigs.from_quant_params( + weight_dtype, + act_dtype, + qscheme, + w_granularity, + model_dtype, + act_granularity, + ) + + +@pytest.fixture( + params=[ + (qscheme, act_granularity) + for qscheme in list(QuantizationScheme) + for act_granularity in [ + PerTensorGranularity(), + PerChannelGranularity(axis=0), + PerChannelGranularity(axis=1), + PerChannelGranularity(axis=2), + PerChannelGranularity(axis=-1), + PerChannelGranularity(axis=-2), + PerChannelGranularity(axis=-3), + ] + ], + ids=lambda p: ( + f"qs:{p[0].value}--" + f"ag:{p[1].__class__.__name__.replace('Granularity', '')}--" + f"axis:{p[1].axis}" + ), +) +def parametrized_quant_config_perchannel_act_axis_coverage( + request: pytest.FixtureRequest, +) -> ParametrizedQuantConfigs: + """Fixture for per-channel activation quantization axis testing. + + Uses fixed values for weight dtype (int8), activation dtype (uint8), + weight granularity (PerTensor), and model dtype (None) to isolate + per-channel activation axis behavior. + Compatible with both CoreML and CoreAI backends. Intended for use with + GatedMLPModel which has uniform rank-3 activations supporting all + axes in [-3, 3). + + Generates 21 parameter combinations (3 qschemes x 7 act granularities). + + Returns: + ParametrizedQuantConfigs with varied activation granularity axes + + """ + qscheme, act_granularity = request.param + return ParametrizedQuantConfigs.from_quant_params( + torch.int8, + torch.uint8, + qscheme, + PerTensorGranularity(), + None, + act_granularity, + ) diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000..30b83ae --- /dev/null +++ b/tests/models/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/tests/palettization/test_kmeans_fake_palettize.py b/tests/palettization/test_kmeans_fake_palettize.py index f35b081..0e986a1 100644 --- a/tests/palettization/test_kmeans_fake_palettize.py +++ b/tests/palettization/test_kmeans_fake_palettize.py @@ -498,10 +498,22 @@ class TestPerGroupedChannelAxisDefault: @pytest.mark.parametrize( "op_to_optimize", - [F.linear, F.conv1d, F.conv2d, F.conv3d, F.multi_head_attention_forward, None], + [ + F.linear, + F.conv1d, + F.conv2d, + F.conv3d, + F.conv_transpose1d, + F.conv_transpose2d, + F.conv_transpose3d, + F.multi_head_attention_forward, + ], ) def test_axis_resolved_from_op_when_unset(self, op_to_optimize): - """When axis is None, _KMeansFakePalettize resolves it from the op's mixin.""" + """ + When axis is None, _KMeansFakePalettize resolves it to the registry-declared + default for the op. + """ palettizer = _KMeansFakePalettize( n_bits=2, lut_qspec=None, @@ -510,7 +522,11 @@ def test_axis_resolved_from_op_when_unset(self, op_to_optimize): enable_per_channel_scale=False, op_to_optimize=op_to_optimize, ) - assert palettizer.granularity.axis == palettizer.reshape_strategy.default_axis + + registry_entry = _KMeansPalettizerSupportedOpsRegistry.get_registry_entry_for_func( + op_to_optimize + ) + assert palettizer.granularity.axis == registry_entry.default_axis @pytest.mark.parametrize("axis", [0, 1]) def test_explicit_axis_preserved(self, axis): diff --git a/tests/palettization/test_kmeans_palettizer.py b/tests/palettization/test_kmeans_palettizer.py index 10133b9..58ff74c 100644 --- a/tests/palettization/test_kmeans_palettizer.py +++ b/tests/palettization/test_kmeans_palettizer.py @@ -617,18 +617,7 @@ def test_module_state_spec_disables_palettization( "spec_field", [ "module_state_spec", - pytest.param( - "op_state_spec", - # TODO: support op_state_spec on shared-weight aliases so config - # precedence (MODULE_NAME > MODULE_TYPE) wins over forward-pass order. - marks=pytest.mark.xfail( - reason=( - "op_state_spec on shared weights currently uses forward-pass " - "order instead of config precedence (MODULE_NAME > MODULE_TYPE)." - ), - strict=True, - ), - ), + "op_state_spec", ], ) def test_shared_weight_uses_priority( @@ -648,9 +637,8 @@ def test_shared_weight_uses_priority( properly built priority dict, insertion order [shared_linear, layer1, layer2] would win and pick n_bits=4. This case guards that regression. - - ``op_state_spec`` (xfail): currently uses forward-pass order — layer1's - op runs first and registers n_bits=4 for the shared tensor before - layer2's op gets a chance. + - ``op_state_spec``: layer2's spec with n_bits=2 gets picked as the + higher priority spec in accordance with config priority rules. """ type_kwargs = {"op_state_spec": None, "module_state_spec": None} type_kwargs[spec_field] = {"weight": PalettizationSpec(n_bits=4)} diff --git a/tests/palettization/test_kmeans_palettizer_mnist.py b/tests/palettization/test_kmeans_palettizer_mnist.py index 18c71a4..b9ba34f 100644 --- a/tests/palettization/test_kmeans_palettizer_mnist.py +++ b/tests/palettization/test_kmeans_palettizer_mnist.py @@ -17,6 +17,7 @@ PalettizationSpec, PerGroupedChannelGranularity, ) +from coreai_opt.palettization.spec.fake_palettize import _FakePalettizeImplBase image_size = 28 batch_size = 128 @@ -27,23 +28,34 @@ @pytest.mark.seed @pytest.mark.slow @pytest.mark.parametrize( - "spec", + "spec,expected_palettized_layers", [ - PalettizationSpec(n_bits=2), - PalettizationSpec(n_bits=4, cluster_dim=2), - PalettizationSpec( - n_bits=4, - cluster_dim=2, - granularity=PerGroupedChannelGranularity(axis=0, group_size=2), + # MNIST model has 6 weight-bearing layers (conv1, conv2, conv_transpose1, + # conv_transpose2, dense1, dense2). For axis=1 with group_size=2, conv1's + # axis-1 (in_channels=1) is not divisible, so palettization is skipped there. + (PalettizationSpec(n_bits=2), 6), + (PalettizationSpec(n_bits=4, cluster_dim=2), 6), + ( + PalettizationSpec( + n_bits=4, + cluster_dim=2, + granularity=PerGroupedChannelGranularity(axis=0, group_size=2), + ), + 6, ), - PalettizationSpec( - n_bits=4, - cluster_dim=2, - granularity=PerGroupedChannelGranularity(axis=1, group_size=2), + ( + PalettizationSpec( + n_bits=4, + cluster_dim=2, + granularity=PerGroupedChannelGranularity(axis=1, group_size=2), + ), + 5, ), ], ) -def test_weight_only_ptq_mnist(mnist_pretrained_model, mnist_dataset, spec): +def test_weight_only_ptq_mnist( + mnist_pretrained_model, mnist_dataset, spec, expected_palettized_layers +): """ Train a simple convnet on the MNIST dataset for different deployment targets and verify its accuracy. @@ -74,6 +86,11 @@ def test_weight_only_ptq_mnist(mnist_pretrained_model, mnist_dataset, spec): num_workers=1, ) + palettized_count = utils.count_weight_parametrizations(prepared_model, _FakePalettizeImplBase) + assert palettized_count == expected_palettized_layers, ( + f"Expected {expected_palettized_layers} palettized layers, got {palettized_count}" + ) + post_vanilla_kmeans_accuracy = utils.eval_model(prepared_model, test_loader) # Check that if there is any drop in accuracy, it is within 1% diff --git a/tests/quantization/test_annotation_pattern_registry.py b/tests/quantization/test_annotation_pattern_registry.py index 96142fa..5ce28d6 100644 --- a/tests/quantization/test_annotation_pattern_registry.py +++ b/tests/quantization/test_annotation_pattern_registry.py @@ -1544,8 +1544,8 @@ def forward(self, x): def test_overlapping_quantization_configs_precedence(self, weight_spec, activation_spec): """ Test when both output of preceding node and input of succeeding node - are both annotated. Verify only one fake_quant is inserted and output - config takes precedence. + are both annotated. Verify only one fake_quant is inserted and the + later config (linear2) takes precedence over the earlier config (linear1). """ class TwoLinearModel(nn.Module): @@ -1640,9 +1640,9 @@ def forward(self, x): # Get the actual fake quantize module to inspect its configuration fq_module = getattr(prepared_model, activation_fq_node.target) - # Check that output spec took precedence + # Check that later config (linear2) takes precedence actual_dtype = fq_module.dtype - assert actual_dtype == activation_spec_int8.dtype + assert actual_dtype == activation_spec_int4.dtype def test_nested_module_structure_quantization(self, weight_spec, activation_spec): """ @@ -2547,13 +2547,10 @@ def forward(self, inp): op_name_config={ "sub": OpQuantizerConfig( # Since both inputs to sub are weights, this input spec should - # not affect any inputs. Only the "linear_weight" state input - # should be quantized due to op_state_spec below. - op_input_spec={_ALL_TENSORS: default_activation_quantization_spec()}, + # not affect any inputs. We will still expect sub op's second input to be + # quantized due to the "linear_weight" setting below. + op_input_spec={_ALL_TENSORS: None}, op_output_spec=None, - op_state_spec={ - "linear_weight": default_activation_quantization_spec(), - }, ), }, ), @@ -2561,8 +2558,16 @@ def forward(self, inp): InnerModel: ModuleQuantizerConfig( op_input_spec=None, op_output_spec=None, - op_state_spec={ - "inner_linear_weight": default_activation_quantization_spec(), + op_state_spec=None, + module_state_spec={"linear_weight": default_weight_quantization_spec()}, + # Using "*" will allow for MyModel.inner_linear_weight to be quantized when used + # as add's first input. + op_name_config={ + "add": OpQuantizerConfig( + op_input_spec={1: default_activation_quantization_spec()}, + op_output_spec=None, + op_state_spec={_ALL_TENSORS: default_weight_quantization_spec()}, + ), }, ), }, @@ -2576,12 +2581,27 @@ def forward(self, inp): quantizer = Quantizer(model, config) prepared_model = quantizer.prepare(example_inputs) node_dict = {node.name: node for node in prepared_model.graph.nodes} + assert "activation_post_process" in node_dict["add"].all_input_nodes[0].name + weight_quantizer = getattr(prepared_model, node_dict["add"].all_input_nodes[0].name) + act_quantizer = getattr(prepared_model, node_dict["add"].all_input_nodes[1].name) + assert isinstance(weight_quantizer.qparams_calculator.granularity, PerChannelGranularity) + assert isinstance(act_quantizer.qparams_calculator.granularity, PerTensorGranularity) + + assert "activation_post_process" in node_dict["sub"].all_input_nodes[0].name assert "activation_post_process" in node_dict["sub"].all_input_nodes[1].name assert "activation_post_process" in node_dict["linear_1"].all_input_nodes[1].name + + # Sub first input should share the same quantizer as linear_1's weight input + assert ( + node_dict["sub"].all_input_nodes[0].name + == node_dict["linear_1"].all_input_nodes[1].name + ) + + # Check that there are no other quantizers in the model assert ( len([node_name for node_name in node_dict if "activation_post_process" in node_name]) - == 3 + == 4 ) @pytest.mark.parametrize( diff --git a/tests/quantization/test_eager_quant.py b/tests/quantization/test_eager_quant.py index cf456bd..3d57bb9 100644 --- a/tests/quantization/test_eager_quant.py +++ b/tests/quantization/test_eager_quant.py @@ -1982,7 +1982,7 @@ def forward(self, inp): execution_mode="eager", ), "Only integer indices or '*'", - id="op_input_check", + id="op_input_check-global_config", ), pytest.param( QuantizerConfig( @@ -1994,7 +1994,7 @@ def forward(self, inp): execution_mode="eager", ), "Only integer indices or '*'", - id="op_input_check", + id="op_input_check-module_type_configs", ), pytest.param( QuantizerConfig( @@ -2006,7 +2006,7 @@ def forward(self, inp): execution_mode="eager", ), "Only integer indices or '*'", - id="op_input_check", + id="op_input_check-module_type_configs-2", ), pytest.param( QuantizerConfig( @@ -2018,7 +2018,7 @@ def forward(self, inp): execution_mode="eager", ), "Only integer indices or '*'", - id="op_input_check", + id="op_input_check-module_name_configs", ), ], ) @@ -2247,13 +2247,6 @@ def forward(self, inp, inp2): }, execution_mode="eager", ), - # TODO: support distinct module-name configs for aliased reused modules in eager mode. - marks=pytest.mark.xfail( - reason=( - "Eager mode does not yet apply distinct module-name " - "configs to two attributes that alias the same module." - ) - ), ), ], ) @@ -2876,7 +2869,6 @@ def forward(self, inp, inp2, raise_error=False): assert _get_current_function_mode() is None -@pytest.mark.serial @pytest.mark.parametrize( "config, expected_quantizers", [ diff --git a/tests/quantization/test_factory.py b/tests/quantization/test_factory.py index a68076f..7e8bf8b 100644 --- a/tests/quantization/test_factory.py +++ b/tests/quantization/test_factory.py @@ -17,6 +17,7 @@ PerTensorGranularity, ) from coreai_opt.quantization.spec.qparams_calculator import ( + DynamicQParamsCalculator, GlobalMinMaxQParamsCalculator, MovingAverageQParamsCalculator, QParamsCalculatorBase, @@ -677,16 +678,24 @@ def test_resolution_in_fake_quantizer_weight(self): # The qparams_calculator should be StaticQParamsCalculator assert isinstance(fake_quantizer.qparams_calculator, StaticQParamsCalculator) - def test_resolution_in_fake_quantizer_activation(self): + @pytest.mark.parametrize( + "qparam_calculator_string,qparam_calculator_cls", + [ + ("default", MovingAverageQParamsCalculator), + ("dynamic", DynamicQParamsCalculator), + ], + ) + def test_resolution_in_fake_quantizer_activation( + self, qparam_calculator_string, qparam_calculator_cls + ): """Test marker resolution through full fake quantizer creation for activation""" - spec = QuantizationSpec(qparam_calculator_cls="default") + spec = QuantizationSpec(qparam_calculator_cls=qparam_calculator_string) fake_quantizer = QuantizationComponentFactory.create_fake_quantizer( spec, CompressionTargetTensor.ACTIVATION ) - # The qparams_calculator should be MovingAverageQParamsCalculator - assert isinstance(fake_quantizer.qparams_calculator, MovingAverageQParamsCalculator) + assert isinstance(fake_quantizer.qparams_calculator, qparam_calculator_cls) def test_default_class_not_callable(self): """Test that the marker class raises an error if forward() is called""" @@ -740,3 +749,26 @@ def test_global_minmax_string_resolves_to_class(self): assert isinstance(qparams_weight, GlobalMinMaxQParamsCalculator) assert isinstance(qparams_activation, GlobalMinMaxQParamsCalculator) + + def test_resolution_for_dynamic_qparams(self): + """Test that 'dynamic' string resolves to DynamicQParamsCalculator""" + spec = QuantizationSpec(qparam_calculator_cls="dynamic") + assert spec.qparam_calculator_cls == DynamicQParamsCalculator + + qparams_calc = QuantizationComponentFactory.create_qparams_calculator( + spec, CompressionTargetTensor.ACTIVATION + ) + assert isinstance(qparams_calc, DynamicQParamsCalculator) + + @pytest.mark.parametrize( + "target", + [CompressionTargetTensor.WEIGHT, CompressionTargetTensor.LUT], + ) + def test_dynamic_rejected_for_non_activation(self, target): + """Test that 'dynamic' raises ValueError when used for weight/LUT targets""" + spec = QuantizationSpec(qparam_calculator_cls="dynamic") + with pytest.raises( + ValueError, + match="DynamicQParamsCalculator is only supported for activation", + ): + QuantizationComponentFactory.create_qparams_calculator(spec, target) diff --git a/tests/quantization/test_graph_mode_quantizer.py b/tests/quantization/test_graph_mode_quantizer.py index cb41451..81d3d3a 100644 --- a/tests/quantization/test_graph_mode_quantizer.py +++ b/tests/quantization/test_graph_mode_quantizer.py @@ -515,7 +515,6 @@ def forward(self, x): output_false = prepared_model_false(example_input) assert output_false.shape == (1, 10) - @pytest.mark.xfail(reason="fails on torch 2.8.0, passes on torch 2.11.0") def test_prepare_with_symint_mul_partition_collision(self): """Verify prepare() handles a SourcePartition with multiple call_function nodes. @@ -526,6 +525,9 @@ def test_prepare_with_symint_mul_partition_collision(self): on SymInt inputs. """ + H, W, B, embed_dim = 4, 4, 1, 8 + num_iters = 2 # >=2 to force the synthesized muls to collide. + class SymIntMulModel(nn.Module): def __init__(self, num_iters: int, embed_dim: int) -> None: super().__init__() @@ -536,19 +538,20 @@ def forward(self, x: torch.Tensor, spatial_shapes: torch.Tensor) -> torch.Tensor for i in range(self.num_iters): h = spatial_shapes[i, 0].item() w = spatial_shapes[i, 1].item() - # Mark h, w as non-negative so torch.export's reshape helper can - # handle the unbacked SymInts in `view`. The runtime assertion - # h * w == x.size(1) is still synthesized as a SymInt mul node. + # The h * w == x.size(1) runtime assertion is synthesized as a + # SymInt mul node by ``insert_deferred_runtime_asserts``; one per + # iteration, all sharing one ``torch_fn`` tag — the partition + # collision this test guards against. The view itself uses static + # H, W instead of the SymInts to avoid the data-dependent reshape + # guards that fail to discharge on torch 2.8. torch._check(h >= 0) torch._check(w >= 0) - x_view = x.view(x.size(0), h, w, x.size(-1)) + torch._check(h * w == x.size(1)) + x_view = x.view(x.size(0), H, W, x.size(-1)) x_view = x_view.flatten(1, 2) x = self.linear(x_view) return x - H, W, B, embed_dim = 4, 4, 1, 8 - num_iters = 2 # >=2 to force the synthesized muls to collide. - model = SymIntMulModel(num_iters=num_iters, embed_dim=embed_dim).eval() example_inputs = ( torch.randn(B, H * W, embed_dim), diff --git a/tests/quantization/test_is_state_node.py b/tests/quantization/test_is_state_node.py index 62f77c3..674e28a 100644 --- a/tests/quantization/test_is_state_node.py +++ b/tests/quantization/test_is_state_node.py @@ -14,7 +14,7 @@ import pytest import torch -from coreai_opt.quantization._graph._annotation_utils import _is_state_node +from coreai_opt._utils.fx_utils import is_coreai_compressed_state_node as is_state_node from tests.test_utils.general import COREAI_AVAILABLE @@ -53,19 +53,19 @@ class TestIsStateNode: def test_get_attr_is_state(self): """get_attr nodes (direct parameter access) are state.""" node = _make_node("get_attr") - assert _is_state_node(node) is True + assert is_state_node(node) is True def test_placeholder_is_not_state(self): """Placeholder nodes (model inputs) are not state.""" node = _make_node("placeholder") - assert _is_state_node(node) is False + assert is_state_node(node) is False def test_lut_to_dense_is_state(self): """coreai.lut_to_dense call_function is state (palettized weights).""" indices = _make_node("get_attr") lut = _make_node("get_attr") node = _make_node("call_function", "coreai", "lut_to_dense", args=(indices, lut)) - assert _is_state_node(node) is True + assert is_state_node(node) is True def test_shift_scale_with_lut_input_is_state(self): """constexpr_blockwise_shift_scale fed by lut_to_dense is state.""" @@ -76,7 +76,7 @@ def test_shift_scale_with_lut_input_is_state(self): node = _make_node( "call_function", "coreai", "constexpr_blockwise_shift_scale", args=(lut_node, scale) ) - assert _is_state_node(node) is True + assert is_state_node(node) is True def test_shift_scale_is_state(self): """constexpr_blockwise_shift_scale is always state. This op is only @@ -87,7 +87,7 @@ def test_shift_scale_is_state(self): node = _make_node( "call_function", "coreai", "constexpr_blockwise_shift_scale", args=(data, scale) ) - assert _is_state_node(node) is True + assert is_state_node(node) is True def test_aten_op_with_all_state_inputs_is_not_state(self): """An aten call_function whose inputs are all get_attr is NOT state. @@ -97,7 +97,7 @@ def test_aten_op_with_all_state_inputs_is_not_state(self): weight = _make_node("get_attr") bias = _make_node("get_attr") node = _make_node("call_function", "aten", "add", args=(weight, bias)) - assert _is_state_node(node) is False + assert is_state_node(node) is False def _find_coreai_nodes(gm: torch.fx.GraphModule, op_name: str) -> list[torch.fx.Node]: @@ -168,7 +168,7 @@ def test_joint_compression_lut_to_dense_not_quantized( # Each lut_to_dense must be recognized as state for node in lut_nodes: - assert _is_state_node(node) is True, ( + assert is_state_node(node) is True, ( f"lut_to_dense node {node.name} not identified as state" ) @@ -255,7 +255,7 @@ def test_joint_compression_lut_quantized_shift_scale_not_quantized( # Both op types must be recognized as state for node in lut_nodes + shift_scale_nodes: - assert _is_state_node(node) is True, ( + assert is_state_node(node) is True, ( f"{node.target._opname} node {node.name} not identified as state" ) diff --git a/tests/quantization/test_qparams_calculator.py b/tests/quantization/test_qparams_calculator.py index 68e441b..2b5a962 100644 --- a/tests/quantization/test_qparams_calculator.py +++ b/tests/quantization/test_qparams_calculator.py @@ -16,8 +16,10 @@ ) from coreai_opt.quantization.spec.factory import QuantizationComponentFactory from coreai_opt.quantization.spec.qparams_calculator import ( + DynamicQParamsCalculator, GlobalMinMaxQParamsCalculator, MovingAverageQParamsCalculator, + StatelessQParamsCalculatorBase, StaticQParamsCalculator, ) from coreai_opt.quantization.spec.range_calculator import MinMaxRangeCalculator @@ -833,6 +835,167 @@ def test_scale_monotonically_nondecreasing(self): prev_scale = scale.clone() +class TestDynamicQParamsCalculator: + """Tests specific to DynamicQParamsCalculator runtime-recompute behavior.""" + + def _make_calculator(self, granularity=None, float_range=None): + if granularity is None: + granularity = PerTensorGranularity() + if float_range is None: + float_range = [None, None] + return DynamicQParamsCalculator( + dtype=torch.int8, + qscheme=QuantizationScheme.SYMMETRIC, + granularity=granularity, + target_dtype=torch.int8, + quant_min=-128, + quant_max=127, + range_calculator=MinMaxRangeCalculator(granularity), + float_range=float_range, + ) + + def test_inherits_stateless_base(self): + """DynamicQParamsCalculator must inherit from StatelessQParamsCalculatorBase""" + assert issubclass(DynamicQParamsCalculator, StatelessQParamsCalculatorBase) + + def test_set_export_mode_true_raises(self): + """Export mode is unsupported for stateless calculators — qparams are input-dependent.""" + calc = self._make_calculator() + with pytest.raises( + NotImplementedError, + match="Stateless quantization .* does not support export mode", + ): + calc.set_export_mode(enabled=True) + + @pytest.mark.parametrize( + "float_range", + [[-1.0, 1.0], [-1.0, None], [None, 1.0]], + ) + def test_rejects_non_none_float_range(self, float_range): + """Dynamic requires float_range=[None, None].""" + with pytest.raises(ValueError, match=r"requires float_range=\[None, None\]"): + self._make_calculator(float_range=float_range) + + def test_get_qparams_not_defined(self): + """get_qparams is structurally absent on stateless calculators — there + is no cached value to return. Attribute access raises AttributeError.""" + calc = self._make_calculator() + calc(torch.randn(4, 8)) # populate debug attributes + with pytest.raises(AttributeError): + calc.get_qparams() + + def test_qparams_recompute_per_forward(self): + """Different inputs must yield recomputed qparams matching each input's range + directly — locks both per-forward recompute (no caching) AND correct math on + each forward.""" + calc = self._make_calculator() + + x1 = torch.tensor([-1.0, 0.0, 1.0]) + scale1, zp1, minval1 = calc(x1) + torch.testing.assert_close(scale1, torch.tensor([1.0 / 127.5])) + assert zp1 == 0 + torch.testing.assert_close(minval1, torch.tensor([-1.0])) + + x2 = torch.tensor([-10.0, 0.0, 10.0]) + scale2, zp2, minval2 = calc(x2) + torch.testing.assert_close(scale2, torch.tensor([10.0 / 127.5])) + assert zp2 == 0 + torch.testing.assert_close(minval2, torch.tensor([-10.0])) + + @pytest.mark.parametrize( + "qscheme,dtype,x", + [ + pytest.param(QuantizationScheme.SYMMETRIC, torch.int8, [-1.0, 0.0, 3.0], id="int8_sym"), + pytest.param( + QuantizationScheme.ASYMMETRIC, torch.int8, [-1.0, 0.0, 7.0], id="int8_asym" + ), + pytest.param( + QuantizationScheme.SYMMETRIC_WITH_CLIPPING, + torch.int8, + [-3.0, 0.0, 3.0], + id="int8_sym_clip", + ), + pytest.param( + QuantizationScheme.SYMMETRIC, torch.uint8, [-3.0, 0.0, 1.0], id="uint8_sym" + ), + pytest.param( + QuantizationScheme.ASYMMETRIC, torch.uint8, [-1.0, 0.0, 5.0], id="uint8_asym" + ), + pytest.param( + QuantizationScheme.SYMMETRIC_WITH_CLIPPING, + torch.int4, + [-1.5, 0.0, 1.5], + id="int4_sym_clip", + ), + pytest.param( + QuantizationScheme.SYMMETRIC, + torch.float8_e4m3fn, + [-3.0, 0.0, 3.0], + id="fp8_e4m3", + ), + pytest.param( + QuantizationScheme.SYMMETRIC, torch.float8_e5m2, [-3.0, 0.0, 3.0], id="fp8_e5m2" + ), + pytest.param( + QuantizationScheme.SYMMETRIC, + torch.float4_e2m1fn_x2, + [-3.0, 0.0, 3.0], + id="fp4", + ), + ], + ) + def test_matches_static_for_single_input(self, qscheme, dtype, x): + """For a single forward, dynamic must match static across the + dtype/qscheme matrix — equivalence with static's known-good math is + the strongest correctness check for dynamic.""" + granularity = PerTensorGranularity() + spec = QuantizationSpec(dtype=dtype, qscheme=qscheme, granularity=granularity) + common_kwargs = dict( + dtype=dtype, + qscheme=qscheme, + granularity=granularity, + target_dtype=spec.target_dtype, + quant_min=spec.quant_min, + quant_max=spec.quant_max, + range_calculator=MinMaxRangeCalculator(granularity), + float_range=[None, None], + scale_dtype=spec.scale_dtype, + ) + static_calc = StaticQParamsCalculator(**common_kwargs) + dynamic_calc = DynamicQParamsCalculator(**common_kwargs) + x_tensor = torch.tensor(x, dtype=torch.float32) + + scale_s, zp_s, minval_s = static_calc(x_tensor) + scale_d, zp_d, minval_d = dynamic_calc(x_tensor) + + assert torch.equal(scale_s, scale_d) + # FP4/FP8 return None for zero_point and minval. + if zp_s is None: + assert zp_d is None + else: + assert torch.equal(zp_s, zp_d) + if minval_s is None: + assert minval_d is None + else: + assert torch.equal(minval_s, minval_d) + + def test_variable_shape_scale_supported(self): + """LLM use case: per-token (per-channel along seq axis) dynamic + quantization where seq_len varies across forwards. Scale shape must + change correspondingly + """ + granularity = PerChannelGranularity(axis=1) # per-token on (B, seq, H) + calc = self._make_calculator(granularity=granularity) + + x_short = torch.randn(2, 8, 16) # seq_len = 8 + scale_short, _, _ = calc(x_short) + assert scale_short.shape == (1, 8, 1) + + x_long = torch.randn(2, 32, 16) # seq_len = 32 + scale_long, _, _ = calc(x_long) + assert scale_long.shape == (1, 32, 1) + + @pytest.mark.parametrize( "precision_dtype", [ diff --git a/tests/quantization/test_quantization.py b/tests/quantization/test_quantization.py index d5ce8bf..8e3b917 100644 --- a/tests/quantization/test_quantization.py +++ b/tests/quantization/test_quantization.py @@ -8,25 +8,34 @@ Each test class covers a distinct scenario or feature area of the quantizer. Tests are parametrized via the ``execution_mode`` fixture so every scenario -runs for both graph and eager mode in a single test definition. Where a mode -is known to be broken, ``pytest.xfail`` is used inline with an explanation. +runs for both graph and eager mode in a single test definition. """ import pytest import torch import torch.nn as nn +from coreai_opt import ExportBackend from coreai_opt.quantization import ( ModuleQuantizerConfig, + QuantizationSpec, Quantizer, QuantizerConfig, ) from coreai_opt.quantization._graph.quantizer import GraphQuantizer +from coreai_opt.quantization.config.quantization_config import QATSchedule from coreai_opt.quantization.spec import ( + PerTensorGranularity, + QuantizationScheme, default_activation_quantization_spec, default_weight_quantization_spec, ) from coreai_opt.quantization.spec.fake_quantize import FakeQuantizeImplBase +from coreai_opt.quantization.spec.qparams_calculator import ( + DynamicQParamsCalculator, + MovingAverageQParamsCalculator, +) +from tests.models.simple import SimpleLinearModel # --------------------------------------------------------------------------- # Fixtures and helpers @@ -375,3 +384,489 @@ def test_later_module_name_config_wins_all_invocations(self, execution_mode, exa assert _count_fake_quant_modules(later_incl) == _count_fake_quant_modules(only_incl) assert _count_fake_quant_modules(later_excl) == _count_fake_quant_modules(only_excl) assert _count_fake_quant_modules(later_excl) < _count_fake_quant_modules(later_incl) + + +class TestDynamicActivationQuantization: + """Lifecycle and finalize-rejection tests for dynamic activation quantization. + + Uses a 2-Linear model (l1=dynamic, l2=moving-average) and walks setup → + calibration → fake-quant inference. Verifies dynamic qparams change per + inference while moving-average qparams stay frozen post-calibration. + """ + + @staticmethod + def _get_activation_fq(prepared_model, execution_mode, layer_prefix, calculator_cls): + """Find the unique input-activation FakeQuantize for a layer. + + Eager identifies by submodule name (``_quantize_input``). Graph + identifies by ``qparams_calculator`` class — each layer in this test + has a distinct activation spec, so the calculator class is unique. + """ + if execution_mode == "eager": + for name, mod in prepared_model.named_modules(): + if name.startswith(layer_prefix) and name.endswith("quantize_input"): + return mod + raise AssertionError(f"No input activation FQ for prefix {layer_prefix!r}") + + matches = [ + m + for m in prepared_model.modules() + if isinstance(m, FakeQuantizeImplBase) + and isinstance(m.qparams_calculator, calculator_cls) + ] + assert len(matches) == 1, ( + f"Expected exactly 1 FakeQuantize with {calculator_cls.__name__}, got {len(matches)}" + ) + return matches[0] + + def _make_mixed_dynamic_static_config( + self, execution_mode: str, qat_schedule: QATSchedule | None = None + ) -> QuantizerConfig: + """l1: dynamic activation; l2: moving-average activation; both: static weight.""" + weight_spec = default_weight_quantization_spec() + dynamic_act_spec = QuantizationSpec( + dtype=torch.int8, + qscheme=QuantizationScheme.SYMMETRIC, + granularity=PerTensorGranularity(), + qparam_calculator_cls="dynamic", + ) + moving_avg_act_spec = default_activation_quantization_spec() + return QuantizerConfig( + global_config=None, + module_name_configs={ + "l1": ModuleQuantizerConfig( + op_state_spec={"weight": weight_spec}, + op_input_spec={"*": dynamic_act_spec}, + op_output_spec=None, + qat_schedule=qat_schedule, + ), + "l2": ModuleQuantizerConfig( + op_state_spec={"weight": weight_spec}, + op_input_spec={"*": moving_avg_act_spec}, + op_output_spec=None, + qat_schedule=qat_schedule, + ), + }, + ).set_execution_mode(execution_mode) + + def test_dynamic_qparams_lifecycle(self, execution_mode): + # 1. Initialize model + mixed-spec config and prepare. + config = self._make_mixed_dynamic_static_config(execution_mode) + quantizer = Quantizer(SimpleLinearModel(), config) + prepared_model = quantizer.prepare((torch.randn(4, 64),)) + + # 2. Verify calculator types are wired correctly per layer. + dynamic_fq = self._get_activation_fq( + prepared_model, execution_mode, "l1", DynamicQParamsCalculator + ) + moving_avg_fq = self._get_activation_fq( + prepared_model, execution_mode, "l2", MovingAverageQParamsCalculator + ) + assert isinstance(dynamic_fq.qparams_calculator, DynamicQParamsCalculator) + assert isinstance(moving_avg_fq.qparams_calculator, MovingAverageQParamsCalculator) + + # 3. Calibration: feed several batches inside calibration_mode(). + torch.manual_seed(0) + with quantizer.calibration_mode(): + for _ in range(5): + prepared_model(torch.randn(4, 64)) + + dyn_scale_before_forward = dynamic_fq.qparams_calculator.scale.clone() + moving_avg_scale_before_forward = moving_avg_fq.qparams_calculator.scale.clone() + + # 4. Fake-quant mode (default after calibration_mode exits): run an input + # with deterministically larger magnitude than calibration so dynamic's + # recomputed scale provably differs from its calibrated value. + prepared_model(torch.randn(4, 64) * 10.0) + dyn_scale_after_forward = dynamic_fq.qparams_calculator.scale.clone() + moving_avg_scale_after_forward = moving_avg_fq.qparams_calculator.scale.clone() + + # Dynamic recomputes scale per inference; moving-average is frozen post-calibration. + assert not torch.equal(dyn_scale_before_forward, dyn_scale_after_forward) + assert torch.equal(moving_avg_scale_before_forward, moving_avg_scale_after_forward) + + @pytest.mark.parametrize( + "backend,is_supported", + [ + (ExportBackend.CoreAI, False), + (ExportBackend.CoreML, False), + (ExportBackend._TORCH, True), + ], + ) + def test_finalize_rejects_dynamic_for_non_torch_backends( + self, execution_mode, backend, is_supported + ): + """``finalize`` must reject CoreAI/CoreML for dynamic FakeQuantize + modules. ``_TORCH`` is allowed since it returns the prepared model as-is.""" + config = self._make_mixed_dynamic_static_config(execution_mode) + quantizer = Quantizer(SimpleLinearModel(), config) + prepared_model = quantizer.prepare((torch.randn(4, 64),)) + + if is_supported: + finalized = quantizer.finalize(prepared_model, backend=backend) + assert finalized is not None + else: + with pytest.raises(NotImplementedError, match="dynamic quantization"): + quantizer.finalize(prepared_model, backend=backend) + + def test_qat_schedule_does_not_disable_dynamic_observer(self, execution_mode): + """QAT schedule's ``disable_observer`` transition must skip dynamic FQs.""" + config = self._make_mixed_dynamic_static_config( + execution_mode, + qat_schedule=QATSchedule(enable_observer=0, enable_fake_quant=1, disable_observer=2), + ) + quantizer = Quantizer(SimpleLinearModel(), config) + prepared_model = quantizer.prepare((torch.randn(4, 64),)) + + dynamic_fq = self._get_activation_fq( + prepared_model, execution_mode, "l1", DynamicQParamsCalculator + ) + moving_avg_fq = self._get_activation_fq( + prepared_model, execution_mode, "l2", MovingAverageQParamsCalculator + ) + + # Step past disable_observer=2 inside training_mode (which is what + # actually invokes the schedule via _maybe_apply_qat_schedule). + with quantizer.training_mode(): + for _ in range(5): + quantizer.step() + + assert moving_avg_fq.observer_enabled.item() == 0 + assert dynamic_fq.observer_enabled.item() == 1 + + +class TestSharedWeightQuantization: + class _LeafA(nn.Module): + def __init__(self): + super().__init__() + self.my_weight = nn.Parameter(torch.randn(2, 2)) + + def forward(self, x): + return torch.nn.functional.linear(x, self.my_weight) + + class _LeafB(nn.Module): + def __init__(self): + super().__init__() + self.other_weight = nn.Parameter(torch.randn(2, 2)) + + def forward(self, x): + return torch.nn.functional.linear(x, self.other_weight) + + class _SharedWeightModel(nn.Module): + """Two leaves whose state tensors alias the same parameter. + + ``linear2.other_weight is linear1.my_weight`` after construction. + """ + + def __init__(self): + super().__init__() + self.linear1 = TestSharedWeightQuantization._LeafA() + self.linear2 = TestSharedWeightQuantization._LeafB() + self.linear2.other_weight = self.linear1.my_weight + + def forward(self, x): + return self.linear2(self.linear1(x)) + + # In the below model, the add op consumes a state tensor which is referenced by multiple local + # names: "my_weight" through leaf_a and "other_weight" through leaf_b. + class _AddModelSharedStateInput(torch.nn.Module): + """Model with add op consuming a state tensor which is referenced by multiple local + names: "my_weight" through leaf_a and "other_weight" through leaf_b. + """ + + def __init__(self): + super().__init__() + self.leaf_a = TestSharedWeightQuantization._LeafA() + self.leaf_b = TestSharedWeightQuantization._LeafB() + self.leaf_b.other_weight = self.leaf_a.my_weight + + def forward(self, inp): + x = self.leaf_a.my_weight + inp + return x + + @staticmethod + def _w4_spec() -> QuantizationSpec: + return QuantizationSpec(dtype=torch.int4) + + @pytest.mark.parametrize( + "config, expected_dtype", + [ + ( + QuantizerConfig( + global_config=None, + module_type_configs={ + _LeafA: ModuleQuantizerConfig( + op_input_spec=None, + op_output_spec=None, + op_state_spec={"my_weight": default_weight_quantization_spec()}, + ), + _LeafB: ModuleQuantizerConfig( + op_input_spec=None, + op_output_spec=None, + op_state_spec={"other_weight": _w4_spec()}, + ), + }, + ), + torch.int4, + ), + ( + QuantizerConfig( + global_config=None, + module_type_configs={ + _LeafB: ModuleQuantizerConfig( + op_input_spec=None, + op_output_spec=None, + op_state_spec={"other_weight": _w4_spec()}, + ), + _LeafA: ModuleQuantizerConfig( + op_input_spec=None, + op_output_spec=None, + op_state_spec={"my_weight": default_weight_quantization_spec()}, + ), + }, + ), + default_weight_quantization_spec().dtype, + ), + ( + QuantizerConfig( + global_config=None, + module_name_configs={ + "linear1": ModuleQuantizerConfig( + op_input_spec=None, + op_output_spec=None, + op_state_spec={"my_weight": default_weight_quantization_spec()}, + ), + }, + module_type_configs={ + _LeafB: ModuleQuantizerConfig( + op_input_spec=None, + op_output_spec=None, + op_state_spec={"other_weight": _w4_spec()}, + ), + }, + ), + default_weight_quantization_spec().dtype, + ), + ( + QuantizerConfig( + global_config=None, + module_name_configs={ + "linear2": ModuleQuantizerConfig( + op_input_spec=None, + op_output_spec=None, + op_state_spec={"other_weight": _w4_spec()}, + ), + }, + module_type_configs={ + _LeafA: ModuleQuantizerConfig( + op_input_spec=None, + op_output_spec=None, + op_state_spec={"my_weight": default_weight_quantization_spec()}, + ), + }, + ), + torch.int4, + ), + ( + QuantizerConfig( + global_config=None, + module_name_configs={ + "linear1": ModuleQuantizerConfig( + op_input_spec=None, + op_output_spec=None, + op_state_spec={"my_weight": default_weight_quantization_spec()}, + ), + }, + module_type_configs={ + _LeafB: ModuleQuantizerConfig( + op_input_spec=None, + op_output_spec=None, + op_state_spec=None, + module_state_spec={"other_weight": _w4_spec()}, + ), + }, + ), + torch.int4, + ), + ( + QuantizerConfig( + global_config=None, + module_type_configs={ + _LeafA: ModuleQuantizerConfig( + op_input_spec=None, + op_output_spec=None, + op_state_spec=None, + module_state_spec={"my_weight": default_weight_quantization_spec()}, + ), + _LeafB: ModuleQuantizerConfig( + op_input_spec=None, + op_output_spec=None, + op_state_spec=None, + module_state_spec={"other_weight": _w4_spec()}, + ), + }, + ), + torch.int4, + ), + ( + QuantizerConfig( + global_config=None, + module_type_configs={ + _LeafB: ModuleQuantizerConfig( + op_input_spec=None, + op_output_spec=None, + op_state_spec=None, + module_state_spec={"other_weight": _w4_spec()}, + ), + _LeafA: ModuleQuantizerConfig( + op_input_spec=None, + op_output_spec=None, + op_state_spec=None, + module_state_spec={"my_weight": default_weight_quantization_spec()}, + ), + }, + ), + default_weight_quantization_spec().dtype, + ), + ( + QuantizerConfig( + global_config=None, + module_name_configs={ + "linear1": ModuleQuantizerConfig( + op_input_spec=None, + op_output_spec=None, + op_state_spec=None, + module_state_spec={"my_weight": default_weight_quantization_spec()}, + ), + }, + module_type_configs={ + _LeafB: ModuleQuantizerConfig( + op_input_spec=None, + op_output_spec=None, + op_state_spec=None, + module_state_spec={"other_weight": _w4_spec()}, + ), + }, + ), + default_weight_quantization_spec().dtype, + ), + ], + ) + def test_shared_weight_quantization(self, config, expected_dtype, execution_mode): + """ + Test shared weight quantization for different variations and orderings of configurations + applied to the same shared weight. + """ + model = self._SharedWeightModel() + inp = (torch.randn(1, 2),) + + config = config.set_execution_mode(execution_mode) + quantizer = Quantizer(model, config) + prepared_model = quantizer.prepare(inp) + + if execution_mode == "graph": + node_dict = {node.name: node for node in prepared_model.graph.nodes} + leaf_a_linear_weight = getattr( + prepared_model, node_dict["linear"].all_input_nodes[1].name + ) + assert leaf_a_linear_weight.qparams_calculator.dtype == expected_dtype + leaf_b_linear_weight = getattr( + prepared_model, node_dict["linear_1"].all_input_nodes[1].name + ) + assert leaf_a_linear_weight is leaf_b_linear_weight + else: + assert ( + prepared_model.linear1.parametrizations["my_weight"][0].qparams_calculator.dtype + == expected_dtype + ) + assert ( + prepared_model.linear1.parametrizations["my_weight"][0] + is prepared_model.linear2.parametrizations["other_weight"][0] + ) + + @pytest.mark.parametrize( + "config", + [ + QuantizerConfig( + global_config=None, + module_name_configs={ + "": ModuleQuantizerConfig( + op_input_spec=None, + op_output_spec=None, + op_state_spec={"my_weight": default_weight_quantization_spec()}, + ), + }, + ), + QuantizerConfig( + global_config=None, + module_name_configs={ + "": ModuleQuantizerConfig( + op_input_spec=None, + op_output_spec=None, + op_state_spec={"other_weight": default_weight_quantization_spec()}, + ), + }, + ), + ], + ) + @pytest.mark.parametrize("execution_mode", ["graph", "eager"]) + def test_independent_state_tensor_usage(self, config, execution_mode): + """ + Test that modules using a state tensor can configure the tensor using a local name which is + used by the module owning the tensor. + """ + + model = self._AddModelSharedStateInput() + inp = (torch.randn(1, 2),) + + config = config.set_execution_mode(execution_mode) + quantizer = Quantizer(model, config) + prepared_model = quantizer.prepare(inp) + + quantizers = [ + module + for module in prepared_model.modules() + if isinstance(module, FakeQuantizeImplBase) + ] + assert len(quantizers) == 1 + + if execution_mode == "graph": + node_dict = {node.name: node for node in prepared_model.graph.nodes} + assert "activation_post_process" in node_dict["add"].all_input_nodes[0].name + else: + assert ( + prepared_model.leaf_a.parametrizations["my_weight"][0] + is prepared_model.leaf_b.parametrizations["other_weight"][0] + ) + + def test_op_state_spec_last_key_wins_for_aliased_state(self, execution_mode): + """Last key in op_state_spec wins when a state tensor matches multiple alias keys.""" + model = self._SharedWeightModel() + inp = (torch.randn(1, 2),) + + config = QuantizerConfig( + global_config=None, + module_name_configs={ + "": ModuleQuantizerConfig( + op_input_spec=None, + op_output_spec=None, + op_state_spec={ + "my_weight": default_weight_quantization_spec(), + "other_weight": self._w4_spec(), + }, + ), + }, + ).set_execution_mode(execution_mode) + + quantizer = Quantizer(model, config) + prepared_model = quantizer.prepare(inp) + + if execution_mode == "graph": + node_dict = {node.name: node for node in prepared_model.graph.nodes} + weight = getattr(prepared_model, node_dict["linear"].all_input_nodes[1].name) + assert weight.qparams_calculator.dtype == torch.int4 + else: + assert ( + prepared_model.linear1.parametrizations["my_weight"][0].qparams_calculator.dtype + == torch.int4 + ) diff --git a/tests/quantization/test_state_spec_resolver.py b/tests/quantization/test_state_spec_resolver.py new file mode 100644 index 0000000..42939cb --- /dev/null +++ b/tests/quantization/test_state_spec_resolver.py @@ -0,0 +1,441 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + + +"""Resolver-level unit tests for ``StateSpecResolver``. + +These tests exercise the resolver's public contract directly, without standing +up the quantizer or torch-function-mode machinery. Test fixtures construct a +small ``nn.Module`` and minimal ``ModuleCompressionComponents`` shapes by hand. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from coreai_opt._utils.insertion.torch_function.state_spec_resolver import StateSpecResolver +from coreai_opt._utils.insertion.torch_function.types import ModuleCompressionComponents +from coreai_opt._utils.torch_utils import NamedModule +from coreai_opt.config.spec import CompressionSimulatorBase + + +class _StubSimulator(CompressionSimulatorBase): + """Minimal stand-in for a ``CompressionSimulatorBase`` used in tests. + + Carries a ``label`` so tests can identify which constructor produced this + instance, and records the ``op_to_optimize`` argument the resolver passes. + """ + + def __init__(self, op_to_optimize=None, label=""): + super().__init__() + self.label = label + self.op_to_optimize = op_to_optimize + + def forward(self, x): + return x + + +def _stub_constructor(label): + """Return a callable mimicking a ``PartialConstructor`` that yields ``_StubSimulator``.""" + + def _make(op_to_optimize): + return _StubSimulator(op_to_optimize=op_to_optimize, label=label) + + return _make + + +class _LeafA(nn.Module): + def __init__(self): + super().__init__() + self.my_weight = nn.Parameter(torch.randn(2, 2)) + + def forward(self, x): + return F.linear(x, self.my_weight) + + +class _LeafB(nn.Module): + def __init__(self): + super().__init__() + self.other_weight = nn.Parameter(torch.randn(2, 2)) + + def forward(self, x): + return F.linear(x, self.other_weight) + + +class _SharedWeightModel(nn.Module): + """Two leaves whose state tensors alias the same parameter. + + After construction, ``linear2.other_weight is linear1.my_weight``. + """ + + def __init__(self): + super().__init__() + self.linear1 = _LeafA() + self.linear2 = _LeafB() + self.linear2.other_weight = self.linear1.my_weight + + def forward(self, x): + return self.linear2(self.linear1(x)) + + +class _ModelWithBuffer(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(2, 2)) + self.register_buffer("scale", torch.ones(2)) + + def forward(self, x): + return F.linear(x, self.weight) * self.scale + + +def _make_resolver(model, module_components_dict, module_priority_dict): + return StateSpecResolver( + model=model, + module_components_dict=module_components_dict, + module_priority_dict=module_priority_dict, + ) + + +# --------------------------------------------------------------------------- +# 1. is_state_tensor membership +# --------------------------------------------------------------------------- + + +def test_is_state_tensor_recognizes_parameter(): + model = _LeafA() + resolver = _make_resolver(model, {}, {"": 0}) + assert resolver.is_state_tensor(model.my_weight) + + +def test_is_state_tensor_recognizes_buffer(): + model = _ModelWithBuffer() + resolver = _make_resolver(model, {}, {"": 0}) + assert resolver.is_state_tensor(model.scale) + + +def test_is_state_tensor_rejects_non_tensor_without_raising(): + model = _LeafA() + resolver = _make_resolver(model, {}, {"": 0}) + assert resolver.is_state_tensor(7) is False + assert resolver.is_state_tensor(None) is False + assert resolver.is_state_tensor([1, 2, 3]) is False + + +def test_is_state_tensor_rejects_unregistered_tensor(): + model = _LeafA() + resolver = _make_resolver(model, {}, {"": 0}) + foreign = torch.randn(2, 2) + assert resolver.is_state_tensor(foreign) is False + + +# --------------------------------------------------------------------------- +# 2. get_all_local_names +# --------------------------------------------------------------------------- + + +def test_get_all_local_names_single_owner(): + model = _LeafA() + resolver = _make_resolver(model, {}, {"": 0}) + assert resolver.get_all_local_names(model.my_weight) == ["my_weight"] + + +def test_get_all_local_names_shared_state_returns_each_owners_local_name(): + model = _SharedWeightModel() + resolver = _make_resolver(model, {}, {"linear1": 0, "linear2": 1, "": 2}) + names = resolver.get_all_local_names(model.linear1.my_weight) + assert sorted(names) == ["my_weight", "other_weight"] + + +def test_get_all_local_names_returns_empty_for_unknown_tensor(): + model = _LeafA() + resolver = _make_resolver(model, {}, {"": 0}) + foreign = torch.randn(2, 2) + assert resolver.get_all_local_names(foreign) == [] + + +class _IntraModuleAliasLeaf(nn.Module): + """Single module that aliases the same parameter under two attribute names.""" + + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(2, 2)) + self.alias = self.weight + + def forward(self, x): + return F.linear(x, self.weight) + + +class _ModelWithIntraModuleAlias(nn.Module): + def __init__(self): + super().__init__() + self.leaf = _IntraModuleAliasLeaf() + + def forward(self, x): + return self.leaf(x) + + +def test_build_inventory_records_intra_module_aliases(): + """A single module aliasing one tensor under two attribute names must record both.""" + model = _ModelWithIntraModuleAlias() + resolver = _make_resolver(model, {}, {"leaf": 0, "": 1}) + + nm_leaf = NamedModule("leaf", model.leaf) + entry = resolver._state_inventory[model.leaf.weight] + + assert entry.owners == [nm_leaf] + assert set(entry.local_names) == {"weight", "alias"} + + assert set(resolver.get_all_local_names(model.leaf.weight)) == {"weight", "alias"} + + +# --------------------------------------------------------------------------- +# 3. resolve with op_state_spec +# --------------------------------------------------------------------------- + + +def test_resolve_op_state_match_caches_optimizer_at_current_priority(): + model = _LeafA() + named = NamedModule("", model) + components_dict = {"my_weight": _stub_constructor("op-match")} + resolver = _make_resolver(model, {}, {"": 0}) + + resolver.resolve(F.linear, model.my_weight, named, components_dict) + + optimizer = resolver.get_optimizer(model.my_weight) + assert isinstance(optimizer, _StubSimulator) + assert optimizer.label == "op-match" + assert optimizer.op_to_optimize is F.linear + cached_opt, cached_priority = resolver._optimizer_cache[model.my_weight] + assert cached_opt is optimizer + assert cached_priority == 0 + + +def test_resolve_op_state_no_match_caches_none_at_current_priority(): + model = _LeafA() + named = NamedModule("", model) + components_dict = {"some_other_name": _stub_constructor("not-match")} + resolver = _make_resolver(model, {}, {"": 5}) + + resolver.resolve(F.linear, model.my_weight, named, components_dict) + + assert resolver.get_optimizer(model.my_weight) is None + cached_opt, cached_priority = resolver._optimizer_cache[model.my_weight] + assert cached_opt is None + assert cached_priority == 5 + + +def test_resolve_op_state_explicit_none_caches_none_at_current_priority(): + model = _LeafA() + named = NamedModule("", model) + components_dict = {"my_weight": None} + resolver = _make_resolver(model, {}, {"": 3}) + + resolver.resolve(F.linear, model.my_weight, named, components_dict) + + assert resolver.get_optimizer(model.my_weight) is None + cached_opt, cached_priority = resolver._optimizer_cache[model.my_weight] + assert cached_opt is None + assert cached_priority == 3 + + +# --------------------------------------------------------------------------- +# 4. resolve with module_state_spec +# --------------------------------------------------------------------------- + + +def test_resolve_module_state_match_cached_at_module_state_priority(): + model = _LeafA() + named = NamedModule("", model) + module_components_dict = { + named: ModuleCompressionComponents( + module_state_components={"my_weight": _stub_constructor("module-match")}, + ) + } + resolver = _make_resolver(model, module_components_dict, {"": 0}) + + resolver.resolve(F.linear, model.my_weight, named, components_dict={}) + + optimizer = resolver.get_optimizer(model.my_weight) + assert isinstance(optimizer, _StubSimulator) + assert optimizer.label == "module-match" + cached_opt, cached_priority = resolver._optimizer_cache[model.my_weight] + assert cached_opt is optimizer + assert cached_priority == StateSpecResolver._MODULE_STATE_PRIORITY + + +def test_resolve_module_state_not_overwritten_by_higher_priority_op_state(): + """A module_state match cached at -1 must never be overwritten by op_state. + + The skip check uses strict ``>``: ``current_priority > cached_priority``. + Since every op-state priority is >= 0 and module-state priority is -1, the + skip fires for any op-state visit after the module-state cache is set. + """ + model = _SharedWeightModel() + nm_linear1 = NamedModule("linear1", model.linear1) + nm_linear2 = NamedModule("linear2", model.linear2) + module_components_dict = { + nm_linear1: ModuleCompressionComponents( + module_state_components={"my_weight": _stub_constructor("module-spec")}, + ), + } + resolver = _make_resolver(model, module_components_dict, {"linear1": 0, "linear2": 5, "": 10}) + + resolver.resolve(F.linear, model.linear1.my_weight, nm_linear1, components_dict={}) + cached_after_first = resolver.get_optimizer(model.linear1.my_weight) + assert isinstance(cached_after_first, _StubSimulator) + assert cached_after_first.label == "module-spec" + + resolver.resolve( + F.linear, + model.linear2.other_weight, + nm_linear2, + components_dict={"other_weight": _stub_constructor("op-spec-from-linear2")}, + ) + cached_after_second = resolver.get_optimizer(model.linear1.my_weight) + assert cached_after_second is cached_after_first + _, cached_priority = resolver._optimizer_cache[model.linear1.my_weight] + assert cached_priority == StateSpecResolver._MODULE_STATE_PRIORITY + + +# --------------------------------------------------------------------------- +# 5. Priority cache ordering for shared states +# --------------------------------------------------------------------------- + + +def test_lower_priority_first_then_higher_priority_skipped(): + """Lower numeric priority = higher precedence; once cached, a higher-numeric + visit must be skipped (skip check fires when ``current > cached``).""" + model = _SharedWeightModel() + nm_linear1 = NamedModule("linear1", model.linear1) + nm_linear2 = NamedModule("linear2", model.linear2) + resolver = _make_resolver(model, {}, {"linear1": 0, "linear2": 5, "": 10}) + + resolver.resolve( + F.linear, + model.linear1.my_weight, + nm_linear1, + components_dict={"my_weight": _stub_constructor("from-linear1")}, + ) + first = resolver.get_optimizer(model.linear1.my_weight) + assert first.label == "from-linear1" + + resolver.resolve( + F.linear, + model.linear2.other_weight, + nm_linear2, + components_dict={"other_weight": _stub_constructor("from-linear2")}, + ) + second = resolver.get_optimizer(model.linear2.other_weight) + assert second is first + assert second.label == "from-linear1" + _, cached_priority = resolver._optimizer_cache[model.linear1.my_weight] + assert cached_priority == 0 + + +def test_higher_priority_first_then_lower_priority_overwrites(): + """A subsequent visit with lower numeric priority (higher precedence) + overwrites the cache because the skip check ``current > cached`` is False.""" + model = _SharedWeightModel() + nm_linear1 = NamedModule("linear1", model.linear1) + nm_linear2 = NamedModule("linear2", model.linear2) + resolver = _make_resolver(model, {}, {"linear1": 0, "linear2": 5, "": 10}) + + resolver.resolve( + F.linear, + model.linear2.other_weight, + nm_linear2, + components_dict={"other_weight": _stub_constructor("from-linear2")}, + ) + first = resolver.get_optimizer(model.linear2.other_weight) + assert first.label == "from-linear2" + + resolver.resolve( + F.linear, + model.linear1.my_weight, + nm_linear1, + components_dict={"my_weight": _stub_constructor("from-linear1")}, + ) + second = resolver.get_optimizer(model.linear1.my_weight) + assert second.label == "from-linear1" + _, cached_priority = resolver._optimizer_cache[model.linear1.my_weight] + assert cached_priority == 0 + + +def test_equal_priorities_last_writer_wins(): + """Strict ``>`` skip check means equal priorities do NOT trigger skip — the + later writer overwrites the earlier one.""" + model = _SharedWeightModel() + nm_linear1 = NamedModule("linear1", model.linear1) + nm_linear2 = NamedModule("linear2", model.linear2) + resolver = _make_resolver(model, {}, {"linear1": 0, "linear2": 0, "": 10}) + + resolver.resolve( + F.linear, + model.linear1.my_weight, + nm_linear1, + components_dict={"my_weight": _stub_constructor("first-writer")}, + ) + first = resolver.get_optimizer(model.linear1.my_weight) + assert first.label == "first-writer" + + resolver.resolve( + F.linear, + model.linear2.other_weight, + nm_linear2, + components_dict={"other_weight": _stub_constructor("second-writer")}, + ) + second = resolver.get_optimizer(model.linear2.other_weight) + assert second is not first + assert second.label == "second-writer" + _, cached_priority = resolver._optimizer_cache[model.linear1.my_weight] + assert cached_priority == 0 + + +# --------------------------------------------------------------------------- +# 6. get_optimizer +# --------------------------------------------------------------------------- + + +def test_get_optimizer_returns_cached_value(): + model = _LeafA() + named = NamedModule("", model) + resolver = _make_resolver(model, {}, {"": 0}) + + resolver.resolve( + F.linear, + model.my_weight, + named, + components_dict={"my_weight": _stub_constructor("cached")}, + ) + + optimizer = resolver.get_optimizer(model.my_weight) + assert isinstance(optimizer, _StubSimulator) + assert optimizer.label == "cached" + + +def test_get_optimizer_returns_none_for_unresolved_tensor(): + model = _LeafA() + resolver = _make_resolver(model, {}, {"": 0}) + assert resolver.get_optimizer(model.my_weight) is None + + +# --------------------------------------------------------------------------- +# 7. Module-state walk caching (skip redundant walks after first visit) +# --------------------------------------------------------------------------- + + +def test_optimizer_cache_set_after_first_visit_no_match(): + """Test that the ``_optimizer_cache`` contains the state_tensor after the first call even when + no match. + """ + model = _SharedWeightModel() + nm_linear1 = NamedModule("linear1", model.linear1) + resolver = _make_resolver(model, {}, {"linear1": 0, "linear2": 5, "": 10}) + + assert model.linear1.my_weight not in resolver._optimizer_cache + + resolver.resolve(F.linear, model.linear1.my_weight, nm_linear1, components_dict={}) + + assert model.linear1.my_weight in resolver._optimizer_cache diff --git a/tests/test_inspection.py b/tests/test_inspection.py index 05bde15..483f3e4 100644 --- a/tests/test_inspection.py +++ b/tests/test_inspection.py @@ -11,6 +11,9 @@ import torch import torch.nn as nn +from coreai_opt._utils.insertion.torch_function.module_boundary_tracker import ( + TensorIdVersion, +) from coreai_opt._utils.torch_utils import export_model as _export_model from coreai_opt.base_model_compressor import _BaseModelCompressor from coreai_opt.inspection import ( @@ -18,18 +21,15 @@ ModelSummary, ModuleInfo, ) +from coreai_opt.inspection._eager_mode import _EagerOpDiscoveryMode +from coreai_opt.inspection.types import BoundaryEdge, InputEdge, OpInfo +from coreai_opt.palettization import KMeansPalettizer from coreai_opt.quantization import Quantizer from coreai_opt.quantization.config.quantization_config import ExecutionMode execution_modes = pytest.mark.parametrize( "execution_mode", - [ - ExecutionMode.GRAPH, - pytest.param( - ExecutionMode.EAGER, - marks=pytest.mark.xfail(reason="Eager inspection not yet implemented"), - ), - ], + [ExecutionMode.GRAPH, ExecutionMode.EAGER], ) @@ -142,16 +142,24 @@ def test_simple_conv_model(self, execution_mode: ExecutionMode) -> None: # Root is a ModuleSummary assert isinstance(inspector.summary.model, ModuleInfo) - # Op discovery + # Op discovery — names differ by mode ops = inspector.summary.model.all_ops() op_names = [op.op_name for op in ops] - assert "conv2d" in op_names - assert "linear" in op_names + if execution_mode == ExecutionMode.GRAPH: + assert "conv2d" in op_names + assert "linear" in op_names + conv_op_name = "conv2d" + linear_op_name = "linear" + else: + assert "conv.conv2d" in op_names + assert "fc.linear" in op_names + conv_op_name = "conv.conv2d" + linear_op_name = "fc.linear" # Op types - conv_op = next(op for op in ops if op.op_name == "conv2d") + conv_op = next(op for op in ops if op.op_name == conv_op_name) assert conv_op.op_type == "conv2d" - linear_op = next(op for op in ops if op.op_name == "linear") + linear_op = next(op for op in ops if op.op_name == linear_op_name) assert linear_op.op_type == "linear" # Module stack @@ -168,9 +176,9 @@ def test_simple_conv_model(self, execution_mode: ExecutionMode) -> None: assert inspector.get_matched_ops_for_module_type("NonexistentModule") == () # Query: by name (exact) - conv_by_name = inspector.get_matched_ops_for_op_name("conv2d") + conv_by_name = inspector.get_matched_ops_for_op_name(conv_op_name) assert len(conv_by_name) == 1 - assert conv_by_name[0].op_name == "conv2d" + assert conv_by_name[0].op_name == conv_op_name # Query: by name (regex) all_by_regex = inspector.get_matched_ops_for_op_name(".*") @@ -226,18 +234,24 @@ def test_nested_model(self, execution_mode: ExecutionMode) -> None: compressor=Quantizer, ) - # Op discovery and hierarchy + # Op discovery and hierarchy — names differ by mode op_names = [op.op_name for op in inspector.summary.model.all_ops()] - assert "conv2d" in op_names - assert "conv2d_1" in op_names - assert "linear" in op_names - assert "linear_1" in op_names - - # Graph order - assert op_names.index("conv2d") < op_names.index("linear") + if execution_mode == ExecutionMode.GRAPH: + conv_first, conv_second = "conv2d", "conv2d_1" + linear_first, linear_second = "linear", "linear_1" + else: + conv_first, conv_second = "encoder.conv1.conv2d", "encoder.conv2.conv2d" + linear_first, linear_second = "decoder.fc1.linear", "decoder.fc2.linear" + assert conv_first in op_names + assert conv_second in op_names + assert linear_first in op_names + assert linear_second in op_names + + # Execution order: convs before linears + assert op_names.index(conv_first) < op_names.index(linear_first) # Nested module FQNs - conv_op = next(op for op in inspector.summary.model.all_ops() if op.op_name == "conv2d") + conv_op = next(op for op in inspector.summary.model.all_ops() if op.op_name == conv_first) fqns = [m.module_name for m in conv_op.module_stack] assert "encoder" in fqns assert "encoder.conv1" in fqns @@ -250,24 +264,24 @@ def test_nested_model(self, execution_mode: ExecutionMode) -> None: # Query: by module name encoder_ops = inspector.get_matched_ops_for_module_name("encoder") encoder_op_names = [op.op_name for op in encoder_ops] - assert "conv2d" in encoder_op_names - assert "conv2d_1" in encoder_op_names + assert conv_first in encoder_op_names + assert conv_second in encoder_op_names # Query: by module name (leaf) leaf_ops = inspector.get_matched_ops_for_module_name("encoder.conv1") assert len(leaf_ops) == 1 - assert leaf_ops[0].op_name == "conv2d" + assert leaf_ops[0].op_name == conv_first # Query: by module name (regex) encoder_regex_ops = inspector.get_matched_ops_for_module_name(r"encoder\..*") encoder_regex_op_names = [op.op_name for op in encoder_regex_ops] - assert "conv2d" in encoder_regex_op_names - assert "conv2d_1" in encoder_regex_op_names + assert conv_first in encoder_regex_op_names + assert conv_second in encoder_regex_op_names # Query: by name (regex matching multiple ops) - conv_ops_by_name = inspector.get_matched_ops_for_op_name(r"conv2d.*") + conv_ops_by_name = inspector.get_matched_ops_for_op_name(r".*conv2d.*") assert len(conv_ops_by_name) == 2 - linear_ops_by_name = inspector.get_matched_ops_for_op_name(r"linear.*") + linear_ops_by_name = inspector.get_matched_ops_for_op_name(r".*linear.*") assert len(linear_ops_by_name) == 2 # Formatting @@ -287,7 +301,11 @@ def test_arithmetic_model(self, execution_mode: ExecutionMode) -> None: compressor=Quantizer, ) op_names = [op.op_name for op in inspector.summary.model.all_ops()] - assert "linear" in op_names + # Linear should be present (module-qualified in eager) + if execution_mode == ExecutionMode.GRAPH: + assert "linear" in op_names + else: + assert "linear.linear" in op_names add_ops = [n for n in op_names if "add" in n] mul_ops = [n for n in op_names if "mul" in n] assert len(add_ops) >= 2, f"Expected at least 2 add ops, got {add_ops}" @@ -339,20 +357,22 @@ def test_op_connectivity_arithmetic_model(self, execution_mode: ExecutionMode) - ops = inspector.summary.model.all_ops() ops_by_name = {op.op_name: op for op in ops} - # linear has a placeholder in inputs - linear_op = ops_by_name["linear"] - assert any(inp.op_name not in ops_by_name for inp in linear_op.inputs), ( - "linear should have a placeholder/parameter input" - ) + linear_name = "linear" if execution_mode == ExecutionMode.GRAPH else "linear.linear" + linear_op = ops_by_name[linear_name] + # linear's outputs should include an add op - assert any("add" in out.op_name for out in linear_op.outputs) + assert any( + "add" in out.op_name for consumers in linear_op.outputs.values() for out in consumers + ) # add ops have correct inputs - add_op = ops_by_name["add"] - assert len(add_op.inputs) >= 2 + add_name = "add" + add_op = ops_by_name[add_name] + assert len(add_op.inputs) == 2 - # mul has two inputs (both add-related ops) - mul_op = ops_by_name["mul"] + # mul has inputs (both add-related ops) + mul_name = "mul" + mul_op = ops_by_name[mul_name] assert len(mul_op.inputs) == 2 def test_module_io_nested_model(self, execution_mode: ExecutionMode) -> None: @@ -366,13 +386,13 @@ def test_module_io_nested_model(self, execution_mode: ExecutionMode) -> None: encoder = root.child_modules["encoder"] decoder = root.child_modules["decoder"] - # Encoder: first conv is an input, last conv is an output + # Encoder: has input and output ops assert len(encoder.input_ops) >= 1 - encoder_input_names = {op.op_name for op in encoder.input_ops} - assert "conv2d" in encoder_input_names or any("conv" in n for n in encoder_input_names) + encoder_input_names = {e.op.op_name for edges in encoder.input_ops.values() for e in edges} + assert any("conv" in n for n in encoder_input_names) assert len(encoder.output_ops) >= 1 - # Decoder: first linear is an input, last linear is an output + # Decoder: has input and output ops assert len(decoder.input_ops) >= 1 assert len(decoder.output_ops) >= 1 @@ -402,7 +422,10 @@ def test_tree_structure_nested_model(self, execution_mode: ExecutionMode) -> Non # Ops should be nested inside leaf modules, not at root conv1 = encoder.child_modules["encoder.conv1"] conv1_op_names = [op.op_name for op in conv1.ops] - assert "conv2d" in conv1_op_names + if execution_mode == ExecutionMode.GRAPH: + assert "conv2d" in conv1_op_names + else: + assert "encoder.conv1.conv2d" in conv1_op_names # Decoder should have children for fc1 and fc2 decoder = root.child_modules["decoder"] @@ -523,8 +546,12 @@ def test_all_ops(self, execution_mode: ExecutionMode) -> None: encoder = root.get_submodule("encoder") encoder_ops = encoder.all_ops() encoder_op_names = [op.op_name for op in encoder_ops] - assert "conv2d" in encoder_op_names - assert "conv2d_1" in encoder_op_names + if execution_mode == ExecutionMode.GRAPH: + assert "conv2d" in encoder_op_names + assert "conv2d_1" in encoder_op_names + else: + assert "encoder.conv1.conv2d" in encoder_op_names + assert "encoder.conv2.conv2d" in encoder_op_names assert not any("linear" in n for n in encoder_op_names) # Leaf module all_ops should equal its direct ops @@ -542,6 +569,548 @@ def test_empty_summary_after_compressor_filter(self, execution_mode: ExecutionMo # The root should be non-empty for a real model with quantizable ops assert inspector.summary.model.child_modules or inspector.summary.model.ops + def test_source_frames(self, execution_mode: ExecutionMode) -> None: + """Verify source frames are captured from forward() methods in both modes.""" + inspector = ModelInspector( + _SimpleConvModel(), + (torch.randn(1, 3, 8, 8),), + execution_mode=execution_mode, + compressor=Quantizer, + ) + conv_op = next(op for op in inspector.summary.model.all_ops() if op.op_type == "conv2d") + assert len(conv_op.source_frames) >= 1 + assert all(f.function_name == "forward" for f in conv_op.source_frames) + assert all(f.filename != "" for f in conv_op.source_frames) + + def test_connectivity_through_non_captured_ops(self, execution_mode: ExecutionMode) -> None: + """Verify filtered ops still provide connectivity edges between tree ops.""" + + class _ReluBetweenLinears(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = nn.Linear(10, 10) + self.linear2 = nn.Linear(10, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear1(x) + x = torch.relu(x) + x = self.linear2(x) + return x + + inspector = ModelInspector( + _ReluBetweenLinears(), + (torch.randn(1, 10),), + execution_mode=execution_mode, + compressor=Quantizer, + ) + + # Only linear ops should appear in the tree; relu is filtered out + tree_ops = inspector.summary.model.all_ops() + tree_op_types = {op.op_type for op in tree_ops if op.op_type} + assert "linear" in tree_op_types + assert "relu" not in tree_op_types + + # Identify linear2 by its module_stack (names mirror named_modules in both modes) + linear_ops = [op for op in tree_ops if op.op_type == "linear"] + linear2_op = next( + op + for op in linear_ops + if any(ctx.module_name.endswith("linear2") for ctx in op.module_stack) + ) + + # linear2's input should chain through relu back to a linear op + relu_input = next( + (inp for inp in linear2_op.inputs if inp.op_type == "relu"), + None, + ) + assert relu_input is not None, ( + f"Expected relu in linear2.inputs, got {[i.op_name for i in linear2_op.inputs]}" + ) + linear1_upstream = next( + (inp for inp in relu_input.inputs if inp.op_type == "linear"), + None, + ) + assert linear1_upstream is not None + + def test_boundary_ops_with_non_tree_ops(self, execution_mode: ExecutionMode) -> None: + """Verify module boundary ops are correct when non-tree ops sit between tree ops.""" + + class _Inner(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = nn.Linear(10, 10) + self.linear2 = nn.Linear(10, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear1(x) + x = torch.relu(x) + x = self.linear2(x) + return x + + class _Outer(nn.Module): + def __init__(self) -> None: + super().__init__() + self.inner = _Inner() + self.final = nn.Linear(10, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.inner(x) + x = self.final(x) + return x + + inspector = ModelInspector( + _Outer(), + (torch.randn(1, 10),), + execution_mode=execution_mode, + compressor=Quantizer, + ) + root = inspector.summary.model + inner = root.child_modules["inner"] + + inner_tree_ops = inner.all_ops() + inner_op_types = {op.op_type for op in inner_tree_ops if op.op_type} + assert "linear" in inner_op_types + assert "relu" not in inner_op_types + assert len(inner_tree_ops) == 2 # linear1 and linear2, relu filtered out + + linears_by_module = {} + for op in inner_tree_ops: + if op.op_type != "linear": + continue + for ctx in op.module_stack: + if ctx.module_name == "inner.linear1": + linears_by_module["linear1"] = op + elif ctx.module_name == "inner.linear2": + linears_by_module["linear2"] = op + + assert set(linears_by_module) == {"linear1", "linear2"} + + # linear1 is an input boundary (data comes from outside inner). + assert any( + e.op == linears_by_module["linear1"] + for edges in inner.input_ops.values() + for e in edges + ) + # linear2 is an output boundary (data goes to outside inner). + assert any(e.op == linears_by_module["linear2"] for e in inner.output_ops.values()) + # linear2 is NOT an input boundary (its data flows from within inner via relu). + assert all( + e.op != linears_by_module["linear2"] + for edges in inner.input_ops.values() + for e in edges + ) + + def test_module_level_input_ops(self, execution_mode: ExecutionMode) -> None: + """Verify model-level inputs appear as placeholder-like OpInfos at the root boundary. + + Graph mode emits ``placeholder`` nodes; eager mode emits synthetic ``input_i`` + ops. Both should behave identically: empty module_stack, not is_state, no + further inputs, present in the consuming op's ``inputs`` tuple, but absent + from any module's tree ops or boundary lists. + """ + inspector = ModelInspector( + _SimpleConvModel(), + (torch.randn(1, 3, 8, 8),), + execution_mode=execution_mode, + ) + root = inspector.summary.model + + # The first real op (conv) consumes the user input, so it must be in root.input_ops. + conv_op = next(op for op in root.all_ops() if op.op_type == "conv2d") + assert any(e.op == conv_op for edges in root.input_ops.values() for e in edges) + + # conv's inputs should include at least one placeholder-like OpInfo. + placeholders = [inp for inp in conv_op.inputs if not inp.module_stack and not inp.is_state] + assert len(placeholders) >= 1 + for ph in placeholders: + assert ph.inputs == () + + # Placeholder OpInfos must not appear in any module's tree or boundary lists. + all_tree_ops = root.all_ops() + for module in root.modules(): + for ph in placeholders: + assert ph not in all_tree_ops + assert all(e.op != ph for edges in module.input_ops.values() for e in edges) + assert all(e.op != ph for e in module.output_ops.values()) + + def test_module_level_output_ops(self, execution_mode: ExecutionMode) -> None: + """Verify model-level outputs appear as output-like OpInfos at the root boundary. + + Graph mode emits a single ``output`` node; eager mode emits one ``output_i`` + per output tensor. Both should behave identically: empty module_stack, not + is_state, no further outputs, present in the producing op's ``outputs`` tuple, + but absent from any module's tree ops or boundary lists. + """ + inspector = ModelInspector( + _SimpleConvModel(), + (torch.randn(1, 3, 8, 8),), + execution_mode=execution_mode, + ) + root = inspector.summary.model + + # The last real op (fc linear) produces the model output, so it's in root.output_ops. + linear_op = next(op for op in root.all_ops() if op.op_type == "linear") + assert any(e.op == linear_op for e in root.output_ops.values()) + + output_consumers = [ + out + for consumers in linear_op.outputs.values() + for out in consumers + if not out.module_stack and not out.is_state + ] + assert len(output_consumers) >= 1 + for out in output_consumers: + assert not any(out.outputs.values()) + + all_tree_ops = root.all_ops() + for module in root.modules(): + for out in output_consumers: + assert out not in all_tree_ops + assert all(e.op != out for edges in module.input_ops.values() for e in edges) + assert all(e.op != out for e in module.output_ops.values()) + + def test_state_ops_for_parameters(self, execution_mode: ExecutionMode) -> None: + """Verify parameters consumed by ops appear as is_state=True OpInfos. + + Graph mode emits ``get_attr`` nodes; eager mode emits synthetic state + OpInfos on first reference. Both should share identical semantic + behavior: is_state=True, empty module_stack, no inputs, and excluded + from tree/boundary lists. + """ + inspector = ModelInspector( + _SimpleConvModel(), + (torch.randn(1, 3, 8, 8),), + execution_mode=execution_mode, + ) + root = inspector.summary.model + linear_op = next(op for op in root.all_ops() if op.op_type == "linear") + + state_inputs = [inp for inp in linear_op.inputs if inp.is_state] + state_names = {inp.op_name for inp in state_inputs} + assert any("weight" in n for n in state_names), state_names + assert any("bias" in n for n in state_names), state_names + + for s in state_inputs: + assert s.module_stack == () + assert s.inputs == () + + all_tree_ops = root.all_ops() + for module in root.modules(): + for s in state_inputs: + assert s not in all_tree_ops + assert all(e.op != s for edges in module.input_ops.values() for e in edges) + assert all(e.op != s for e in module.output_ops.values()) + + # State ops: _display_name drops any dotted prefix (e.g., "fc.weight" → "weight") + # to match the suffix-only matching supported by state configs today. + for s in state_inputs: + assert s._display_name == s.op_name.rsplit(".", 1)[-1] + assert "." not in s._display_name + assert s._display_name != s.op_name + + # Non-state ops: _display_name equals op_name. + assert linear_op._display_name == linear_op.op_name + + def test_state_ops_for_buffers(self, execution_mode: ExecutionMode) -> None: + """Verify registered buffers consumed by ops appear as is_state=True OpInfos.""" + + class _WithBuffer(nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("scale", torch.tensor(2.0)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * self.scale + + inspector = ModelInspector( + _WithBuffer(), + (torch.randn(1, 10),), + execution_mode=execution_mode, + ) + mul_op = next(op for op in inspector.summary.model.all_ops() if op.op_type == "mul") + state_inputs = [inp for inp in mul_op.inputs if inp.is_state] + assert len(state_inputs) == 1 + assert "scale" in state_inputs[0].op_name + + def test_shared_state_is_not_duplicated(self, execution_mode: ExecutionMode) -> None: + """Verify a parameter referenced by multiple ops yields a single shared state OpInfo. + + A parameter used more than once in ``forward`` should resolve to the same + ``_OpInfo`` instance each time, not distinct duplicates that merely compare + equal via ``op_name``. + """ + + class _SharedWeightModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = nn.Parameter(torch.randn(10, 10)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = torch.matmul(x, self.weight) + return torch.matmul(y, self.weight) + + inspector = ModelInspector( + _SharedWeightModel(), + (torch.randn(1, 10),), + execution_mode=execution_mode, + ) + + state_inputs = [ + next(inp for inp in op.inputs if inp.is_state) + for op in inspector.summary.model.all_ops() + if any(inp.is_state for inp in op.inputs) + ] + assert len(state_inputs) == 2, "expected two ops to reference the weight" + assert state_inputs[0].op is state_inputs[1].op, ( + "both references should resolve to the same state OpInfo instance" + ) + + def test_boundary_ops_topological_order(self, execution_mode: ExecutionMode) -> None: + """Verify input_ops and output_ops of a module are in topological order. + + Uses a model where the root has two parallel child modules feeding into a + combine op, so naive DFS ordering would differ from topological order. + """ + + class _ParallelModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.branch_a = nn.Linear(4, 4) + self.branch_b = nn.Linear(4, 4) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = self.branch_a(x) + v = self.branch_b(x) + return u + v + + inspector = ModelInspector( + _ParallelModel(), + (torch.randn(1, 4),), + execution_mode=execution_mode, + ) + root = inspector.summary.model + branch_a = root.child_modules["branch_a"] + branch_b = root.child_modules["branch_b"] + + # Each branch has exactly one input spec index (key) with one edge (the linear). + assert len(branch_a.input_ops) == 1 + assert len(branch_a.output_ops) == 1 + assert len(branch_b.input_ops) == 1 + assert len(branch_b.output_ops) == 1 + + # Both branches consume x from outside root. In graph mode each branch gets its + # own spec_idx in root.input_ops; in eager mode both share spec_idx 0 (x appears + # once in root's forward args and fans out to both branches). + all_root_edges = [e for edges in root.input_ops.values() for e in edges] + root_op_names = [e.op.op_name for e in all_root_edges] + assert any( + "branch_a" in n or "branch_a" in str(e.op.module_stack) + for e, n in zip(all_root_edges, root_op_names, strict=True) + ) + assert any( + "branch_b" in n or "branch_b" in str(e.op.module_stack) + for e, n in zip(all_root_edges, root_op_names, strict=True) + ) + + # branch_a's ops appear before branch_b's ops in the all_ops list (execution order). + all_ops = inspector.summary.model.all_ops() + all_op_names = [op.op_name for op in all_ops] + branch_a_edge = branch_a.input_ops[0][0] + branch_b_edge = branch_b.input_ops[0][0] + assert all_op_names.index(branch_a_edge.op.op_name) < all_op_names.index( + branch_b_edge.op.op_name + ) + + # Root input_ops preserves topological order. Find each edge's position as + # (spec_idx, list_idx) and compare — works whether branches share a spec_idx + # (eager: both at 0) or each has its own (graph: 0 and 1). + def _edge_pos(ops: dict, edge: BoundaryEdge) -> tuple[int, int]: + for spec_idx, edges in sorted(ops.items()): + for list_idx, e in enumerate(edges): + if e == edge: + return (spec_idx, list_idx) + raise AssertionError(f"edge not found: {edge}") + + assert _edge_pos(root.input_ops, branch_a_edge) < _edge_pos(root.input_ops, branch_b_edge) + + def test_op_inputs_index_with_state_before_activation( + self, execution_mode: ExecutionMode + ) -> None: + """Verify op input dict key reflects full arg position when a state precedes an activation. + + In both graph and eager mode, when a model parameter appears at an earlier argument + position than an activation (e.g., torch.mm(self.weight, x)), the state occupies + arg index 0 and the activation occupies arg index 1. The op inputs dict must show + key 1 for the activation, not key 0, because the key is the full arg position + (= op_input_spec index), not the position within non-state inputs only. + """ + + class _WeightFirstModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = nn.Parameter(torch.randn(4, 4)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.mm(self.weight, x) + + inspector = ModelInspector( + _WeightFirstModel(), + (torch.randn(4, 4),), + execution_mode=execution_mode, + ) + # Find the op: state at arg 0, activation at arg 1. + mm_op = next( + op + for op in inspector.summary.model.all_ops() + if len(op.inputs) >= 2 and op.inputs[0].is_state and not op.inputs[1].is_state + ) + + assert mm_op.inputs[0].is_state + assert not mm_op.inputs[1].is_state + + # Formatted output must show the activation at dict key 1, not 0. + formatted = inspector.format_summary(colorize=False) + assert "op inputs: {1:" in formatted + + def test_raw_attribute_untracked_in_eager_state_in_graph( + self, execution_mode: ExecutionMode + ) -> None: + """Verify a raw tensor attribute appears as untracked in eager, as a state in graph. + + ``self.mask = torch.ones(8)`` is not registered as a parameter or buffer. + In eager mode the inspector cannot trace its origin, so it appears as + ``untracked_N`` at arg index 1 of the consuming op. In graph mode + ``torch.export`` captures it as a ``get_attr`` node, so it shows up as a + state input. + """ + + class _ModelWithRawAttribute(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(8, 8) + self.mask = torch.ones(8) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) * self.mask + + inspector = ModelInspector( + _ModelWithRawAttribute(), + (torch.randn(2, 8),), + execution_mode=execution_mode, + ) + ops = inspector.summary.model.all_ops() + mul_op = next(op for op in ops if op.op_type == "mul") + mul_non_state_inputs = [inp for inp in mul_op.inputs if not inp.is_state] + mul_state_inputs = [inp for inp in mul_op.inputs if inp.is_state] + + if execution_mode == ExecutionMode.EAGER: + # mask is not registered — shows as untracked at arg index 1 + assert len(mul_non_state_inputs) == 2 + assert any(inp.op_name.startswith("untracked_") for inp in mul_non_state_inputs) + assert not any("mask" in inp.op_name for inp in mul_state_inputs) + else: + # graph mode: FX captures self.mask as get_attr → state + assert len(mul_non_state_inputs) == 1 + assert any("mask" in inp.op_name for inp in mul_state_inputs) + + def test_global_tensor_untracked_in_eager_state_in_graph( + self, execution_mode: ExecutionMode + ) -> None: + """Verify a global tensor appears as untracked in eager, as a lifted state in graph. + + A module-level Python global (``_BIAS = torch.zeros(8)``) has no registered + name in the model. Eager mode cannot trace it and marks it ``untracked_N`` + at arg index 1. Graph mode lifts it as a ``lifted_tensor_N`` placeholder + which the inspector treats as a state. + """ + _bias = torch.zeros(8) + + class _ModelWithGlobalTensor(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(8, 8) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + _bias + + inspector = ModelInspector( + _ModelWithGlobalTensor(), + (torch.randn(2, 8),), + execution_mode=execution_mode, + ) + ops = inspector.summary.model.all_ops() + add_op = next(op for op in ops if op.op_type == "add") + add_non_state_inputs = [inp for inp in add_op.inputs if not inp.is_state] + add_state_inputs = [inp for inp in add_op.inputs if inp.is_state] + + if execution_mode == ExecutionMode.EAGER: + # global tensor is untracked — appears at arg index 1 + assert len(add_non_state_inputs) == 2 + assert any(inp.op_name.startswith("untracked_") for inp in add_non_state_inputs) + else: + # graph mode: torch.export lifts the global as a lifted_tensor placeholder + assert len(add_non_state_inputs) == 1 + assert len(add_state_inputs) >= 1 + + def test_shared_module_op_count_and_boundaries(self, execution_mode: ExecutionMode) -> None: + """Verify shared module handling for eager and graph modes. + + Eager mode only captures ops from the *first* traversal of a shared module + (subsequent calls are blocked by ``traversed_modules``), but re-registers + the second traversal's output tensor so downstream modules resolve it. + Graph mode sees both calls as separate nodes. + """ + + class _ModelWithSharedModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.shared = nn.Linear(8, 8) + self.tail = nn.Linear(8, 8) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.shared(x) + x = self.shared(x) + return self.tail(x) + + inspector = ModelInspector( + _ModelWithSharedModule(), + (torch.randn(2, 8),), + execution_mode=execution_mode, + ) + root = inspector.summary.model + shared = root.child_modules.get("shared") + tail = root.child_modules.get("tail") + assert shared is not None + assert tail is not None + + if execution_mode == ExecutionMode.EAGER: + # Only the first traversal is recorded; second is skipped. + assert len(shared.ops) == 1 + shared_linear = shared.ops[0] + assert shared_linear.op_name == "shared.linear" + + # shared.linear's output goes to tail.linear (via the re-registered + # second-traversal output tensor pointing back to the first traversal). + assert shared_linear.outputs[0] == ( + next(op for op in tail.ops if op.op_name == "tail.linear"), + ) + + # tail.linear's first input comes from shared.linear. + tail_linear = next(op for op in tail.ops if op.op_name == "tail.linear") + assert tail_linear.inputs[0].op.op_name == "shared.linear" + else: + # graph mode unfolds both calls as distinct nodes. + # Graph mode unfolds both calls as distinct nodes named linear and linear_1. + assert len(shared.ops) == 2 + linear_1 = next(op for op in shared.ops if op.op_name == "linear_1") + + # linear feeds into linear_1 (second call reuses the shared weights). + assert linear_1.inputs[0].op.op_name == "linear" + + # linear_1's output goes to tail's linear_2. + tail_linear = next(op for op in tail.ops if op.op_name == "linear_2") + assert tail_linear.inputs[0].op.op_name == "linear_1" + class TestModelInspectorValidation: """Tests for ModelInspector input validation.""" @@ -551,7 +1120,7 @@ def test_rejects_non_module(self) -> None: with pytest.raises(TypeError, match="Expected a torch.fx.GraphModule or torch.nn.Module"): ModelInspector("not a module", (torch.randn(1),), execution_mode="graph") - @execution_modes + @pytest.mark.parametrize("execution_mode", [ExecutionMode.GRAPH, ExecutionMode.EAGER]) def test_example_input_none(self, execution_mode: ExecutionMode) -> None: """Verify ValueError for example_inputs of None when model not a GraphModule and execution_mode is not ExecutionMode.GRAPH.""" @@ -568,12 +1137,6 @@ def test_eager_with_graph_module_raises_type_error(self) -> None: with pytest.raises(TypeError, match="Expected a torch.nn.Module for Eager execution_mode"): ModelInspector(gm, (torch.randn(1, 10),), execution_mode="eager") - def test_eager_raises_not_implemented(self) -> None: - """Verify NotImplementedError for eager mode (not yet supported).""" - model = nn.Linear(10, 5) - with pytest.raises(NotImplementedError, match="not yet implemented"): - ModelInspector(model, (torch.randn(1, 10),), execution_mode="eager") - def test_invalid_execution_mode_raises(self) -> None: """Verify ValueError for unrecognized execution mode.""" model = nn.Linear(10, 5) @@ -594,3 +1157,325 @@ class _FakeCompressor(_BaseModelCompressor): execution_mode="graph", compressor=_FakeCompressor, ) + + def test_palettizer_in_graph_mode_raises(self) -> None: + """Verify ValueError when using KMeansPalettizer with graph mode.""" + model = nn.Linear(10, 5) + with pytest.raises(ValueError, match="not supported in graph mode"): + ModelInspector( + model, + (torch.randn(1, 10),), + execution_mode="graph", + compressor=KMeansPalettizer, + ) + + +class TestEagerModeSpecific: + """Tests specific to eager mode behavior.""" + + def test_dynamic_control_flow(self) -> None: + """Verify eager mode handles dynamic control flow (if/else).""" + + class _DynamicModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.branch_a = nn.Linear(10, 10) + self.branch_b = nn.Linear(10, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.sum() > 0: + return self.branch_a(x) + else: + return self.branch_b(x) + + model = _DynamicModel() + # Positive input takes branch_a + inspector = ModelInspector( + model, (torch.ones(1, 10),), execution_mode="eager", compressor=Quantizer + ) + op_names = [op.op_name for op in inspector.summary.model.all_ops()] + assert "branch_a.linear" in op_names + assert "branch_b.linear" not in op_names + + # Negative input takes branch_b + inspector = ModelInspector( + model, (-1.0 * torch.ones(1, 10),), execution_mode="eager", compressor=Quantizer + ) + op_names = [op.op_name for op in inspector.summary.model.all_ops()] + assert "branch_b.linear" in op_names + assert "branch_a.linear" not in op_names + + def test_shared_module_only_captured_once(self) -> None: + """Verify shared module instances only produce ops on first traversal.""" + + class _SharedModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.shared = nn.Linear(10, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.shared(x) + x = self.shared(x) + return x + + inspector = ModelInspector( + _SharedModel(), + (torch.randn(1, 10),), + execution_mode="eager", + compressor=Quantizer, + ) + ops = inspector.summary.model.all_ops() + linear_ops = [op for op in ops if op.op_type == "linear"] + assert len(linear_ops) == 1 + + def test_passthrough_submodule_has_empty_boundary(self) -> None: + """Verify a submodule that returns its input unchanged has empty output_ops. + + When a module's output tensor is the same as its input (not produced by any + op in the subtree), the corresponding ``_module_output_producers`` entry has + ``output_idx=None``. The guard in ``_populate_boundary_ops_eager`` should + skip it, leaving ``output_ops`` empty. + """ + + class _SideEffectModule(nn.Module): + """Has an internal op but returns its input unchanged.""" + + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(4, 4) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _ = self.linear(x) # runs an op internally, but output is discarded + return x # returns the original input — not produced by any subtree op + + class _Wrapper(nn.Module): + def __init__(self) -> None: + super().__init__() + self.inner = _SideEffectModule() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.inner(x) + + inspector = ModelInspector( + _Wrapper(), + (torch.randn(1, 4),), + execution_mode="eager", + ) + inner = inspector.summary.model.child_modules.get("inner") + assert inner is not None + assert inner.output_ops == {} + + def test_palettizer_compressor(self) -> None: + """Verify KMeansPalettizer filters to only palettization-supported ops.""" + inspector = ModelInspector( + _SimpleConvModel(), + (torch.randn(1, 3, 8, 8),), + execution_mode="eager", + compressor=KMeansPalettizer, + ) + ops = inspector.summary.model.all_ops() + op_types = {op.op_type for op in ops} + # KMeansPalettizer supports conv and linear but not add/mul/relu + assert "conv2d" in op_types + assert "linear" in op_types + + def test_input_output_op_names(self) -> None: + """Verify eager mode names module-level input/output ops as ``input_i``/``output_i``.""" + inspector = ModelInspector( + _SimpleConvModel(), + (torch.randn(1, 3, 8, 8),), + execution_mode="eager", + ) + root = inspector.summary.model + conv_op = next(op for op in root.all_ops() if op.op_type == "conv2d") + linear_op = next(op for op in root.all_ops() if op.op_type == "linear") + + placeholder_names = { + inp.op_name for inp in conv_op.inputs if not inp.module_stack and not inp.is_state + } + assert "input_0" in placeholder_names + + output_names = { + out.op_name + for consumers in linear_op.outputs.values() + for out in consumers + if not out.module_stack and not out.is_state + } + assert "output_0" in output_names + + def test_input_ops_one_per_input_tensor(self) -> None: + """Verify each forward-argument tensor gets its own ``input_i`` OpInfo.""" + + class _TwoInputs(nn.Module): + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return a + b + + inspector = ModelInspector( + _TwoInputs(), + (torch.randn(1, 10), torch.randn(1, 10)), + execution_mode="eager", + ) + add_op = next(op for op in inspector.summary.model.all_ops() if op.op_type == "add") + placeholder_names = { + inp.op_name for inp in add_op.inputs if not inp.module_stack and not inp.is_state + } + assert placeholder_names == {"input_0", "input_1"} + + def test_source_frames_include_nested_forward_calls(self) -> None: + """Verify eager mode captures source frames from all forward() methods on the call stack. + + Graph mode exercises source-frame extraction in + :py:meth:`TestModelInspector.test_source_frames`; this test is eager-specific + because it asserts the multi-frame stack ordering unique to runtime interception. + """ + inspector = ModelInspector( + _NestedModel(), + (torch.randn(1, 3, 8, 8),), + execution_mode="eager", + ) + # conv1 lives inside _Encoder.forward() which is called by _NestedModel.forward(): + # the source stack should contain at least two forward frames, outermost first. + conv_op = next( + op for op in inspector.summary.model.all_ops() if op.op_name == "encoder.conv1.conv2d" + ) + assert len(conv_op.source_frames) >= 2 + assert all(f.function_name == "forward" for f in conv_op.source_frames) + + def test_inplace_op_connectivity(self) -> None: + """Verify connectivity is correct when in-place ops mutate tensors.""" + + class _InPlaceModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(10, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear(x) + x.relu_() + x = x + 1 + return x + + inspector = ModelInspector( + _InPlaceModel(), + (torch.randn(1, 10),), + execution_mode="eager", + ) + ops = inspector.summary.model.all_ops() + ops_by_name = {op.op_name: op for op in ops} + + # relu_ is in-place: it consumes linear's output tensor (same id, different version) + relu_op = ops_by_name["relu_"] + assert len(relu_op.inputs) >= 1 + assert any("linear" in inp.op_name for inp in relu_op.inputs) + + # add consumes relu_'s output (the mutated tensor) + add_op = ops_by_name["add"] + assert len(add_op.inputs) >= 1 + assert any("relu_" in inp.op_name for inp in add_op.inputs) + + def test_weakref_removes_producer_on_dealloc(self) -> None: + """Verify ``_tensor_producers`` entries are cleaned up when tensors die. + + ``_record_outputs`` registers a ``weakref.finalize`` callback that + removes the producer entry when the tensor is deallocated. In CPython + this fires deterministically when the tensor's refcount hits zero + (no cyclic references involved here), so we can assert directly + without relying on ``gc.collect`` heuristics. + """ + mode = _EagerOpDiscoveryMode(nn.Linear(3, 3)) + tensor = torch.randn(3) + op = OpInfo( + op_name="fake", + op_type="fake", + module_stack=(), + source_frames=(), + inputs=(), + outputs={}, + is_state=False, + ) + mode._record_outputs(tensor, op) + key = TensorIdVersion(id(tensor), tensor._version) + assert mode._tensor_producers[key] == InputEdge(op=op, output_idx=0) + del tensor + assert key not in mode._tensor_producers + + def test_input_ops_disambiguates_multi_output_external_producer(self) -> None: + """Verify input_ops correctly separates consumers of distinct outputs from the same op. + + When a multi-output external producer (e.g. chunk) feeds two of its outputs into + a module as separate forward args, input_ops[0] should only list ops that consume + output slot 0 and input_ops[1] should only list ops that consume output slot 1. + The bug: keying external_to_consumers by op_name alone merges both consumer sets + under the same key, so both spec positions get the wrong combined list. + """ + + class _Inner(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear_a = nn.Linear(4, 4) + self.linear_b = nn.Linear(4, 4) + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return self.linear_a(a) + self.linear_b(b) + + class _ChunkModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.inner = _Inner() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a, b = torch.chunk(x, 2, dim=-1) + return self.inner(a, b) + + inspector = ModelInspector( + _ChunkModel(), + (torch.randn(1, 8),), + execution_mode="eager", + ) + inner = inspector.summary.model.child_modules["inner"] + + assert set(inner.input_ops.keys()) == {0, 1} + ops_at_0 = {e.op.op_name for e in inner.input_ops[0]} + ops_at_1 = {e.op.op_name for e in inner.input_ops[1]} + assert ops_at_0 == {"inner.linear_a.linear"} + assert ops_at_1 == {"inner.linear_b.linear"} + + def test_state_detection_after_inplace_mutation_via_view(self) -> None: + """Verify state ops are still detected after their _version advances through a view. + + When a buffer is mutated in-place via a view (e.g., ``self.counter.view(-1).add_(1.0)``), + the mutation increments ``self.counter._version`` without changing ``id(self.counter)``. + The inspector must still recognise ``counter`` as a state input to any op that later + consumes it. + + Before the fix, ``_states_to_names`` was keyed by + ``TensorIdVersion(id(state), state._version)`` captured at ``__init__``. After the + in-place mutation the version had advanced, so the lookup missed and counter was silently + dropped from the consuming op's ``inputs``. Graph mode was unaffected because it identifies + states via ``node.op == "get_attr"`` rather than tensor identity. The fix keys + ``_states_to_names`` by ``id(state)`` only, which is stable for the model's lifetime. + """ + + class _CounterModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(4, 4) + self.register_buffer("counter", torch.zeros(1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Mutate counter through a view — advances counter._version without + # changing id(counter). + self.counter.view(-1).add_(1.0) + x = self.linear(x) + return x + self.counter + + inspector = ModelInspector( + _CounterModel(), + (torch.randn(1, 4),), + execution_mode="eager", + ) + add_op = next(op for op in inspector.summary.model.all_ops() if op.op_type == "add") + state_names = {inp.op_name for inp in add_op.inputs if inp.is_state} + assert any("counter" in name for name in state_names), ( + f"counter not found in state inputs of add op; states found: {state_names}" + ) diff --git a/tests/test_joint_compression.py b/tests/test_joint_compression.py index 045290d..6bf73e5 100644 --- a/tests/test_joint_compression.py +++ b/tests/test_joint_compression.py @@ -13,8 +13,9 @@ import tests.utils as utils from coreai_opt import ExportBackend from coreai_opt.palettization.kmeans import KMeansPalettizer +from coreai_opt.palettization.spec.fake_palettize import _FakePalettizeImplBase from coreai_opt.quantization import Quantizer -from tests.conftest import ParametrizedP4A8CompressionConfigs +from tests.fixtures.compression import ParametrizedP4A8CompressionConfigs from tests.test_utils.general import COREAI_AVAILABLE batch_size = 128 @@ -46,7 +47,16 @@ def test_p4a8_compression_mnist_accuracy( # Palettize weights palettizer = KMeansPalettizer(model, config.palett_config) - palettizer.prepare((example_input,)) + prepared_palettized = palettizer.prepare((example_input,)) + + # MNIST model has 6 weight-bearing layers (conv1, conv2, conv_transpose1, + # conv_transpose2, dense1, dense2). With 4-bit per-tensor palettization, all + # 6 layers are palettized. + palettized_count = utils.count_weight_parametrizations( + prepared_palettized, _FakePalettizeImplBase + ) + assert palettized_count == 6, f"Expected 6 palettized layers, got {palettized_count}" + palettized = palettizer.finalize(backend=ExportBackend.CoreAI) # Quantize activations on the palettized model diff --git a/tests/test_nox_utils.py b/tests/test_nox_utils.py new file mode 100644 index 0000000..a1b1728 --- /dev/null +++ b/tests/test_nox_utils.py @@ -0,0 +1,93 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +"""Tests for ci/nox/utils.py.""" + +from contextlib import contextmanager +from pathlib import Path +from unittest.mock import patch + +import pytest +from packaging.specifiers import SpecifierSet + +from ci.nox.utils import ( + _get_minimum_python_minor_version, + get_supported_python_versions, +) + + +@contextmanager +def mock_pyproject(tmp_path: Path, content: str, patch_target: str = "ci.nox.utils.REPO_ROOT"): + """Context manager for creating mock pyproject.toml and patching REPO_ROOT. + + Args: + tmp_path: Temporary directory path + content: Content to write to pyproject.toml + patch_target: Module path to patch (default: ci.nox.utils.REPO_ROOT) + """ + (tmp_path / "pyproject.toml").write_text(content) + with patch(patch_target, tmp_path): + yield + + +class TestGetMinimumPythonMinorVersion: + """Tests for _get_minimum_python_minor_version.""" + + @pytest.mark.parametrize( + ("specifier_str", "expected"), + [ + (">=3.10", 10), + (">=3.11", 11), + (">=3.10, <3.13", 10), + (">=3.9, <3.12", 9), + (">3.10", 10), + (">3.10, <=3.12", 10), + ], + ) + def test_extracts_minimum_version(self, specifier_str: str, expected: int) -> None: + """Test that the minimum minor version is correctly extracted.""" + specifier = SpecifierSet(specifier_str) + assert _get_minimum_python_minor_version(specifier) == expected + + def test_raises_error_when_no_lower_bound(self) -> None: + """Test that ValueError is raised when no lower bound is specified.""" + specifier = SpecifierSet("<3.13") + with pytest.raises(ValueError, match="No lower bound found in specifier"): + _get_minimum_python_minor_version(specifier) + + +class TestGetSupportedPythonVersions: + """Tests for get_supported_python_versions.""" + + def test_returns_list_of_version_strings(self) -> None: + """Test that the function returns a list of version strings.""" + versions = get_supported_python_versions() + assert isinstance(versions, list) + assert len(versions) > 0 + for version in versions: + assert isinstance(version, str) + assert version.startswith("3.") + + @pytest.mark.parametrize( + ("requires_python", "expected_versions"), + [ + (">=3.10, <3.13", ["3.10", "3.11", "3.12"]), + (">=3.11, <3.14", ["3.11", "3.12", "3.13"]), + (">=3.9, <3.11", ["3.9", "3.10"]), + (">=3.10, <3.11", ["3.10"]), + ], + ) + def test_returns_correct_versions_for_specifier( + self, tmp_path: Path, requires_python: str, expected_versions: list[str] + ) -> None: + """Test that correct versions are returned for a given specifier.""" + content = f""" +[project] +name = "test-project" +requires-python = "{requires_python}" +""" + with mock_pyproject(tmp_path, content): + versions = get_supported_python_versions() + assert versions == expected_versions diff --git a/tests/test_utils/general.py b/tests/test_utils/general.py index 436a256..b6ba93d 100644 --- a/tests/test_utils/general.py +++ b/tests/test_utils/general.py @@ -6,6 +6,7 @@ """General test utilities.""" import importlib.util +from decimal import ROUND_FLOOR, Decimal import torch @@ -23,13 +24,37 @@ def __init__( psnr_thresh: float, prefix: str = "", ) -> None: + # Floor (not round) the displayed values so they match the raw-value + # comparison in verify_snr_psnr. Rounding could nudge a sub-threshold + # value up past the threshold (e.g. 34.9999 -> "35.00"), producing a + # self-contradictory "PSNR 35.00 below threshold 35.0" message. + snr_display = floor_to_decimals(str(snr), 2) + psnr_display = floor_to_decimals(str(psnr), 2) if snr <= snr_thresh: - msg = f"{prefix}SNR {snr:.2f} below threshold {snr_thresh} (PSNR: {psnr:.2f})" + msg = f"{prefix}SNR {snr_display} below threshold {snr_thresh} (PSNR: {psnr_display})" else: - msg = f"{prefix}PSNR {psnr:.2f} below threshold {psnr_thresh} (SNR: {snr:.2f})" + msg = f"{prefix}PSNR {psnr_display} below threshold {psnr_thresh} (SNR: {snr_display})" super().__init__(msg) +def floor_to_decimals(value: str, decimals: int) -> Decimal: + """Floor a numeric value to a fixed number of decimal places. + + Unlike ``round`` or ``f"{x:.2f}"`` (which round to nearest), this rounds + toward negative infinity, so the result never exceeds ``value``. + + Args: + value (str): Numeric value to floor, passed as a string for exact + decimal parsing (avoids binary float representation error). + decimals (int): Number of decimal places to keep. + + Returns: + Decimal: ``value`` floored to ``decimals`` decimal places. + """ + step = Decimal("1").scaleb(-decimals) + return Decimal(value).quantize(step, rounding=ROUND_FLOOR) + + def compute_snr_psnr( data: torch.Tensor, reference: torch.Tensor, diff --git a/tests/test_utils/test_config_utils.py b/tests/test_utils/test_config_utils.py new file mode 100644 index 0000000..df9e577 --- /dev/null +++ b/tests/test_utils/test_config_utils.py @@ -0,0 +1,59 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +import logging + +from coreai_opt._utils.config_utils import get_last_matching_spec + + +def test_no_match(): + assert get_last_matching_spec(["a"], {"b": 1}) == (None, False) + + +def test_single_match(): + assert get_last_matching_spec(["a"], {"a": 42, "b": 99}) == (42, True) + + +def test_wildcard_fallback(): + assert get_last_matching_spec(["a"], {"*": 7}) == (7, True) + + +def test_specific_match_beats_wildcard(): + assert get_last_matching_spec(["a"], {"a": 1, "*": 99}) == (1, True) + + +def test_explicit_none_value_is_found(): + value, found = get_last_matching_spec(["a"], {"a": None}) + assert found is True + assert value is None + + +def test_last_key_wins_on_multiple_matches(): + # spec_dict order determines precedence, not identifier order + assert get_last_matching_spec(["a", "b"], {"a": 1, "b": 2}) == (2, True) + + +def test_last_key_wins_respects_spec_dict_order_not_identifier_order(): + # "b" comes after "a" in spec_dict, so it wins even though "a" is listed first in identifiers + assert get_last_matching_spec(["b", "a"], {"a": 1, "b": 2}) == (2, True) + + +def test_warning_emitted_on_multiple_matches(caplog): + with caplog.at_level(logging.WARNING, logger="coreai_opt._utils.config_utils"): + get_last_matching_spec(["a", "b"], {"a": 1, "b": 2}) + assert len(caplog.records) == 1 + msg = caplog.records[0].message + # identifiers and matched keys both appear in the message + assert "a" in msg and "b" in msg + + +def test_no_warning_on_single_match(caplog): + with caplog.at_level(logging.WARNING, logger="coreai_opt._utils.config_utils"): + get_last_matching_spec(["a"], {"a": 1, "b": 2}) + assert len(caplog.records) == 0 + + +def test_integer_identifiers(): + assert get_last_matching_spec([0, 1], {0: "x", 1: "y", 2: "z"}) == ("y", True) diff --git a/tests/test_utils/test_general.py b/tests/test_utils/test_general.py new file mode 100644 index 0000000..de7ac66 --- /dev/null +++ b/tests/test_utils/test_general.py @@ -0,0 +1,25 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-Clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +from decimal import Decimal + +from tests.test_utils.general import SNRBelowThresholdError, floor_to_decimals + + +def test_floor_to_decimals(): + assert floor_to_decimals("3.149", 2) == Decimal("3.14") + assert floor_to_decimals("3.140", 2) == Decimal("3.14") + assert floor_to_decimals("-3.141", 2) == Decimal("-3.15") + # A value just below a threshold must floor down, never round up. + assert floor_to_decimals("34.9999", 2) == Decimal("34.99") + + +def test_snr_below_threshold_error_floors_displayed_value(): + # PSNR is below threshold; the displayed value must read "34.99", not the + # rounded-up "35.00" that would contradict "below threshold 35.0". + error = SNRBelowThresholdError(snr=100.0, psnr=34.9999, snr_thresh=80.0, psnr_thresh=35.0) + message = str(error) + assert "PSNR 34.99 below threshold 35.0" in message + assert "35.00" not in message diff --git a/tests/test_utils/test_registry_utils.py b/tests/test_utils/test_registry_utils.py index cce846f..7c0747f 100644 --- a/tests/test_utils/test_registry_utils.py +++ b/tests/test_utils/test_registry_utils.py @@ -34,6 +34,21 @@ class ClassB: assert MyClassRegistry.list_registry_keys() == {"class_a", "class_b"} assert MyClassRegistry.list_registry_values() == {ClassA, ClassB} + # resolve() returns the registered class for both string keys and class types. + assert MyClassRegistry.resolve("class_a") is ClassA + assert MyClassRegistry.resolve(ClassB) is ClassB + + # Unknown string key surfaces a ValueError that lists registered keys. + with pytest.raises(ValueError, match="class_c"): + MyClassRegistry.resolve("class_c") + + # An unregistered class also raises ValueError. + class UnregisteredClass: + pass + + with pytest.raises(ValueError, match="UnregisteredClass"): + MyClassRegistry.resolve(UnregisteredClass) + # Test overwriting registry name @MyClassRegistry.register("class_a") class NewClassA: diff --git a/tests/test_utils/test_torch_utils.py b/tests/test_utils/test_torch_utils.py index fd8eefd..5afa6a6 100644 --- a/tests/test_utils/test_torch_utils.py +++ b/tests/test_utils/test_torch_utils.py @@ -9,10 +9,10 @@ import torch from torchao.quantization.pt2e import allow_exported_model_train_eval +from coreai_opt._utils.fx_utils import normalize_module_fqn from coreai_opt._utils.torch_utils import ( move_model_to_eval, move_model_to_train, - normalize_module_fqn, ) diff --git a/tests/utils.py b/tests/utils.py index d18c78a..49f97fd 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,8 +6,20 @@ from pathlib import Path import torch +import torch.nn as nn import torch.nn.functional as F import yaml +from torch.nn.utils.parametrize import is_parametrized + + +def count_weight_parametrizations(model: nn.Module, parametrization_cls: type) -> int: + """Count modules in ``model`` whose ``weight`` is parametrized with ``parametrization_cls``.""" + return sum( + 1 + for module in model.modules() + if is_parametrized(module, "weight") + and any(isinstance(p, parametrization_cls) for p in module.parametrizations["weight"]) + ) def test_data_path():