src/torchjd/autojac/_mtl_backward.py:128: in mtl_backward
backward_transform(grad_tensors_dict)
src/torchjd/autojac/_transform/_base.py:79: in __call__
return self.outer(intermediate)
^^^^^^^^^^^^^^^^^^^^^^^^
src/torchjd/autojac/_transform/_base.py:78: in __call__
intermediate = self.inner(input)
^^^^^^^^^^^^^^^^^
src/torchjd/autojac/_transform/_differentiate.py:43: in __call__
differentiated_tuple = self._differentiate(tensor_outputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
src/torchjd/autojac/_transform/_jac.py:91: in _differentiate
jacs_chunks.append(_get_jacs_chunk(jac_outputs_chunk, get_vjp_last))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
src/torchjd/autojac/_transform/_jac.py:119: in _get_jacs_chunk
return torch.vmap(get_vjp, chunk_size=chunk_size)(jac_outputs_chunk)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.14/site-packages/torch/_functorch/apis.py:220: in wrapped
return vmap_impl(
.venv/lib/python3.14/site-packages/torch/_functorch/vmap.py:305: in vmap_impl
return _chunked_vmap(
.venv/lib/python3.14/site-packages/torch/_functorch/vmap.py:445: in _chunked_vmap
_flat_vmap(
.venv/lib/python3.14/site-packages/torch/_functorch/vmap.py:507: in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
src/torchjd/autojac/_transform/_differentiate.py:68: in _get_vjp
optional_grads = torch.autograd.grad(
.venv/lib/python3.14/site-packages/torch/autograd/__init__.py:530: in grad
result = _engine_run_backward(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
t_outputs = (tensor([[0.0000, 0.0000, 0.0526],
[0.3484, 0.1794, 0.1005],
[0.3232, 0.4607, 0.0777],
[0.1349... [0.3950, 0.0359, 0.0000],
[0.3462, 0.3726, 0.1592]], dtype=torch.float16,
grad_fn=<ReluBackward0>),)
args = ((BatchedTensor(lvl=1, bdim=0, value=
tensor([[[-8.0640e+03, -8.4720e+03, -8.5920e+03],
[ 1.1050e+03,...-0.1621, 0.0535, -0.2953, -0.2285, -0.1630, 0.1995, 0.1854,
-0.1402, -0.0114]], requires_grad=True)), True)
kwargs = {'accumulate_grad': False}, attach_logging_hooks = False
def _engine_run_backward(
t_outputs: Sequence[torch.Tensor | GradientEdge],
*args: Any,
**kwargs: Any,
) -> tuple[torch.Tensor, ...]:
attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
if attach_logging_hooks:
unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
# Need to save the context so compiler config will be visible in device threads
torch._C._stash_obj_in_tls("context", contextvars.copy_context())
try:
> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
t_outputs, *args, **kwargs
) # Calls into the C++ engine to run the backward pass
E RuntimeError: expected scalar type Half but found Float
If we take the AMP usage example (https://torchjd.org/stable/examples/amp/) and just add a single
BatchNorm1dorLayerNormin the model, we get:RuntimeError: expected scalar type Half but found FloatFull stack trace
A simple reproducer is the following:
Note that the bug happens regardless of the device (cpu or cuda), of the autocast dtype (float16, bfloat16), and of the autojac function that we use (jac, backward, mtl_backward). The bug does not happen when
parallel_chunk_size=1, so it's really linked tovmap. I can confirm that the bug happens with all variants ofBatchNorm,LayerNorm,InstanceNormandGroupNorm. I could not make it happen with any other layer, but I suspect that it can also affect other layers (typically layers using buffers).This bug is quite critical because it means that autojac is not compatible with AMP for most real-world models.
Claude came up with a fix (that works), but I don't know how clean it is so I'll work a bit on it before opening a PR.