Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ jobs:
CUDA_VISIBLE_DEVICES: 0
# See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
XLA_PYTHON_CLIENT_PREALLOCATE: false
XLA_PYTHON_CLIENT_ALLOCATOR: platform
- name: Convert models
run: source/tests/infer/convert-models.sh
- name: Download libtorch
Expand Down
19 changes: 19 additions & 0 deletions deepmd/dpmodel/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Utilities for the array API."""

import array_api_compat
import numpy as np
from packaging.version import (
Version,
)
Expand Down Expand Up @@ -73,3 +74,21 @@
out = xp.take(arr, indices)
out = xp.reshape(out, shape)
return xp_swapaxes(out, axis, -1)


def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray:
"""Reduces all values from the src tensor to the indices specified in the index tensor."""
# jax only
if array_api_compat.is_jax_array(input):
from deepmd.jax.common import (
scatter_sum,
)

return scatter_sum(
input,
dim,
index,
src,
)
else:
raise NotImplementedError("Only JAX arrays are supported.")

Check warning on line 94 in deepmd/dpmodel/array_api.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/array_api.py#L94

Added line #L94 was not covered by tests
19 changes: 19 additions & 0 deletions deepmd/dpmodel/model/ener_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from copy import (
deepcopy,
)

from deepmd.dpmodel.atomic_model import (
DPEnergyAtomicModel,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
)

from .dp_model import (
DPModelCommon,
Expand All @@ -25,3 +32,15 @@ def __init__(
) -> None:
DPModelCommon.__init__(self)
DPEnergyModel_.__init__(self, *args, **kwargs)
self._enable_hessian = False
self.hess_fitting_def = None

def enable_hessian(self):
self.hess_fitting_def = deepcopy(self.atomic_output_def())
self.hess_fitting_def["energy"].r_hessian = True
self._enable_hessian = True

def atomic_output_def(self) -> FittingOutputDef:
if self._enable_hessian:
return self.hess_fitting_def
return super().atomic_output_def()
105 changes: 91 additions & 14 deletions deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import array_api_compat
import numpy as np

from deepmd.dpmodel.array_api import (
xp_scatter_sum,
)
from deepmd.dpmodel.common import (
GLOBAL_ENER_FLOAT_PRECISION,
)
Expand All @@ -11,6 +14,7 @@
ModelOutputDef,
OutputVariableDef,
get_deriv_name,
get_hessian_name,
get_reduce_name,
)

Expand Down Expand Up @@ -81,6 +85,7 @@

"""
xp = array_api_compat.get_namespace(mapping)
mapping_ = mapping
new_ret = {}
for kk in model_output_def.keys_outp():
vv = model_ret[kk]
Expand All @@ -98,24 +103,96 @@
mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims)))
mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims)
force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype)
# jax only
if array_api_compat.is_jax_array(force):
from deepmd.jax.common import (
scatter_sum,
)

force = scatter_sum(
force,
1,
mapping,
model_ret[kk_derv_r],
)
else:
raise NotImplementedError("Only JAX arrays are supported.")
force = xp_scatter_sum(
force,
1,
mapping,
model_ret[kk_derv_r],
)
new_ret[kk_derv_r] = force
else:
# name holders
new_ret[kk_derv_r] = None
if vdef.r_hessian:
kk_hess = get_hessian_name(kk)
if model_ret[kk_hess] is not None:
# [nf, *def, nall, 3, nall, 3]
hess_ = model_ret[kk_hess]
def_ndim = len(vdef.shape)
# [nf, nall1, nall2, *def, 3(1), 3(2)]
hess_1 = xp.permute_dims(
hess_,
(
0,
def_ndim + 1,
def_ndim + 3,
*range(1, def_ndim + 1),
def_ndim + 2,
def_ndim + 4,
),
)
nall = hess_1.shape[1]
# (1) -> [nf, nloc1, nall2, *def, 3(1), 3(2)]
hessian1 = xp.zeros(
[*vldims, nall, *vdef.shape, 3, 3], dtype=vv.dtype
)
mapping_hess = xp.reshape(
mapping_, (mldims + [1] * (len(vdef.shape) + 3))
)
mapping_hess = xp.tile(
mapping_hess,
[1] * len(mldims) + [nall, *vdef.shape, 3, 3],
)
hessian1 = xp_scatter_sum(
hessian1,
1,
mapping_hess,
hess_1,
)
# [nf, nall2, nloc1, *def, 3(1), 3(2)]
hessian1 = xp.permute_dims(
hessian1,
(0, 2, 1, *range(3, def_ndim + 5)),
)
nloc = hessian1.shape[2]
# (2) -> [nf, nloc2, nloc1, *def, 3(1), 3(2)]
hessian = xp.zeros(
[*vldims, nloc, *vdef.shape, 3, 3], dtype=vv.dtype
)
mapping_hess = xp.reshape(
mapping_, (mldims + [1] * (len(vdef.shape) + 3))
)
mapping_hess = xp.tile(
mapping_hess,
[1] * len(mldims) + [nloc, *vdef.shape, 3, 3],
)
hessian = xp_scatter_sum(
hessian,
1,
mapping_hess,
hessian1,
)
# -> [nf, *def, nloc1, 3(1), nloc2, 3(2)]
hessian = xp.permute_dims(
hessian,
(
0,
*range(3, def_ndim + 3),
2,
def_ndim + 3,
1,
def_ndim + 4,
),
)
# -> [nf, *def nloc1 * 3, nloc2 * 3]
hessian = xp.reshape(
hessian,
(hessian.shape[0], *vdef.shape, nloc * 3, nloc * 3),
)

new_ret[kk_hess] = hessian
else:
new_ret[kk_hess] = None

Check warning on line 195 in deepmd/dpmodel/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/transform_output.py#L195

Added line #L195 was not covered by tests
if vdef.c_differentiable:
assert vdef.r_differentiable
if model_ret[kk_derv_c] is not None:
Expand Down
13 changes: 13 additions & 0 deletions deepmd/jax/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from deepmd.dpmodel.output_def import (
get_deriv_name,
get_hessian_name,
get_reduce_name,
)
from deepmd.jax.env import (
Expand Down Expand Up @@ -87,6 +88,18 @@ def eval_output(
)

model_predict[kk_derv_r] = extended_force
if vdef.r_hessian:
# [nf, *def, nall, 3, nall, 3]
hessian = jax.vmap(jax.hessian(eval_output, argnums=0))(
extended_coord,
extended_atype,
nlist,
mapping,
fparam,
aparam,
)
kk_hessian = get_hessian_name(kk)
model_predict[kk_hessian] = hessian
if vdef.c_differentiable:
assert vdef.r_differentiable
# avr: [nf, *def, nall, 3, 3]
Expand Down
117 changes: 117 additions & 0 deletions source/tests/jax/test_dp_hessian_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import sys
import unittest

import numpy as np

from deepmd.dpmodel.common import (
to_numpy_array,
)

if sys.version_info >= (3, 10):
from deepmd.jax.common import (
to_jax_array,
)
from deepmd.jax.descriptor.se_e2_a import (
DescrptSeA,
)
from deepmd.jax.env import (
jnp,
)
from deepmd.jax.fitting.fitting import (
EnergyFittingNet,
)
from deepmd.jax.model.ener_model import (
EnergyModel,
)

dtype = jnp.float64


@unittest.skipIf(
sys.version_info < (3, 10),
"JAX requires Python 3.10 or later",
)
class TestCaseSingleFrameWithoutNlist:
def setUp(self) -> None:
# nloc == 3, nall == 4
self.nloc = 3
self.nf, self.nt = 1, 2
self.coord = np.array(
[
[0, 0, 0],
[0, 1, 0],
[0, 0, 1],
],
dtype=np.float64,
).reshape([1, self.nloc * 3])
self.atype = np.array([0, 0, 1], dtype=int).reshape([1, self.nloc])
self.cell = 2.0 * np.eye(3).reshape([1, 9])
# sel = [5, 2]
self.sel = [16, 8]
self.sel_mix = [24]
self.natoms = [3, 3, 2, 1]
self.rcut = 2.2
self.rcut_smth = 0.4
self.atol = 1e-12


@unittest.skipIf(
sys.version_info < (3, 10),
"JAX requires Python 3.10 or later",
)
class TestEnergyHessianModel(unittest.TestCase, TestCaseSingleFrameWithoutNlist):
def setUp(self):
TestCaseSingleFrameWithoutNlist.setUp(self)

def test_self_consistency(self):
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
)
ft = EnergyFittingNet(
self.nt,
ds.get_dim_out(),
mixed_types=ds.mixed_types(),
)
type_map = ["foo", "bar"]
md0 = EnergyModel(ds, ft, type_map=type_map)
md1 = EnergyModel.deserialize(md0.serialize())
md0.enable_hessian()
md1.enable_hessian()
args = [to_jax_array(ii) for ii in [self.coord, self.atype, self.cell]]
ret0 = md0.call(*args)
ret1 = md1.call(*args)
np.testing.assert_allclose(
to_numpy_array(ret0["energy"]),
to_numpy_array(ret1["energy"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_redu"]),
to_numpy_array(ret1["energy_redu"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_r"]),
to_numpy_array(ret1["energy_derv_r"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_c_redu"]),
to_numpy_array(ret1["energy_derv_c_redu"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_r_derv_r"]),
to_numpy_array(ret1["energy_derv_r_derv_r"]),
atol=self.atol,
)
ret0 = md0.call(*args, do_atomic_virial=True)
ret1 = md1.call(*args, do_atomic_virial=True)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_c"]),
to_numpy_array(ret1["energy_derv_c"]),
atol=self.atol,
)
Loading