Skip to content

autojac incompatible with AMP for models with normalization layers #699

@ValerianRey

Description

@ValerianRey

If we take the AMP usage example (https://torchjd.org/stable/examples/amp/) and just add a single BatchNorm1d or LayerNorm in the model, we get: RuntimeError: expected scalar type Half but found Float

Full stack trace
src/torchjd/autojac/_mtl_backward.py:128: in mtl_backward
    backward_transform(grad_tensors_dict)
src/torchjd/autojac/_transform/_base.py:79: in __call__
    return self.outer(intermediate)
           ^^^^^^^^^^^^^^^^^^^^^^^^
src/torchjd/autojac/_transform/_base.py:78: in __call__
    intermediate = self.inner(input)
                   ^^^^^^^^^^^^^^^^^
src/torchjd/autojac/_transform/_differentiate.py:43: in __call__
    differentiated_tuple = self._differentiate(tensor_outputs)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
src/torchjd/autojac/_transform/_jac.py:91: in _differentiate
    jacs_chunks.append(_get_jacs_chunk(jac_outputs_chunk, get_vjp_last))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
src/torchjd/autojac/_transform/_jac.py:119: in _get_jacs_chunk
    return torch.vmap(get_vjp, chunk_size=chunk_size)(jac_outputs_chunk)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.14/site-packages/torch/_functorch/apis.py:220: in wrapped
    return vmap_impl(
.venv/lib/python3.14/site-packages/torch/_functorch/vmap.py:305: in vmap_impl
    return _chunked_vmap(
.venv/lib/python3.14/site-packages/torch/_functorch/vmap.py:445: in _chunked_vmap
    _flat_vmap(
.venv/lib/python3.14/site-packages/torch/_functorch/vmap.py:507: in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
src/torchjd/autojac/_transform/_differentiate.py:68: in _get_vjp
    optional_grads = torch.autograd.grad(
.venv/lib/python3.14/site-packages/torch/autograd/__init__.py:530: in grad
    result = _engine_run_backward(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

t_outputs = (tensor([[0.0000, 0.0000, 0.0526],
        [0.3484, 0.1794, 0.1005],
        [0.3232, 0.4607, 0.0777],
        [0.1349...    [0.3950, 0.0359, 0.0000],
        [0.3462, 0.3726, 0.1592]], dtype=torch.float16,
       grad_fn=<ReluBackward0>),)
args = ((BatchedTensor(lvl=1, bdim=0, value=
    tensor([[[-8.0640e+03, -8.4720e+03, -8.5920e+03],
             [ 1.1050e+03,...-0.1621,  0.0535, -0.2953, -0.2285, -0.1630,  0.1995,  0.1854,
         -0.1402, -0.0114]], requires_grad=True)), True)
kwargs = {'accumulate_grad': False}, attach_logging_hooks = False

    def _engine_run_backward(
        t_outputs: Sequence[torch.Tensor | GradientEdge],
        *args: Any,
        **kwargs: Any,
    ) -> tuple[torch.Tensor, ...]:
        attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
        if attach_logging_hooks:
            unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    
        # Need to save the context so compiler config will be visible in device threads
        torch._C._stash_obj_in_tls("context", contextvars.copy_context())
    
        try:
>           return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
                t_outputs, *args, **kwargs
            )  # Calls into the C++ engine to run the backward pass
E           RuntimeError: expected scalar type Half but found Float

A simple reproducer is the following:

import torch
from torchjd.autojac import jac
from torch import nn

model = nn.Sequential(nn.Linear(4, 4), nn.BatchNorm1d(4))

with torch.autocast("cpu", dtype=torch.float16):
    output = model(torch.randn(8, 4))

_ = jac(output, list(model.parameters()))

Note that the bug happens regardless of the device (cpu or cuda), of the autocast dtype (float16, bfloat16), and of the autojac function that we use (jac, backward, mtl_backward). The bug does not happen when parallel_chunk_size=1, so it's really linked to vmap. I can confirm that the bug happens with all variants of BatchNorm, LayerNorm, InstanceNorm and GroupNorm. I could not make it happen with any other layer, but I suspect that it can also affect other layers (typically layers using buffers).

This bug is quite critical because it means that autojac is not compatible with AMP for most real-world models.

Claude came up with a fix (that works), but I don't know how clean it is so I'll work a bit on it before opening a PR.

Metadata

Metadata

Assignees

No one assigned

    Type

    No fields configured for Bug.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions