-
Notifications
You must be signed in to change notification settings - Fork 581
feat(jax): Hessian #4649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(jax): Hessian #4649
Conversation
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR Overview
This PR adds Hessian functionality to the JAX-based EnergyModel and its associated output transformation. Key changes include:
- New tests in the JAX test suite to validate Hessian accuracy and model self-consistency.
- Updates in transform_output and base_model modules to compute and propagate Hessian information.
- Enhancements to the EnergyModel to enable Hessian computation via a new enable_hessian method.
Reviewed Changes
| File | Description |
|---|---|
| source/tests/jax/test_make_hessian_model.py | Adds tests to compare auto-differentiated Hessian with finite differences. |
| source/tests/jax/test_dp_hessian_model.py | Introduces tests ensuring consistency of DP Hessian model outputs. |
| deepmd/dpmodel/model/transform_output.py | Updates output communication to handle Hessian arrays. |
| deepmd/jax/model/base_model.py | Integrates Hessian computation using jax.hessian with vmap. |
| deepmd/dpmodel/model/ener_model.py | Enhances EnergyModel to support Hessian via deep copy of output definitions. |
Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.
Comments suppressed due to low confidence (2)
source/tests/jax/test_make_hessian_model.py:56
- [nitpick] Consider using a distinct variable name for the result of f(x + i0) to avoid confusion with the initial computation of y0 (line 39), which could improve code clarity.
y0 = f(x + i0)
deepmd/dpmodel/model/transform_output.py:85
- [nitpick] Consider renaming 'mapping_' to a more descriptive name to clarify its usage in the Hessian transformation logic.
mapping_ = mapping
📝 WalkthroughWalkthroughThe pull request introduces Hessian computation support to the deepmd energy model. It enhances the Changes
Possibly related PRs
Suggested reviewers
Tip ⚡🧪 Multi-step agentic review comment chat (experimental)
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (4)
deepmd/jax/model/base_model.py (1)
91-102: Consider potential performance implications when computing Hessians for each sample.This block introduces the Hessian computation via
jax.vmap(jax.hessian(...)). While it is correct, Hessian calculations can be computationally expensive, particularly for large systems or frequent evaluations. A caching mechanism or on-demand evaluation approach might help mitigate performance overhead if this becomes a bottleneck.deepmd/dpmodel/model/ener_model.py (1)
38-42: Consider documenting Hessian enablement or adding a disable method.The method
enable_hessiansets Hessian support permanently. If future workflows might need toggling, consider adding a correspondingdisable_hessianor clarifying in docstrings that enablement is irreversible.source/tests/jax/test_dp_hessian_model.py (1)
1-107: Consider adding class and method docstrings.To improve code documentation, consider adding docstrings to both the test case class and test methods that explain their purpose and what aspects of the Hessian functionality they're testing.
source/tests/jax/test_make_hessian_model.py (1)
1-172: Consider adding class and method docstrings.To improve code documentation, consider adding docstrings to both test classes and methods that explain their purpose and the specific aspects of Hessian functionality they're testing.
🧰 Tools
🪛 Ruff (0.8.2)
113-113: Function definition does not bind loop variable
ii(B023)
114-114: Function definition does not bind loop variable
ii(B023)
115-115: Function definition does not bind loop variable
ii(B023)
116-116: Function definition does not bind loop variable
ii(B023)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
deepmd/dpmodel/model/ener_model.py(2 hunks)deepmd/dpmodel/model/transform_output.py(3 hunks)deepmd/jax/model/base_model.py(2 hunks)source/tests/jax/test_dp_hessian_model.py(1 hunks)source/tests/jax/test_make_hessian_model.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
source/tests/jax/test_make_hessian_model.py
113-113: Function definition does not bind loop variable ii
(B023)
114-114: Function definition does not bind loop variable ii
(B023)
115-115: Function definition does not bind loop variable ii
(B023)
116-116: Function definition does not bind loop variable ii
(B023)
⏰ Context from checks skipped due to timeout of 90000ms (20)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
🔇 Additional comments (17)
deepmd/jax/model/base_model.py (1)
11-11: No concerns with the new import.Importing
get_hessian_nameis straightforward and aligns well with its usage below.deepmd/dpmodel/model/transform_output.py (2)
14-14: New import for Hessian naming is consistent.Including
get_hessian_nameis consistent with the extended Hessian logic added below.
85-85: Clarify the intention behind reassigningmappingtomapping_.Currently,
mapping_ = mappingperforms a direct reference assignment. If you intend an independent copy, you may need to create a clone instead. Otherwise, additional transformations onmapping_could unintentionally affectmapping.deepmd/dpmodel/model/ener_model.py (4)
2-4: Imports for deepcopy look good.Using
deepcopyprevents referencing the same data structure in multiple places, which is crucial for safely toggling Hessian definitions.
12-14: No issues with new import statements.Adding
FittingOutputDefhelps define the Hessian-friendly output structure.
35-37: Fields for Hessian enablement are well introduced.Defining
_enable_hessianandhess_fitting_defprovides a clear toggle mechanism within the model.
43-46: Overridingatomic_output_defis logical.Returning
hess_fitting_defwhen_enable_hessianis true aligns with the toggling approach. This keeps Hessian logic separate from the default definition.source/tests/jax/test_dp_hessian_model.py (6)
28-50: LGTM: Comprehensive test setup for single frame without neighbor list.This test setup looks well-structured with appropriate parameters for testing Hessian functionality in a small system.
52-55: LGTM: Test class structure is good.The test class properly inherits from both unittest.TestCase and TestCaseSingleFrameWithoutNlist.
56-72: Model serialization and Hessian enablement pattern looks good.The test creates two instances of the same model (one deserialized from the other) and enables Hessian calculation on both to verify consistency.
73-94: Well-structured tests for output consistency.The test properly verifies that all energy outputs match between the two model instances using appropriate tolerances.
95-99: Good test for Hessian matrix consistency.The test correctly verifies that the Hessian matrix (energy_derv_r_derv_r) matches between the two model instances.
100-106: Good additional test for atomic virial.The test ensures consistency when atomic virial calculations are enabled, which is an important edge case to check.
source/tests/jax/test_make_hessian_model.py (4)
36-61: Well-implemented finite difference Hessian calculation.The finite_hessian function uses a fourth-order accurate finite difference formula to compute the Hessian matrix. This is a good numerical approximation for verification purposes.
64-133: Well-designed test for comparing analytical and numerical Hessians.The test thoroughly compares the analytical Hessian computed by the model with the numerical Hessian from finite differences, which is an excellent validation approach.
🧰 Tools
🪛 Ruff (0.8.2)
113-113: Function definition does not bind loop variable
ii(B023)
114-114: Function definition does not bind loop variable
ii(B023)
115-115: Function definition does not bind loop variable
ii(B023)
116-116: Function definition does not bind loop variable
ii(B023)
135-163: Good test setup for Hessian model.The test class properly sets up both Hessian-enabled and regular models using the same parameters, which is essential for accurate comparison.
164-171: Good validation of output definitions.The tests correctly verify that the Hessian-related output definitions are enabled for the Hessian model and disabled for the regular value model.
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
source/tests/jax/test_make_hessian_model.py (1)
172-178: Remove duplicated assertions.These assertions are duplicated from lines 165-171. Remove the duplicated code to maintain cleaner tests.
- self.assertTrue(self.model_hess.atomic_output_def()["energy"].r_hessian) - self.assertFalse(self.model_valu.atomic_output_def()["energy"].r_hessian) - self.assertTrue(self.model_hess.model_output_def()["energy"].r_hessian) - self.assertEqual( - self.model_hess.model_output_def()["energy_derv_r_derv_r"].category, - OutputVariableCategory.DERV_R_DERV_R, - )source/tests/jax/test_dp_hessian_model.py (1)
107-117: Remove duplicated test code.Lines 107-117 duplicate the assertions and function calls from lines 96-106. Remove the duplication to maintain clean, maintainable test code.
- np.testing.assert_allclose( - to_numpy_array(ret0["energy_derv_r_derv_r"]), - to_numpy_array(ret1["energy_derv_r_derv_r"]), - atol=self.atol, - ) - ret0 = md0.call(*args, do_atomic_virial=True) - ret1 = md1.call(*args, do_atomic_virial=True) - np.testing.assert_allclose( - to_numpy_array(ret0["energy_derv_c"]), - to_numpy_array(ret1["energy_derv_c"]), - atol=self.atol, - )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
source/tests/jax/test_dp_hessian_model.py(1 hunks)source/tests/jax/test_make_hessian_model.py(1 hunks)source/tests/jax/test_dp_hessian_model.py(3 hunks)source/tests/jax/test_make_hessian_model.py(2 hunks)source/tests/jax/test_dp_hessian_model.py(1 hunks)source/tests/jax/test_make_hessian_model.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
- source/tests/jax/test_dp_hessian_model.py
- source/tests/jax/test_dp_hessian_model.py
- source/tests/jax/test_make_hessian_model.py
- source/tests/jax/test_make_hessian_model.py
🧰 Additional context used
🪛 Ruff (0.8.2)
source/tests/jax/test_make_hessian_model.py
116-116: Function definition does not bind loop variable ii
(B023)
117-117: Function definition does not bind loop variable ii
(B023)
118-118: Function definition does not bind loop variable ii
(B023)
119-119: Function definition does not bind loop variable ii
(B023)
⏰ Context from checks skipped due to timeout of 90000ms (20)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Analyze (python)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test C++ (false)
- GitHub Check: Analyze (javascript-typescript)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test C++ (true)
🔇 Additional comments (1)
source/tests/jax/test_make_hessian_model.py (1)
108-121: Address loop variable binding issue.The nested function definitions inside the loop don't properly capture the loop variable
ii, which could lead to unexpected behavior.Apply this diff to fix the issue:
def np_infer( xx, + _ii=ii, # Capture the loop variable explicitly ): ret = self.model_valu( to_jax_array(xx)[None, ...], - atype[ii][None, ...], - box=cell[ii][None, ...], - fparam=fparam[ii][None, ...], - aparam=aparam[ii][None, ...], + atype[_ii][None, ...], + box=cell[_ii][None, ...], + fparam=fparam[_ii][None, ...], + aparam=aparam[_ii][None, ...], )🧰 Tools
🪛 Ruff (0.8.2)
116-116: Function definition does not bind loop variable
ii(B023)
117-117: Function definition does not bind loop variable
ii(B023)
118-118: Function definition does not bind loop variable
ii(B023)
119-119: Function definition does not bind loop variable
ii(B023)
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🔭 Outside diff range comments (1)
source/tests/jax/test_make_hessian_model.py (1)
172-178:⚠️ Potential issueRemove duplicated test assertions.
Lines 172-178 are exact duplicates of the assertions already present in lines 165-171. These duplicate assertions should be removed to avoid redundancy and potential maintenance issues.
self.assertEqual( self.model_hess.model_output_def()["energy_derv_r_derv_r"].category, OutputVariableCategory.DERV_R_DERV_R, ) - self.assertTrue(self.model_hess.atomic_output_def()["energy"].r_hessian) - self.assertFalse(self.model_valu.atomic_output_def()["energy"].r_hessian) - self.assertTrue(self.model_hess.model_output_def()["energy"].r_hessian) - self.assertEqual( - self.model_hess.model_output_def()["energy_derv_r_derv_r"].category, - OutputVariableCategory.DERV_R_DERV_R, - )
🧹 Nitpick comments (1)
source/tests/jax/test_make_hessian_model.py (1)
64-133: Consider parametrizing the test for different configurations.The test method in
HessianTestcurrently tests with fixed parameters likeplaces=6anddelta=1e-3. To make the test more robust, consider parametrizing it to run with different test configurations.You could use pytest's parametrize decorator if you switch to pytest, or create multiple test methods with different parameters:
def test_with_higher_precision(self) -> None: # Run the same test with higher precision requirements self._run_test(places=8, delta=1e-4) def test_with_larger_system(self) -> None: # Run with more atoms to test scaling self._run_test(extra_atoms=3) def _run_test(self, places=6, delta=1e-3, extra_atoms=0) -> None: # Current test implementation with parametrized values # ...🧰 Tools
🪛 Ruff (0.8.2)
116-116: Function definition does not bind loop variable
ii(B023)
117-117: Function definition does not bind loop variable
ii(B023)
118-118: Function definition does not bind loop variable
ii(B023)
119-119: Function definition does not bind loop variable
ii(B023)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
source/tests/jax/test_make_hessian_model.py(1 hunks)source/tests/jax/test_make_hessian_model.py(2 hunks)source/tests/jax/test_make_hessian_model.py(1 hunks)source/tests/jax/test_make_hessian_model.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- source/tests/jax/test_make_hessian_model.py
- source/tests/jax/test_make_hessian_model.py
- source/tests/jax/test_make_hessian_model.py
🧰 Additional context used
🪛 Ruff (0.8.2)
source/tests/jax/test_make_hessian_model.py
116-116: Function definition does not bind loop variable ii
(B023)
117-117: Function definition does not bind loop variable ii
(B023)
118-118: Function definition does not bind loop variable ii
(B023)
119-119: Function definition does not bind loop variable ii
(B023)
🔇 Additional comments (3)
source/tests/jax/test_make_hessian_model.py (3)
36-61: LGTM: The finite difference Hessian implementation is correct and well-structured.The implementation uses a second-order central difference scheme to numerically approximate the Hessian matrix, which is appropriate for testing the analytical Hessian computation from JAX.
108-121: Fix loop variable binding issues.The static analysis correctly points out that the functions defined inside the loop don't bind the loop variable
ii. This could lead to unexpected behavior if the functions were used outside the loop scope.Apply this fix to explicitly capture the loop variable:
for ii in range(nf): def np_infer( xx, + _ii=ii, # Capture the loop variable explicitly ): ret = self.model_valu( to_jax_array(xx)[None, ...], - atype[ii][None, ...], - box=cell[ii][None, ...], - fparam=fparam[ii][None, ...], - aparam=aparam[ii][None, ...], + atype[_ii][None, ...], + box=cell[_ii][None, ...], + fparam=fparam[_ii][None, ...], + aparam=aparam[_ii][None, ...], )🧰 Tools
🪛 Ruff (0.8.2)
116-116: Function definition does not bind loop variable
ii(B023)
117-117: Function definition does not bind loop variable
ii(B023)
118-118: Function definition does not bind loop variable
ii(B023)
119-119: Function definition does not bind loop variable
ii(B023)
135-163: LGTM: Good test class setup with separate models for comparison.The setup method correctly initializes two models: one with Hessian enabled and one without, which is a good approach for verifying that the Hessian functionality works as expected without affecting the base model behavior.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4649 +/- ##
=======================================
Coverage 84.77% 84.78%
=======================================
Files 688 688
Lines 66097 66139 +42
Branches 3539 3538 -1
=======================================
+ Hits 56036 56078 +42
- Misses 8919 8920 +1
+ Partials 1142 1141 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
.github/workflows/test_cuda.yml (1)
66-67: Review the Addition of the New Memory Allocator Environment Variable
The newly added environment variableXLA_PYTHON_CLIENT_ALLOCATOR: platformis intended to configure the memory allocation strategy for JAX during CUDA tests. Please ensure that the value"platform"is properly documented in the JAX documentation and that it harmonizes with the current settings (e.g.,XLA_PYTHON_CLIENT_PREALLOCATE: false) to avoid unintended memory allocation issues.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
.github/workflows/test_cuda.yml(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (21)
- GitHub Check: Test Python and C++ on CUDA
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Analyze (python)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
deepmd/dpmodel/model/transform_output.py (1)
116-195: 🛠️ Refactor suggestionConsider modularizing the complex Hessian scattering logic
The Hessian calculation logic involves multiple reshaping, tiling, and scatter sum operations that are repeated with slight variations. This complex code would benefit from modularization into smaller, well-documented helper functions.
For example, the two scatter sum operations (lines 146-151 and 169-174) follow a similar pattern and could be extracted into a function.
+ def _scatter_hessian_component(xp, hessian, mapping, mldims, hess_input, shape_info): + """Helper function to scatter hessian components. + + Parameters + ---------- + xp : module + Array namespace + hessian : ndarray + Target array for the scattered hessian component + mapping : ndarray + Mapping array + mldims : list + Dimensions of the mapping array + hess_input : ndarray + Input hessian data to scatter + shape_info : tuple + Tuple containing shape information for tiling (nall/nloc, vdef.shape, etc.) + + Returns + ------- + ndarray + Scattered hessian component + """ + target_dim, vdef_shape = shape_info[0], shape_info[1:] + mapping_hess = xp.reshape( + mapping, (mldims + [1] * (len(vdef_shape) + 3)) + ) + mapping_hess = xp.tile( + mapping_hess, + [1] * len(mldims) + [target_dim, *vdef_shape, 3, 3], + ) + return xp_scatter_sum( + hessian, + 1, + mapping_hess, + hess_input, + )Also, consider adding more comments explaining the dimensions and the purpose of each transformation step.
🧹 Nitpick comments (5)
deepmd/dpmodel/model/transform_output.py (3)
88-88: Consider adding a comment explaining why the mapping copy is neededCreating a copy of
mappingasmapping_seems unnecessary since it's not being modified. If there's a specific reason for this, please add a comment explaining why.
119-121: Add comments explaining the Hessian dimensionsThe dimensions of the Hessian tensor are complex, and it would be helpful to add more detailed comments about what each dimension represents. For example, explain what
def_ndimis and how it relates to the output shape.
123-133: Add explanation for the complex dimension permutationThe permutation operation is particularly complex and difficult to understand. Consider adding an explanation of what this permutation is achieving in terms of the physical meaning of the dimensions.
+ # Permute from [nf, *def, nall1, 3, nall2, 3] to [nf, nall1, nall2, *def, 3(1), 3(2)] + # This reordering allows us to handle atom indices first, then shape dimensionsdeepmd/dpmodel/array_api.py (2)
79-94: Add type annotations for input and dim parametersFor consistency with other functions in this file, consider adding full type annotations to the
inputanddimparameters.-def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray: +def xp_scatter_sum(input: np.ndarray, dim: int, index: np.ndarray, src: np.ndarray) -> np.ndarray:
79-94: Document the expected shape relationship between input, index, and src parametersConsider enhancing the docstring to explain the expected shape relationship between the input, index, and src parameters to help users understand how to use this function correctly.
def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray: - """Reduces all values from the src tensor to the indices specified in the index tensor.""" + """Reduces all values from the src tensor to the indices specified in the index tensor. + + Parameters + ---------- + input : np.ndarray + The tensor to scatter into + dim : int + The dimension along which to index + index : np.ndarray + The indices tensor that specifies where to scatter the values + src : np.ndarray + The source tensor containing values to be scattered + + Returns + ------- + np.ndarray + The input tensor with values from src scattered according to index + + Notes + ----- + The `index` tensor should have the same shape as `src`, and both should + be compatible with the shape of `input` along the specified dimension. + Currently, only JAX arrays are supported. + """
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/dpmodel/array_api.py(2 hunks)deepmd/dpmodel/model/transform_output.py(4 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (29)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Analyze (python)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test C++ (false)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Analyze (javascript-typescript)
- GitHub Check: Test C++ (true)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (1)
deepmd/dpmodel/model/transform_output.py (1)
106-111: Great use of abstraction with xp_scatter_sumReplacing the JAX-specific implementation with an abstracted function improves the code maintainability and consistency.
Summary by CodeRabbit
New Features
Tests