Skip to content

Conversation

@michael-0brien
Copy link

Here's my draft of the AbstractDescent as described in #187. I also have the inclusion of an initial step size to the ClassicalTrustRegion here, but we can move this to another PR if desired.

For making scale-invariance the default implementation, we will have to discuss how to strike a balance between scale-invariance and robustness to bad initialization. The jist is that if the initial hessian estimate is bad, the algorithm loses parameter sensitivity.

The scipy version involves taking the scaling matrix D to be the max diagonal of the hessian so far encountered. They probably have other tricks, but this alone is not robust to bad initialization. I could imagine other cases where one would not want the max, perhaps one would be interested in the average instead. Regardless, we will need some modularity for implementing the scaling operator.

The paper I shared in #187 has some advice for choices of D, such as adding small const diagonal offset to prevent becoming too insensitive to parameters. In my work, I'm having success normalizing the trace of D to be equal to the identity. That way the LM param too carries a sense of scale and I can use it to bias the algorithm to gradient-like or newton-like steps in early iterations. D then maintains relative scale between parameters. I'm not suggesting this should be the default, but some food for thought.

@michael-0brien
Copy link
Author

Two things added to public API so far:

  • ScaledDampedNewtonDescent: The AbstractDescent for direct scale-invariant LM. This keeps track of the scaling operator and allows a user to pass a callable update_scaling_fn(hessian: lx.AbstractLinearOperator, scaling_operator: lx.DiagonalLinearOperator) -> lx.DiagonalLinearOperator.
  • max_diagonal_scaling_update: The "scipy" scaling update of taking the maximum diagonal of the hessian so far encountered. I thought that exposing scaling update functions to public API could follow functions for computing beta in non-linear CG.

Not sure if this is fully the way to go but should get the ball rolling! I've also written up some documentation and added to tests.

@johannahaffner
Copy link
Collaborator

johannahaffner commented Nov 23, 2025

Just took a look and I like this! I have a few nits and comments, will post these tomorrow.

Copy link
Collaborator

@johannahaffner johannahaffner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very good to me! (I thought I had more comments but upon going over this again they resolved themselves.)

The main thing is that I think the function updating the scaling factor should become more flexible, and just be a Callable rather than a custom type. Perhaps Callable[[lx.AbstractLinearOperator, lx.AbstractLinearOperator], lx.AbstractLinearOperator] would work?

As you've suggested above, let's put the initial step size attribute for the searches in a different PR.

hessian: lx.AbstractLinearOperator, scaling_operator: lx.DiagonalLinearOperator
) -> lx.DiagonalLinearOperator:
"""Update the matrix `D` that controls the relative scaling of each
parameter based on the procedure described by More (1977).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid all confusion and just point out that the full reference is given below :)

newton_norm = self.trust_region_norm(newton)
return _IndirectDampedNewtonDescentState(
f_info=f_info, newton=newton, newton_norm=newton_norm, result=result
f_info=f_info,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a spurious formatting change - best to remove it, I think.

# Will probably resolve to either Cholesky (for minimisation problems) or
# QR (for least-squares problems).
linear_solver: lx.AbstractLinearSolver = lx.AutoLinearSolver(well_posed=None)
update_scaling_fn: UpdateScalingFn = max_diagonal_scaling_update
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should consider the type of the scaling function we want to support here. How about we make it a Callable that can do anything as long as it takes two Lineax operators and returns another Lineax operator? This will make it easy for anyone wanting to implement a custom thing.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So are you suggesting this should be changed to Callable[[lx.AbstractLinearOperator, lx.AbstractLinearOperator], lx.AbstractLinearOperator]?

This sounds good to me but will just need to add a scaling_operator = cast(lx.DiagonalLinearOperator, scaling_operator) to make pyright happy in the update_scaling_fn. Note that descent.init initializes as:

scaling_operator = lx.DiagonalLinearOperator(tree_full_like(y, -jnp.inf, allow_static=True))

@johannahaffner
Copy link
Collaborator

The jist is that if the initial hessian estimate is bad, the algorithm loses parameter sensitivity.

How does this play out with different flavours of Hessian approximations? A BFGS type Hessian approximation that starts out with an identity should basically implement gradient descent for the first few steps. What happens with a residual Jacobian far from the optimum?

I could imagine other cases where one would not want the max, perhaps one would be interested in the average instead.

In my work, I'm having success normalizing the trace of D to be equal to the identity. That way the LM param too carries a sense of scale and I can use it to bias the algorithm to gradient-like or newton-like steps in early iterations.

Regardless, we will need some modularity for implementing the scaling operator.

I think making the type of the update function more flexible should address all of the above, right? This does not answer the question of what the best default is, but we could start out with the maximum value, matching scipy, and make a change if we see something else that performs much better. In any case, we should write in the documentation that this can be highly problem-specific.

We could also provide several scaling functions, similar to how we have implemented this for the NonlinearCG family, which may be used with different methods: https://docs.kidger.site/optimistix/api/minimise/#optimistix.NonlinearCG (that are also just restricted to implement a specific signature).

@michael-0brien
Copy link
Author

Thanks for the comments! I’ll work on finalizing this PR. Two questions in finalizing:

  1. What API would we like for the LevenbergMarquardt class? Should the scaled update become the default and if so, how? Didn’t want to touch this part since I’m not sure what is desired for backwards compatibility. Somehow it should be clear that the new approach may break for large systems since it requires materializing the hessian.
  2. Would you prefer I revert my commit adding the initial_step_size to the search method, or would you want to cherry pick my commit for the scaled descent? If the latter, I’ll add another commit addressing your comments.

The jist is that if the initial hessian estimate is bad, the algorithm loses parameter sensitivity.

How does this play out with different flavours of Hessian approximations? A BFGS type Hessian approximation that starts out with an identity should basically implement gradient descent for the first few steps. What happens with a residual Jacobian far from the optimum?

Good question, I need to read up on this. The work I’ve seen that studies this is always in context of the usual residual jacobian LM. I think this is still true in this case, which is why I’m curious if scipy does anything else under the hood besides what’s implemented in the 1977 reference they cite.

I am using a BFGS hessian of course, so I just bias the algorithm in early iterations to be gradient-like. I do this by choosing the initial LM param to be large, and also by using a custom search that has the behavior of LinearTrustRegion before some hessian_burn_in number of steps, and ClassicalTrustRegion after that.

I could imagine other cases where one would not want the max, perhaps one would be interested in the average instead.

In my work, I'm having success normalizing the trace of D to be equal to the identity. That way the LM param too carries a sense of scale and I can use it to bias the algorithm to gradient-like or newton-like steps in early iterations.

Regardless, we will need some modularity for implementing the scaling operator.

I think making the type of the update function more flexible should address all of the above, right? This does not answer the question of what the best default is, but we could start out with the maximum value, matching scipy, and make a change if we see something else that performs much better. In any case, we should write in the documentation that this can be highly problem-specific.

We could also provide several scaling functions, similar to how we have implemented this for the NonlinearCG family, which may be used with different methods: https://docs.kidger.site/optimistix/api/minimise/#optimistix.NonlinearCG (that are also just restricted to implement a specific signature).

Yes, I think so. Going with the scipy version and providing flexibility is a good plan!

@johannahaffner
Copy link
Collaborator

Thanks for the comments! I’ll work on finalizing this PR.
Yes, I think so. Going with the scipy version and providing flexibility is a good plan!

Great!

  1. What API would we like for the LevenbergMarquardt class? Should the scaled update become the default and if so, how? Didn’t want to touch this part since I’m not sure what is desired for backwards compatibility. Somehow it should be clear that the new approach may break for large systems since it requires materializing the hessian.

I was thinking about this too. I think for conceptual clarity it makes sense to keep our current LevenbergMarquardt and introduce a new ScaledLevenbergMarquardt that documents the differences. This way, users who currently use LevenbergMarquardt, either with QR or with an iterative backend (we have an LSMR that will be in the next Lineax release and enables fully matrix-free usage) don't have to think about switching / changing their code, and everyone gets a new, extra option.
We can leave a comment somewhere in the code that LM does not support scaling because we would need to normalise the equations to get a diagonal, which interferes with matrix-free computations.

  1. Would you prefer I revert my commit adding the initial_step_size to the search method

Sounds good!

I think this is still true in this case, which is why I’m curious if scipy does anything else under the hood besides what’s implemented in the 1977 reference they cite.

I'm not sure! There is this line here: https://github.com/scipy/scipy/blob/b1296b9b4393e251511fe8fdd3e58c22a1124899/scipy/optimize/__minpack.c#L1921

and another one just like it in LMDIF, which actually do look like they do a QR factorisation before applying the scaling. @ilayn can you comment?

I am using a BFGS hessian of course, so I just bias the algorithm in early iterations to be gradient-like. I do this by choosing the initial LM param to be large, and also by using a custom search that has the behavior of LinearTrustRegion before some hessian_burn_in number of steps, and ClassicalTrustRegion after that.

Amazing! Some true Optimistix power usage right there.

@ilayn
Copy link

ilayn commented Nov 26, 2025

'ello. Such a nice write up and discussion.

For the actual underlying reasons of the optimization problem, I have to disappoint you that I did not find the time to actually go through minpack or slsqp yet in terms of scaling of the data and improving the conditioning of the problem. We are almost finalizing rewriting all Fortran code in the entire SciPy codebase and that took out all the time I have for FOSS activities and work is a bit busy lately (sorry for not keeping in touch sooner).

A couple of comments;

  • Minpack code is old. QR factorizations and rank-k updates definitely need some love to use LAPACK calls instead of these home-baked code (I'm talking about qrfac and rwupd and other stuff). Since our goal was to remove the old F77 code, I just went for line by line equivalent translation instead of risking divergence and causing mayhem downstream suddenly changing results etc. But with hindsight on our side, things moved on in numerical linear algebra in the last 45 years. At some point I have to replace those with proper LAPACK calls

  • which actually do look like they do a QR factorisation before applying the scaling. @ilayn can you comment?

    I don't know the exact reasoning but QR factorization does not change much as Householder reflections are more stable to obtain the $R$ and $Q$ is always unitary. If I have to guess; scaling $R$ does more work as opposed to first scaling and then factorizing.

  • But if I put my very dusty control theory hat on, this type of problem pops up very often and there are many tricks that can be implemented.

    For the problem types of solving for $(A - \lambda I)x = b$, the typical trick is to balance the problem with a diagonal scaling matrix (in practice just a 1D array holding the diagonal) whose entries are exact powers of two, for reduction of numerical noise. This type of scaling is often obtained by LAPACK ?GEBAL routines. So you write

$$ (A - \lambda I)x = b \implies (D^{-1}AD - \lambda I) (D^{-1}x) = (D^{-1}b) \implies (\bar A - \lambda I) \bar x = \bar b $$

$D$ here is a diagonal nonsingular matrix that approximately equalizes the row and column norms of original $A$. Typically the matrix is allowed to be permuted in the regular ?GEBAL but if structure is important you can also obtain a scaling without permutations albeit less successful balancing. Then after the solution, you scale back and go about your way. Hence it does not alter the rest of the algorithm but makes that linear system solve better conditioned.

This is basically a generalization of the scalar scaling and works well with square systems. With a single scalar you run the risk of underflowing the small entries or blowing up the noise.

It is an art rather than science in a way. However every time you call an eigenvalue routine this is what happens behind the scenes as a first step. For rectangular systems this goes by the weird name of equilibration and then the scaling array works with left and right scaling matrix $D_r$ and $D_c$.

In MINPACK, they actually try to do a similar scheme for jacobian (see https://www.osti.gov/biblio/6997568 section 2.5). But they are not balancing rather brute scaling it down to 1.0 based on column norms (and without powers of 2). Obviously nothing beats improving the problem formulation but these should also help.

If you wish to test balancing, you can use scipy.linalg.matrix_balance but it was one of my first FOSS contributions back in the day and it is a slow implementation. Another thing I should circle back to one day. If it is too unbearable call GEBAL directly via scipy.linalg.lapack.dgebal

- Touches up and adds documentation for scale-invariant LM
- More generic function signature for update_scaling_fn
@michael-0brien
Copy link
Author

Comments should be addressed. Added the new ScaledLevenbergMarquardt, max_diagonal_scaling_update, and ScaledDampedNewtonDescent to the docs.

@johannahaffner
Copy link
Collaborator

johannahaffner commented Nov 28, 2025

I did not find the time to actually go through minpack or slsqp yet in terms of scaling of the data and improving the conditioning of the problem.

No worries!

We are almost finalizing rewriting all Fortran code in the entire SciPy codebase

and !!! - this is am amazing achievement.

  • which actually do look like they do a QR factorisation before applying the scaling. @ilayn can you comment?

I think here I should probably give a bit more context! Specifically, we're implementing a variant here where the maximum value along the diagonal of the Hessian is used to scale the regulariser in Levenberg-Marquardt ($A + \lambda \sigma I)x = b$, where $\sigma$ is the scaling factor).

The way we've implemented it here is by computing the Hessian approximation as $A = Jac^T Jac$ and then examining its diagonal, taking the maximum value and using that to scale the identity operator. We then proceed to factorise the operator.
And we were wondering whether this is as it is done in Scipy, since it is our interpretation of the More 1977 paper that Scipy also cites. Specifically, we're going for equation $(3.1)$, having previously implemented $(3.2)$. We normalise the equations to get a diagonal from which the maximum value can be extracted - the maximum value along the diagonal is then used as a scalar to multiply the identity matrix, so we get $(A + \lambda \sigma I)x = b$, where $\sigma$ is the scaling factor.

However, in the MINPACK code it looks like a QR factorisation is done before any scaling factors are computed. Is that also your conclusion when looking at the line I linked above? In that case, the scaling does indeed appear to be applied differently than is suggested in the paper.

( A − λ I ) x = b ⟹ ( D − 1 A D − λ I ) ( D − 1 x ) = ( D − 1 b ) ⟹ ( A ¯ − λ I ) x ¯ = b ¯

That looks quite different - given that the scaling is applied directly to $A$ here! We also add, rather than subtract, the regulariser.

It is an art rather than science in a way. However every time you call an eigenvalue routine this is what happens behind the scenes as a first step. For rectangular systems this goes by the weird name of equilibration and then the scaling array works with left and right scaling matrix D r and D c .

Good to know!

In MINPACK, they actually try to do a similar scheme for jacobian (see https://www.osti.gov/biblio/6997568 section 2.5). But they are not balancing rather brute scaling it down to 1.0 based on column norms (and without powers of 2). Obviously nothing beats improving the problem formulation but these should also help.

Is that in the MINPACK LM? Or elsewhere in MINPACK?

If you wish to test balancing, you can use scipy.linalg.matrix_balance

Also good to know :)

@johannahaffner
Copy link
Collaborator

Comments should be addressed. Added the new ScaledLevenbergMarquardt, max_diagonal_scaling_update, and ScaledDampedNewtonDescent to the docs.

Great, will take a look tomorrow!

def max_diagonal_scaling_update(
hessian: lx.AbstractLinearOperator, scaling_operator: lx.AbstractLinearOperator
) -> lx.AbstractLinearOperator:
"""Update the scaling matrix `D` as the maximum diagonal of the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be desirable to regularise the Hessian less stringently as the optimisation progresses? If we keep the maximum value around forever I would expect convergence to slow down - the regulariser will be too dominant, and the system will become too gradient-like for fast convergence.

Copy link
Collaborator

@johannahaffner johannahaffner Nov 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I have more questions about this one. In the More 1977 paper, equation (6.3) does seem to point to a related concept. Here we take whatever is larger between

  • the value on the diagonal of the Hessian approximation
  • the norm of the respective column in the Jacobian

in an element-wise fashion, but going back just one iteration. In other words, we would not accumulate large values throughout the optimisation. What is your take on this section of the paper?

max_diagonal_scaling_update should also have a reference, I think, even if this ends up being a bit repetitive given that we have references elsewhere. As soon as we'll have more than one of these scaling functions in the public API the correct reference won't be immediately obvious anymore.

- __init__

[`optimistix.ScaledLevenbergMarquardt`][] supports modularity in updating the scaling matrix `D` via the `update_scaling_fn` argument. The default method is exposed to public API and is documented below.
[`optimistix.ScaledLevenbergMarquardt`][] supports modularity in updating the scaling operator via the `update_scaling_fn` argument. The default method is exposed to public API and is documented below.
Copy link
Collaborator

@johannahaffner johannahaffner Nov 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"The default scaling method uses the maximum value along the diagonal of the Hessian approximation." Rest is self-explanatory.

Copy link
Collaborator

@johannahaffner johannahaffner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! I have some nits, some documentation requests and a question about the implementation of the diagonal scaling update - if I read the More paper correctly then they propose something different there which does not accumulate values throughout the optimisation, but rather only between successive iterations.

is given by `0.5*residual^2`, then we know that `grad=J^T residual`, and we make
the Gauss--Newton approximation `Hess ~ J^T J`. This reduces the above to
solving the (linear) least-squares problem
the Gauss--Newton approximation `Hess ~ J^T J`. When `D = I` this reduces the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"When D = I, then we do not need to access the diagonal of J^T J and can reduce the above to...".

Taking `D = I` is the original procedure proposed by Levenberg, whereas
More (1977) describes taking `D` as the maximum diagonal of the hessian
so far encountered. This procedure is used in [`scipy.least_squares`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.least_squares.html)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comments above - I think the values are not actually accumulated throughout the whole optimisation procedure. Just leaving this note here so we don't forget to update the docstring :)

Given that we don't actually know what MINPACK does (I'm confused by the QR factorisation taking place before any scaling does, for instance), let's drop the reference to Scipy here, move the More paper reference to the implementation of the max diagonal scaling method only.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My interpretation of what scipy does, and the interpretation of the More paper come from this reference: https://arxiv.org/abs/1201.5885

Perhaps should not have taken this too much at face value. I’m not quite sure what the right way to go is. But if it’s true that scipy does indeed use the max diagonal so far encountered, even if this is not what is described in More, it might be the right default to go with? Points taken though.

Based on this reference do you think we should investigate scipy further or implement what is in More? If the latter, let me know what you think the best approach would be.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading through the More paper again, I think what is described (Eq 6.3) is indeed what we’ve implemented. What makes you think the maximum is with respect to current and previous iterate?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the More paper - I interpreted the index $k-1$ as the value of the diagonal element in Jac^T Jac from the previous iteration, but looking at it again I can also see your interpretation. Since we're defining a recursion it actually makes a lot of sense to define it the way you did.

What I find very surprising is that it is sensible to keep the largest value thus far encountered for the entirety of the optimisation procedure. That will accumulate information from the start of the optimisation to the end, and I don't find that intuitive.

I'll take another look at the Scipy code, I'm still a bit hesitant to claim equivalence.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I see your point. This is worth looking into further, I’m also hesitant to claim equivalence. The scipy docs state that the LM implementation is “based on the paper [JJMore]” and has “a lot of smart tricks.” I am not sure if this implies there are tricks beyond what is discussed in the paper!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if you already found this, but there is a modernization of MINPACK that looks very useful for us! Here is the documentation + source code for the lmder routine, which is based off the implementation from the More paper: https://fortran-lang.github.io/minpack/proc/lmder.html

Looks like the repo is actively maintained too if we want to get in touch.

ScaledDampedNewtonDescent.__init__.__doc__ = """**Arguments:**
- `linear_solver`: The linear solver used to compute the Newton step.
- `update_scaling_fn`: A function with signature `fn(hessian, scaling_operator)`,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"update_scaling_fn: A function computing the scaling matrix D. Defaults to [optimistix.max_diagonal_scaling_update][]. ScaledLevenbergMarquardt supports custom scaling updates - any function implementing such an update should... " and then explain the signature.

"""A Levenberg--Marquardt method invariant to parameter rescaling.
This is a variant of [`optimistix.LevenbergMarquardt`][], which uses
the hessian to estimate how to preserve scale-invariance when applying
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hessian -> Hessian approximation

f_info: FunctionInfo.EvalGradHessian | FunctionInfo.ResidualJac,
linear_solver: lx.AbstractLinearSolver,
*,
scaling_update: _ScalingUpdate | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid the default argument here, and just keep it a mandatory keyword argument. This forces DampedNewtonDescent to specify that it does not use a scaling_update, which is cleaner and makes our internal code more self-documented.

@michael-0brien
Copy link
Author

I’ll hold off on pushing more changes until we get a handle on the scaling update, but I’ll get these changes pushed soon after!

- More clear documentation
- Remove default argument in damped_newton_step
@michael-0brien
Copy link
Author

I've gone ahead and incorporated the last round of feedback on the documentation, as well as some other things I noticed. I wrote the docstrings more carefully, considering that ScaledDampedNewtonDescent will be in the advanced API docs and the other two functions will be on a different page. As per our discussion, for a minimal change I've also more carefully worded the docs to not claim equivalence with MINPACK.

@michael-0brien
Copy link
Author

@johannahaffner I was thinking more about our discussion about what scaling update to choose, and how it’s a bit strange to keep around information from the whole history of iterations.

We may want to inform our decision of what to do by literature for first order optimizers that use a diagonal preconditioner estimated from the Hessian approximation. These methods I think can be thought of as the limit of LM when the LM parameter is large!

For example, see here: https://arxiv.org/abs/2109.05198. It may be worth thinking about having a scaling update that takes the scaling matrix to be an exponential moving average, which would downweight information from many previous iterations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants