feat(pt_expt): atomic model #5220
Conversation
Summary of ChangesHello @wanghan-iapcm, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly advances the DeepMD-kit's PyTorch-experimental ( Highlights
Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and well-structured feature: a new experimental PyTorch backend (pt_expt) designed to be more idiomatic with torch.nn.Module and exportable via torch.export. The changes also include a substantial refactoring of core dpmodel components to be backend-agnostic using array_api_compat, which is a great step towards better code structure and maintainability. The addition of comprehensive consistency tests for the new backend is commendable. My review focuses on a couple of areas in the new backend-agnostic logic where code duplication can be reduced to improve clarity and maintainability.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: fb08ffca5b
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
📝 WalkthroughWalkthroughAdds a bias/statistics workflow to the DP atomic model: new BaseAtomicModel methods for computing/loading/applying output biases, a new statistics utility module with I/O and computation routines, expanded tensor-to-numpy fallback, PyTorch-exportable atomic model wrappers, and comprehensive unit tests for these flows. Changes
Sequence Diagram(s)sequenceDiagram
participant Base as BaseAtomicModel
participant Stat as stat.compute_output_stats
participant Wrapper as ForwardWrapper
participant Model as ModelForward
participant FS as StatFile (DPPath)
Base->>Base: change_out_bias(sample_merged, stat_file_path, mode)
Base->>Stat: compute_output_stats(merged, ntypes, keys, stat_file_path, model_forward=wrapper)
Stat->>FS: _restore_from_file(stat_file_path) (if provided)
Stat->>Wrapper: request predictions for samples
Wrapper->>Model: forward(converted inputs / built nlist)
Model->>Wrapper: predictions (numpy arrays)
Wrapper->>Stat: return numpy predictions
Stat->>Stat: compute bias/std (global & atomic), merge/fill, optional save to file
Stat->>Base: return out_bias, out_std
Base->>Base: _store_out_stat(out_bias, out_std, add or set)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
No actionable comments were generated in the recent review. 🎉 Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 7
🤖 Fix all issues with AI agents
In `@deepmd/dpmodel/atomic_model/base_atomic_model.py`:
- Around line 259-261: The docstring for get_compute_stats_distinguish_types
contradicts its return value: update the docstring in
BaseAtomicModel.get_compute_stats_distinguish_types to accurately describe that
it returns True when the fitting net computes statistics that are distinguished
between different atom types (remove the word "not" and rephrase to say it
indicates whether stats are distinguished by atom type), so the text matches the
method name and the True return value.
- Around line 347-367: The code in _store_out_stat uses np.copy on self.out_bias
and self.out_std which breaks for torch.Tensor buffers in the pt_expt backend;
replace those np.copy calls with safe conversions using the existing
to_numpy_array helper (e.g., out_bias_data =
to_numpy_array(self.out_bias).copy() and out_std_data =
to_numpy_array(self.out_std).copy()) so both numpy arrays and torch tensors are
handled correctly before mutating; keep the rest of _store_out_stat logic intact
and continue to assign the final numpy arrays back to self.out_bias/self.out_std
as before.
In `@deepmd/dpmodel/utils/stat.py`:
- Around line 254-263: The comprehension building model_pred_g (and similarly
model_pred_a) can raise KeyError because it indexes global_sampled_idx[kk]
directly; change it to use global_sampled_idx.get(kk, []) so missing keys yield
an empty list and the inner listcomp becomes empty instead of crashing; update
the comprehension that iterates over model_pred (and the analogous one for
model_pred_a) to call global_sampled_idx.get(kk, []) and keep the existing
np.sum(vv[idx], axis=1) logic for each idx.
In `@deepmd/dpmodel/utils/type_embed.py`:
- Around line 210-222: The call to np.random.default_rng().random(...) passes
PRECISION_DICT[self.precision] directly to dtype which fails for unsupported
types like np.float16; instead generate the random array without dtype (or with
a supported float like np.float32) and then cast to the target dtype before
converting to the array backend. Update the block that creates
extend_type_params (the np.random.default_rng().random call and the subsequent
xp.asarray) so you generate with a supported numpy float type, then use
.astype(first_layer_matrix.dtype) or let xp.asarray handle the dtype conversion
(matching first_layer_matrix.dtype and device via xp.asarray(...,
dtype=first_layer_matrix.dtype,
device=array_api_compat.device(first_layer_matrix))) to ensure compatibility for
precisions such as "float16"/"half".
In `@source/tests/consistent/fitting/test_ener.py`:
- Around line 38-44: The INSTALLED_PT_EXPT branch may run without INSTALLED_PT,
so import torch inside that block to avoid NameError in the eval_pt_expt methods
which call torch.from_numpy; update the block that defines EnerFittingPTExpt and
PT_EXPT_DEVICE to also "import torch" so eval_pt_expt (both implementations that
use torch.from_numpy) and any PT_EXPT_DEVICE-dependent code have torch in scope.
In `@source/tests/pt_expt/atomic_model/test_dp_atomic_model.py`:
- Around line 115-159: Add an inline explanatory comment above the
torch.export.export call clarifying why strict=False is used (to handle dynamic
shapes and dict-returning models) consistent with the pattern in
test_fitting_invar_fitting.py; locate the export call (torch.export.export(...,
strict=False)) in test_exportable and add a brief comment referencing md0
returning a dict and dynamic output shapes so future readers understand the
non-strict export choice.
In `@source/tests/pt_expt/fitting/test_fitting_invar_fitting.py`:
- Around line 168-201: The three tests wrap calls to ifn0(...) in with
self.assertRaises(ValueError) as context but place self.assertIn(...) inside the
with block, so those checks never run; move each self.assertIn(...) to
immediately after its corresponding with block and reference context.exception
(e.g., use str(context.exception)) to assert the error message for the ifn0 call
in each case (the blocks around the first ifn0 call, the second ifn0 call when
nfp > 0, and the third ifn0 call when nap > 0).
🧹 Nitpick comments (13)
deepmd/pt_expt/atomic_model/energy_atomic_model.py (1)
14-21: Docstring claims validation that isn't implemented.The docstring says this class "validates the fitting is an EnergyFittingNet or InvarFitting," but the body is
pass. If this validation is intentionally deferred, consider updating the docstring to reflect that (e.g., "placeholder for future validation" or "specialization for energy models"). Otherwise, add the fitting-type check in__init__.deepmd/pt_expt/utils/type_embed.py (1)
15-15: Remove unusednoqadirective.Ruff reports
F401isn't enabled, so# noqa: F401is unnecessary. The side-effect import comment is sufficient to explain the intent.Proposed fix
-from deepmd.pt_expt.utils import network # noqa: F401 +from deepmd.pt_expt.utils import network # ensure EmbeddingNet is registereddeepmd/pt_expt/fitting/invar_fitting.py (1)
23-27: Potentially redundantnetsconversion on line 27.Since
__setattr__routes throughdpmodel_setattr, which auto-convertsNativeOPinstances via the registry,self.netsshould already be a pt_exptNetworkCollectionafterInvarFittingDP.__init__completes. The explicitNetworkCollection.deserialize(self.nets.serialize())on line 27 then performs a redundant serialize→deserialize round-trip.This is harmless (acts as a safety net) and the pattern may be intentional, but worth noting for awareness if you want to trim the overhead.
deepmd/dpmodel/atomic_model/base_atomic_model.py (1)
369-443:_get_forward_wrapper_func— device/no-device branches have significant duplication.The two branches (lines 391–404 vs 405–417) differ only by the
device=devicekwarg. A small helper could reduce this, though the current code is correct and readable.Example consolidation
- device = getattr(ref_array, "device", None) - if device is not None: - # For torch tensors - coord = xp.asarray(coord, device=device) - atype = xp.asarray(atype, device=device) - if box is not None: - # Check if box is all zeros before converting - if np.allclose(box, 0.0): - box = None - else: - box = xp.asarray(box, device=device) - if fparam is not None: - fparam = xp.asarray(fparam, device=device) - if aparam is not None: - aparam = xp.asarray(aparam, device=device) - else: - # For numpy arrays - coord = xp.asarray(coord) - atype = xp.asarray(atype) - if box is not None: - if np.allclose(box, 0.0): - box = None - else: - box = xp.asarray(box) - if fparam is not None: - fparam = xp.asarray(fparam) - if aparam is not None: - aparam = xp.asarray(aparam) + device = getattr(ref_array, "device", None) + dev_kw = {"device": device} if device is not None else {} + + def _to_xp(arr): + return xp.asarray(arr, **dev_kw) + + coord = _to_xp(coord) + atype = _to_xp(atype) + if box is not None: + if np.allclose(box, 0.0): + box = None + else: + box = _to_xp(box) + if fparam is not None: + fparam = _to_xp(fparam) + if aparam is not None: + aparam = _to_xp(aparam)source/tests/pt_expt/fitting/test_fitting_stat.py (1)
47-72:_brute_fparam_ptand_brute_aparam_ptare identical except for the dict key.These two functions could be a single helper parameterized by key name, reducing duplication. Minor nit for test utility code.
♻️ Optional: consolidate into a single helper
-def _brute_fparam_pt(data, ndim): - adata = [ii["fparam"] for ii in data] - all_data = [] - for ii in adata: - tmp = np.reshape(ii, [-1, ndim]) - if len(all_data) == 0: - all_data = np.array(tmp) - else: - all_data = np.concatenate((all_data, tmp), axis=0) - avg = np.average(all_data, axis=0) - std = np.std(all_data, axis=0) - return avg, std - - -def _brute_aparam_pt(data, ndim): - adata = [ii["aparam"] for ii in data] +def _brute_param_pt(data, ndim, key): + adata = [ii[key] for ii in data] all_data = [] for ii in adata: tmp = np.reshape(ii, [-1, ndim])deepmd/dpmodel/utils/stat.py (2)
131-140:np.nan_to_numsilently replaces residual NaN with 0.After the
np.where, any positions where bothatomic_statandglobal_statare NaN remain NaN;np.nan_to_numthen maps them to 0. If this is intentional (e.g., treating unobserved types as zero bias), it's worth a brief inline comment to make the intent clear. If not, it could mask a data-quality problem.
528-540:missing_typesonly accounts for types beyondmax(atype), not gaps.If atom types 0 and 2 are present but type 1 is missing,
compute_stats_from_atomicreturns rows for types 0–2 and this padding only appends types beyond 2. Types in the gap would get NaN fromcompute_stats_from_atomicand are later filled by_fill_stat_with_global, so the overall pipeline is correct. Just flagging for clarity — a comment here would help future readers.Also,
dtype iscomparison on line 531 is fragile — prefer==.Minor: use `==` for dtype comparison
- assert bias_atom_e[kk].dtype is std_atom_e[kk].dtype, ( + assert bias_atom_e[kk].dtype == std_atom_e[kk].dtype, ( "bias and std should be of the same dtypes" )source/tests/pt_expt/atomic_model/test_atomic_model_global_stat.py (2)
47-151:FooFittingis duplicated across test files.This class is nearly identical to
FooFittingintest_atomic_model_atomic_stat.py(the only difference is the addition ofpixoutput). Consider extracting a shared base or parameterized test fixture to reduce copy-paste across the two test modules.
196-199: Unusedfinh5py.Filecontext manager.Ruff flags this (F841). A simple
_would suppress it without changing behavior:- with h5py.File(h5file, "w") as f: + with h5py.File(h5file, "w") as _:This pattern repeats at lines 587 and 697 as well.
source/tests/pt_expt/atomic_model/test_atomic_model_atomic_stat.py (1)
40-127:FooFittingduplicates the one intest_atomic_model_global_stat.py(minuspix).As noted in the other file — extracting common test fixtures would reduce maintenance burden. Not blocking.
source/tests/pt_expt/atomic_model/test_dp_atomic_model.py (1)
161-231:test_excl_consistency: the "hacking!" comment (line 189) deserves a brief explanation.The test calls
reinit_atom_exclude/reinit_pair_excludeonmd0but uses different method names onmd1. A one-line comment explaining why would help future maintainers.source/tests/pt_expt/descriptor/test_se_t.py (1)
59-69: Prefix unused unpacked variables with_.
gr1(line 65) andgr2(line 85) are unpacked but never used. Ruff flagged these (RUF059). Since se_t returnsNoneforgr, you can prefix them.Suggested fix
- rd1, gr1, _, _, sw1 = dd1( + rd1, _gr1, _, _, sw1 = dd1(- rd2, gr2, _, _, sw2 = dd2.call( + rd2, _gr2, _, _, sw2 = dd2.call(source/tests/pt_expt/descriptor/test_se_t_tebd.py (1)
92-93: TODO:grisNone— worth tracking.The comment notes that
grisNoneand warrants investigation. Consider opening an issue to track this so it doesn't get lost.Would you like me to open an issue to track the
grbeingNoneinvestigation?
| np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) | ||
|
|
||
| def test_serialize(self) -> None: | ||
| nf, nloc, nnei = self.nlist.shape |
Check notice
Code scanning / CodeQL
Unused local variable Note test
| np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) | ||
|
|
||
| def test_serialize(self) -> None: | ||
| nf, nloc, nnei = self.nlist.shape |
Check notice
Code scanning / CodeQL
Unused local variable Note test
| np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) | ||
|
|
||
| def test_serialize(self) -> None: | ||
| nf, nloc, nnei = self.nlist.shape |
Check notice
Code scanning / CodeQL
Unused local variable Note test
|
|
||
| def test_change_by_statistic(self) -> None: | ||
| """Test change-by-statistic with atomic foo + global pix + global bar.""" | ||
| nf, nloc, nnei = self.nlist.shape |
Check notice
Code scanning / CodeQL
Unused local variable Note test
|
|
||
| def test_change_by_statistic(self) -> None: | ||
| """Test change-by-statistic with atomic foo + global pix + global bar.""" | ||
| nf, nloc, nnei = self.nlist.shape |
Check notice
Code scanning / CodeQL
Unused local variable Note test
|
|
||
| def test_dp_consistency(self) -> None: | ||
| """Test numerical consistency between dpmodel and pt_expt atomic models.""" | ||
| nf, nloc, nnei = self.nlist.shape |
Check notice
Code scanning / CodeQL
Unused local variable Note test
|
|
||
| def test_dp_consistency(self) -> None: | ||
| """Test numerical consistency between dpmodel and pt_expt atomic models.""" | ||
| nf, nloc, nnei = self.nlist.shape |
Check notice
Code scanning / CodeQL
Unused local variable Note test
|
|
||
| def test_exportable(self) -> None: | ||
| """Test that pt_expt atomic model can be exported with torch.export.""" | ||
| nf, nloc, nnei = self.nlist.shape |
Check notice
Code scanning / CodeQL
Unused local variable Note test
|
|
||
| def test_exportable(self) -> None: | ||
| """Test that pt_expt atomic model can be exported with torch.export.""" | ||
| nf, nloc, nnei = self.nlist.shape |
Check notice
Code scanning / CodeQL
Unused local variable Note test
|
|
||
| def test_exportable(self) -> None: | ||
| """Test that pt_expt atomic model can be exported with torch.export.""" | ||
| nf, nloc, nnei = self.nlist.shape |
Check notice
Code scanning / CodeQL
Unused local variable Note test
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5220 +/- ##
==========================================
+ Coverage 82.07% 82.12% +0.05%
==========================================
Files 732 736 +4
Lines 73974 74237 +263
Branches 3615 3616 +1
==========================================
+ Hits 60711 60967 +256
- Misses 12100 12107 +7
Partials 1163 1163 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary by CodeRabbit
New Features
Bug Fixes
Tests