Skip to content

Conversation

@njzjz
Copy link
Member

@njzjz njzjz commented Mar 11, 2025

Summary by CodeRabbit

  • New Features

    • Introduced enhanced Hessian fitting capabilities that extend model outputs to include second-order derivative information.
    • Integrated Hessian computations into the output transformation workflow for more detailed analytical results.
  • Tests

    • Updated test suites to conditionally import modules based on Python version, ensuring compatibility with the JAX library.
    • Adjusted precision level in tests for finite differences to improve accuracy of comparisons.
    • Added a new environment variable for memory allocation handling in test configurations.
    • Introduced a new function for scatter summation specifically for JAX arrays.

Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR Overview

This PR adds Hessian functionality to the JAX-based EnergyModel and its associated output transformation. Key changes include:

  • New tests in the JAX test suite to validate Hessian accuracy and model self-consistency.
  • Updates in transform_output and base_model modules to compute and propagate Hessian information.
  • Enhancements to the EnergyModel to enable Hessian computation via a new enable_hessian method.

Reviewed Changes

File Description
source/tests/jax/test_make_hessian_model.py Adds tests to compare auto-differentiated Hessian with finite differences.
source/tests/jax/test_dp_hessian_model.py Introduces tests ensuring consistency of DP Hessian model outputs.
deepmd/dpmodel/model/transform_output.py Updates output communication to handle Hessian arrays.
deepmd/jax/model/base_model.py Integrates Hessian computation using jax.hessian with vmap.
deepmd/dpmodel/model/ener_model.py Enhances EnergyModel to support Hessian via deep copy of output definitions.

Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.

Comments suppressed due to low confidence (2)

source/tests/jax/test_make_hessian_model.py:56

  • [nitpick] Consider using a distinct variable name for the result of f(x + i0) to avoid confusion with the initial computation of y0 (line 39), which could improve code clarity.
            y0 = f(x + i0)

deepmd/dpmodel/model/transform_output.py:85

  • [nitpick] Consider renaming 'mapping_' to a more descriptive name to clarify its usage in the Hessian transformation logic.
    mapping_ = mapping

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 11, 2025

📝 Walkthrough

Walkthrough

The pull request introduces Hessian computation support to the deepmd energy model. It enhances the EnergyModel class with a new method, enable_hessian, which sets up Hessian fitting capabilities and modifies the behavior of atomic_output_def when Hessians are enabled. Additionally, the output transformation logic is updated to handle Hessian data, and the model evaluation flow incorporates a conditional Hessian calculation using JAX’s hessian (vectorized with jax.vmap). Unit tests are adjusted for compatibility with Python 3.10 and to verify Hessian computations.

Changes

File(s) Change Summary
deepmd/dpmodel/model/ener_model.py, deepmd/dpmodel/model/transform_output.py Added Hessian support: introduced enable_hessian in EnergyModel; modified atomic_output_def to return Hessian-enabled definitions; updated output transformation to handle Hessian processing; imported get_hessian_name for Hessian naming.
deepmd/jax/model/base_model.py Integrated conditional Hessian computation in eval_output using jax.hessian (vectorized with jax.vmap); stored the computed Hessian in the model output dictionary with a key generated by get_hessian_name.
source/tests/jax/test_dp_hessian_model.py, source/tests/jax/test_make_hessian_model.py Introduced conditional imports based on Python version; ensured tests are compatible with the required Python version for JAX.
.github/workflows/test_cuda.yml Added environment variable XLA_PYTHON_CLIENT_ALLOCATOR for memory allocation handling in the test job.

Possibly related PRs

  • feat(jax/array-api): energy fitting #4204: The changes in the main PR, which focus on adding Hessian fitting functionality to the EnergyModel class, are related to the retrieved PR that enhances the communicate_extended_output function to handle Hessian matrices, as both involve modifications to Hessian-related functionality in the energy modeling context.
  • feat(jax): force & virial #4251: The changes in the main PR, which add Hessian fitting functionality to the EnergyModel class, are related to the modifications in the retrieved PR that enhance Hessian handling in the communicate_extended_output function, as both involve direct alterations to Hessian-related logic in their respective classes.
  • fix(jax): fix several serialization and jit issues for DPA-2 #4315: The changes in the main PR, which enhance the EnergyModel class with Hessian fitting functionality, are related to the modifications in the retrieved PR that also involve the EnergyModel class, specifically addressing issues with gradient computation when rebuilding the neighbor list.

Suggested reviewers

  • wanghan-iapcm
  • iProzd

Tip

⚡🧪 Multi-step agentic review comment chat (experimental)
  • We're introducing multi-step agentic chat in review comments. This experimental feature enhances review discussions with the CodeRabbit agentic chat by enabling advanced interactions, including the ability to create pull requests directly from comments.
    - To enable this feature, set early_access to true under in the settings.
✨ Finishing Touches
  • 📝 Generate Docstrings

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (4)
deepmd/jax/model/base_model.py (1)

91-102: Consider potential performance implications when computing Hessians for each sample.

This block introduces the Hessian computation via jax.vmap(jax.hessian(...)). While it is correct, Hessian calculations can be computationally expensive, particularly for large systems or frequent evaluations. A caching mechanism or on-demand evaluation approach might help mitigate performance overhead if this becomes a bottleneck.

deepmd/dpmodel/model/ener_model.py (1)

38-42: Consider documenting Hessian enablement or adding a disable method.

The method enable_hessian sets Hessian support permanently. If future workflows might need toggling, consider adding a corresponding disable_hessian or clarifying in docstrings that enablement is irreversible.

source/tests/jax/test_dp_hessian_model.py (1)

1-107: Consider adding class and method docstrings.

To improve code documentation, consider adding docstrings to both the test case class and test methods that explain their purpose and what aspects of the Hessian functionality they're testing.

source/tests/jax/test_make_hessian_model.py (1)

1-172: Consider adding class and method docstrings.

To improve code documentation, consider adding docstrings to both test classes and methods that explain their purpose and the specific aspects of Hessian functionality they're testing.

🧰 Tools
🪛 Ruff (0.8.2)

113-113: Function definition does not bind loop variable ii

(B023)


114-114: Function definition does not bind loop variable ii

(B023)


115-115: Function definition does not bind loop variable ii

(B023)


116-116: Function definition does not bind loop variable ii

(B023)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5831505 and 2fdf404.

📒 Files selected for processing (5)
  • deepmd/dpmodel/model/ener_model.py (2 hunks)
  • deepmd/dpmodel/model/transform_output.py (3 hunks)
  • deepmd/jax/model/base_model.py (2 hunks)
  • source/tests/jax/test_dp_hessian_model.py (1 hunks)
  • source/tests/jax/test_make_hessian_model.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
source/tests/jax/test_make_hessian_model.py

113-113: Function definition does not bind loop variable ii

(B023)


114-114: Function definition does not bind loop variable ii

(B023)


115-115: Function definition does not bind loop variable ii

(B023)


116-116: Function definition does not bind loop variable ii

(B023)

⏰ Context from checks skipped due to timeout of 90000ms (20)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Test C++ (true)
🔇 Additional comments (17)
deepmd/jax/model/base_model.py (1)

11-11: No concerns with the new import.

Importing get_hessian_name is straightforward and aligns well with its usage below.

deepmd/dpmodel/model/transform_output.py (2)

14-14: New import for Hessian naming is consistent.

Including get_hessian_name is consistent with the extended Hessian logic added below.


85-85: Clarify the intention behind reassigning mapping to mapping_.

Currently, mapping_ = mapping performs a direct reference assignment. If you intend an independent copy, you may need to create a clone instead. Otherwise, additional transformations on mapping_ could unintentionally affect mapping.

deepmd/dpmodel/model/ener_model.py (4)

2-4: Imports for deepcopy look good.

Using deepcopy prevents referencing the same data structure in multiple places, which is crucial for safely toggling Hessian definitions.


12-14: No issues with new import statements.

Adding FittingOutputDef helps define the Hessian-friendly output structure.


35-37: Fields for Hessian enablement are well introduced.

Defining _enable_hessian and hess_fitting_def provides a clear toggle mechanism within the model.


43-46: Overriding atomic_output_def is logical.

Returning hess_fitting_def when _enable_hessian is true aligns with the toggling approach. This keeps Hessian logic separate from the default definition.

source/tests/jax/test_dp_hessian_model.py (6)

28-50: LGTM: Comprehensive test setup for single frame without neighbor list.

This test setup looks well-structured with appropriate parameters for testing Hessian functionality in a small system.


52-55: LGTM: Test class structure is good.

The test class properly inherits from both unittest.TestCase and TestCaseSingleFrameWithoutNlist.


56-72: Model serialization and Hessian enablement pattern looks good.

The test creates two instances of the same model (one deserialized from the other) and enables Hessian calculation on both to verify consistency.


73-94: Well-structured tests for output consistency.

The test properly verifies that all energy outputs match between the two model instances using appropriate tolerances.


95-99: Good test for Hessian matrix consistency.

The test correctly verifies that the Hessian matrix (energy_derv_r_derv_r) matches between the two model instances.


100-106: Good additional test for atomic virial.

The test ensures consistency when atomic virial calculations are enabled, which is an important edge case to check.

source/tests/jax/test_make_hessian_model.py (4)

36-61: Well-implemented finite difference Hessian calculation.

The finite_hessian function uses a fourth-order accurate finite difference formula to compute the Hessian matrix. This is a good numerical approximation for verification purposes.


64-133: Well-designed test for comparing analytical and numerical Hessians.

The test thoroughly compares the analytical Hessian computed by the model with the numerical Hessian from finite differences, which is an excellent validation approach.

🧰 Tools
🪛 Ruff (0.8.2)

113-113: Function definition does not bind loop variable ii

(B023)


114-114: Function definition does not bind loop variable ii

(B023)


115-115: Function definition does not bind loop variable ii

(B023)


116-116: Function definition does not bind loop variable ii

(B023)


135-163: Good test setup for Hessian model.

The test class properly sets up both Hessian-enabled and regular models using the same parameters, which is essential for accurate comparison.


164-171: Good validation of output definitions.

The tests correctly verify that the Hessian-related output definitions are enabled for the Hessian model and disabled for the regular value model.

njzjz added 2 commits March 11, 2025 20:31
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
source/tests/jax/test_make_hessian_model.py (1)

172-178: Remove duplicated assertions.

These assertions are duplicated from lines 165-171. Remove the duplicated code to maintain cleaner tests.

-    self.assertTrue(self.model_hess.atomic_output_def()["energy"].r_hessian)
-    self.assertFalse(self.model_valu.atomic_output_def()["energy"].r_hessian)
-    self.assertTrue(self.model_hess.model_output_def()["energy"].r_hessian)
-    self.assertEqual(
-        self.model_hess.model_output_def()["energy_derv_r_derv_r"].category,
-        OutputVariableCategory.DERV_R_DERV_R,
-    )
source/tests/jax/test_dp_hessian_model.py (1)

107-117: Remove duplicated test code.

Lines 107-117 duplicate the assertions and function calls from lines 96-106. Remove the duplication to maintain clean, maintainable test code.

-    np.testing.assert_allclose(
-        to_numpy_array(ret0["energy_derv_r_derv_r"]),
-        to_numpy_array(ret1["energy_derv_r_derv_r"]),
-        atol=self.atol,
-    )
-    ret0 = md0.call(*args, do_atomic_virial=True)
-    ret1 = md1.call(*args, do_atomic_virial=True)
-    np.testing.assert_allclose(
-        to_numpy_array(ret0["energy_derv_c"]),
-        to_numpy_array(ret1["energy_derv_c"]),
-        atol=self.atol,
-    )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 36d7512 and 6a77d42.

📒 Files selected for processing (6)
  • source/tests/jax/test_dp_hessian_model.py (1 hunks)
  • source/tests/jax/test_make_hessian_model.py (1 hunks)
  • source/tests/jax/test_dp_hessian_model.py (3 hunks)
  • source/tests/jax/test_make_hessian_model.py (2 hunks)
  • source/tests/jax/test_dp_hessian_model.py (1 hunks)
  • source/tests/jax/test_make_hessian_model.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
  • source/tests/jax/test_dp_hessian_model.py
  • source/tests/jax/test_dp_hessian_model.py
  • source/tests/jax/test_make_hessian_model.py
  • source/tests/jax/test_make_hessian_model.py
🧰 Additional context used
🪛 Ruff (0.8.2)
source/tests/jax/test_make_hessian_model.py

116-116: Function definition does not bind loop variable ii

(B023)


117-117: Function definition does not bind loop variable ii

(B023)


118-118: Function definition does not bind loop variable ii

(B023)


119-119: Function definition does not bind loop variable ii

(B023)

⏰ Context from checks skipped due to timeout of 90000ms (20)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Analyze (python)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Analyze (javascript-typescript)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test C++ (true)
🔇 Additional comments (1)
source/tests/jax/test_make_hessian_model.py (1)

108-121: Address loop variable binding issue.

The nested function definitions inside the loop don't properly capture the loop variable ii, which could lead to unexpected behavior.

Apply this diff to fix the issue:

def np_infer(
    xx,
+   _ii=ii,  # Capture the loop variable explicitly
):
    ret = self.model_valu(
        to_jax_array(xx)[None, ...],
-       atype[ii][None, ...],
-       box=cell[ii][None, ...],
-       fparam=fparam[ii][None, ...],
-       aparam=aparam[ii][None, ...],
+       atype[_ii][None, ...],
+       box=cell[_ii][None, ...],
+       fparam=fparam[_ii][None, ...],
+       aparam=aparam[_ii][None, ...],
    )
🧰 Tools
🪛 Ruff (0.8.2)

116-116: Function definition does not bind loop variable ii

(B023)


117-117: Function definition does not bind loop variable ii

(B023)


118-118: Function definition does not bind loop variable ii

(B023)


119-119: Function definition does not bind loop variable ii

(B023)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🔭 Outside diff range comments (1)
source/tests/jax/test_make_hessian_model.py (1)

172-178: ⚠️ Potential issue

Remove duplicated test assertions.

Lines 172-178 are exact duplicates of the assertions already present in lines 165-171. These duplicate assertions should be removed to avoid redundancy and potential maintenance issues.

        self.assertEqual(
            self.model_hess.model_output_def()["energy_derv_r_derv_r"].category,
            OutputVariableCategory.DERV_R_DERV_R,
        )
-        self.assertTrue(self.model_hess.atomic_output_def()["energy"].r_hessian)
-        self.assertFalse(self.model_valu.atomic_output_def()["energy"].r_hessian)
-        self.assertTrue(self.model_hess.model_output_def()["energy"].r_hessian)
-        self.assertEqual(
-            self.model_hess.model_output_def()["energy_derv_r_derv_r"].category,
-            OutputVariableCategory.DERV_R_DERV_R,
-        )
🧹 Nitpick comments (1)
source/tests/jax/test_make_hessian_model.py (1)

64-133: Consider parametrizing the test for different configurations.

The test method in HessianTest currently tests with fixed parameters like places=6 and delta=1e-3. To make the test more robust, consider parametrizing it to run with different test configurations.

You could use pytest's parametrize decorator if you switch to pytest, or create multiple test methods with different parameters:

def test_with_higher_precision(self) -> None:
    # Run the same test with higher precision requirements
    self._run_test(places=8, delta=1e-4)

def test_with_larger_system(self) -> None:
    # Run with more atoms to test scaling
    self._run_test(extra_atoms=3)

def _run_test(self, places=6, delta=1e-3, extra_atoms=0) -> None:
    # Current test implementation with parametrized values
    # ...
🧰 Tools
🪛 Ruff (0.8.2)

116-116: Function definition does not bind loop variable ii

(B023)


117-117: Function definition does not bind loop variable ii

(B023)


118-118: Function definition does not bind loop variable ii

(B023)


119-119: Function definition does not bind loop variable ii

(B023)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6a77d42 and 65525fa.

📒 Files selected for processing (4)
  • source/tests/jax/test_make_hessian_model.py (1 hunks)
  • source/tests/jax/test_make_hessian_model.py (2 hunks)
  • source/tests/jax/test_make_hessian_model.py (1 hunks)
  • source/tests/jax/test_make_hessian_model.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • source/tests/jax/test_make_hessian_model.py
  • source/tests/jax/test_make_hessian_model.py
  • source/tests/jax/test_make_hessian_model.py
🧰 Additional context used
🪛 Ruff (0.8.2)
source/tests/jax/test_make_hessian_model.py

116-116: Function definition does not bind loop variable ii

(B023)


117-117: Function definition does not bind loop variable ii

(B023)


118-118: Function definition does not bind loop variable ii

(B023)


119-119: Function definition does not bind loop variable ii

(B023)

🔇 Additional comments (3)
source/tests/jax/test_make_hessian_model.py (3)

36-61: LGTM: The finite difference Hessian implementation is correct and well-structured.

The implementation uses a second-order central difference scheme to numerically approximate the Hessian matrix, which is appropriate for testing the analytical Hessian computation from JAX.


108-121: Fix loop variable binding issues.

The static analysis correctly points out that the functions defined inside the loop don't bind the loop variable ii. This could lead to unexpected behavior if the functions were used outside the loop scope.

Apply this fix to explicitly capture the loop variable:

for ii in range(nf):
    def np_infer(
        xx,
+       _ii=ii,  # Capture the loop variable explicitly
    ):
        ret = self.model_valu(
            to_jax_array(xx)[None, ...],
-           atype[ii][None, ...],
-           box=cell[ii][None, ...],
-           fparam=fparam[ii][None, ...],
-           aparam=aparam[ii][None, ...],
+           atype[_ii][None, ...],
+           box=cell[_ii][None, ...],
+           fparam=fparam[_ii][None, ...],
+           aparam=aparam[_ii][None, ...],
        )
🧰 Tools
🪛 Ruff (0.8.2)

116-116: Function definition does not bind loop variable ii

(B023)


117-117: Function definition does not bind loop variable ii

(B023)


118-118: Function definition does not bind loop variable ii

(B023)


119-119: Function definition does not bind loop variable ii

(B023)


135-163: LGTM: Good test class setup with separate models for comparison.

The setup method correctly initializes two models: one with Hessian enabled and one without, which is a good approach for verifying that the Hessian functionality works as expected without affecting the base model behavior.

@codecov
Copy link

codecov bot commented Mar 11, 2025

Codecov Report

Attention: Patch coverage is 95.65217% with 2 lines in your changes missing coverage. Please review.

Project coverage is 84.78%. Comparing base (5831505) to head (39ea7a8).
Report is 82 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/dpmodel/array_api.py 83.33% 1 Missing ⚠️
deepmd/dpmodel/model/transform_output.py 95.83% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##            devel    #4649   +/-   ##
=======================================
  Coverage   84.77%   84.78%           
=======================================
  Files         688      688           
  Lines       66097    66139   +42     
  Branches     3539     3538    -1     
=======================================
+ Hits        56036    56078   +42     
- Misses       8919     8920    +1     
+ Partials     1142     1141    -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@njzjz njzjz added the Test CUDA Trigger test CUDA workflow label Mar 11, 2025
@github-actions github-actions bot removed the Test CUDA Trigger test CUDA workflow label Mar 11, 2025
@njzjz njzjz added the Test CUDA Trigger test CUDA workflow label Mar 12, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
.github/workflows/test_cuda.yml (1)

66-67: Review the Addition of the New Memory Allocator Environment Variable
The newly added environment variable XLA_PYTHON_CLIENT_ALLOCATOR: platform is intended to configure the memory allocation strategy for JAX during CUDA tests. Please ensure that the value "platform" is properly documented in the JAX documentation and that it harmonizes with the current settings (e.g., XLA_PYTHON_CLIENT_PREALLOCATE: false) to avoid unintended memory allocation issues.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 65525fa and df4ef25.

📒 Files selected for processing (1)
  • .github/workflows/test_cuda.yml (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (21)
  • GitHub Check: Test Python and C++ on CUDA
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Analyze (python)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_arm64

@github-actions github-actions bot removed the Test CUDA Trigger test CUDA workflow label Mar 12, 2025
@njzjz njzjz requested a review from wanghan-iapcm March 13, 2025 11:09
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
deepmd/dpmodel/model/transform_output.py (1)

116-195: 🛠️ Refactor suggestion

Consider modularizing the complex Hessian scattering logic

The Hessian calculation logic involves multiple reshaping, tiling, and scatter sum operations that are repeated with slight variations. This complex code would benefit from modularization into smaller, well-documented helper functions.

For example, the two scatter sum operations (lines 146-151 and 169-174) follow a similar pattern and could be extracted into a function.

+ def _scatter_hessian_component(xp, hessian, mapping, mldims, hess_input, shape_info):
+     """Helper function to scatter hessian components.
+     
+     Parameters
+     ----------
+     xp : module
+         Array namespace
+     hessian : ndarray
+         Target array for the scattered hessian component
+     mapping : ndarray
+         Mapping array
+     mldims : list
+         Dimensions of the mapping array
+     hess_input : ndarray
+         Input hessian data to scatter
+     shape_info : tuple
+         Tuple containing shape information for tiling (nall/nloc, vdef.shape, etc.)
+     
+     Returns
+     -------
+     ndarray
+         Scattered hessian component
+     """
+     target_dim, vdef_shape = shape_info[0], shape_info[1:]
+     mapping_hess = xp.reshape(
+         mapping, (mldims + [1] * (len(vdef_shape) + 3))
+     )
+     mapping_hess = xp.tile(
+         mapping_hess,
+         [1] * len(mldims) + [target_dim, *vdef_shape, 3, 3],
+     )
+     return xp_scatter_sum(
+         hessian,
+         1,
+         mapping_hess,
+         hess_input,
+     )

Also, consider adding more comments explaining the dimensions and the purpose of each transformation step.

🧹 Nitpick comments (5)
deepmd/dpmodel/model/transform_output.py (3)

88-88: Consider adding a comment explaining why the mapping copy is needed

Creating a copy of mapping as mapping_ seems unnecessary since it's not being modified. If there's a specific reason for this, please add a comment explaining why.


119-121: Add comments explaining the Hessian dimensions

The dimensions of the Hessian tensor are complex, and it would be helpful to add more detailed comments about what each dimension represents. For example, explain what def_ndim is and how it relates to the output shape.


123-133: Add explanation for the complex dimension permutation

The permutation operation is particularly complex and difficult to understand. Consider adding an explanation of what this permutation is achieving in terms of the physical meaning of the dimensions.

+ # Permute from [nf, *def, nall1, 3, nall2, 3] to [nf, nall1, nall2, *def, 3(1), 3(2)]
+ # This reordering allows us to handle atom indices first, then shape dimensions
deepmd/dpmodel/array_api.py (2)

79-94: Add type annotations for input and dim parameters

For consistency with other functions in this file, consider adding full type annotations to the input and dim parameters.

-def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray:
+def xp_scatter_sum(input: np.ndarray, dim: int, index: np.ndarray, src: np.ndarray) -> np.ndarray:

79-94: Document the expected shape relationship between input, index, and src parameters

Consider enhancing the docstring to explain the expected shape relationship between the input, index, and src parameters to help users understand how to use this function correctly.

 def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray:
-    """Reduces all values from the src tensor to the indices specified in the index tensor."""
+    """Reduces all values from the src tensor to the indices specified in the index tensor.
+    
+    Parameters
+    ----------
+    input : np.ndarray
+        The tensor to scatter into
+    dim : int
+        The dimension along which to index
+    index : np.ndarray
+        The indices tensor that specifies where to scatter the values
+    src : np.ndarray
+        The source tensor containing values to be scattered
+        
+    Returns
+    -------
+    np.ndarray
+        The input tensor with values from src scattered according to index
+        
+    Notes
+    -----
+    The `index` tensor should have the same shape as `src`, and both should
+    be compatible with the shape of `input` along the specified dimension.
+    Currently, only JAX arrays are supported.
+    """
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between df4ef25 and 39ea7a8.

📒 Files selected for processing (2)
  • deepmd/dpmodel/array_api.py (2 hunks)
  • deepmd/dpmodel/model/transform_output.py (4 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (29)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Analyze (python)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Analyze (javascript-typescript)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Analyze (c-cpp)
🔇 Additional comments (1)
deepmd/dpmodel/model/transform_output.py (1)

106-111: Great use of abstraction with xp_scatter_sum

Replacing the JAX-specific implementation with an abstracted function improves the code maintainability and consistency.

@njzjz njzjz enabled auto-merge March 13, 2025 17:30
@njzjz njzjz added this pull request to the merge queue Mar 13, 2025
Merged via the queue into deepmodeling:devel with commit eb9e71d Mar 13, 2025
60 checks passed
@njzjz njzjz deleted the jax-hessian branch March 13, 2025 20:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants