Skip to content

Infinite Gradient Handling #582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

Aziz-Shameem
Copy link

Added support for infinite gradient check as per this

Note that the method added also checks for nan values in grads along with infinite values, since both lead to the same errors in optimization and are therefore, related.

Let me know if any further changes/polishing is required. All the CI tests seem to pass.

Thanks

Copy link
Member

@timmens timmens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

This looks pretty good already. I have added a few review comments to your changes.

Additionally: Can you create a new test file tests/optimagic/optimization/test_invalid_function_value.py, and add two test cases where you check the cases of

  1. Infinite values and
  2. NaN values

These tests should fail without your changes, and pass with your changes; i.e., you will have to check that your code raises the correct error! We have many such tests in our test suite -- I would just search for "pytest.raises".

Thanks again and let me know if you have any questions!

@timmens
Copy link
Member

timmens commented Apr 28, 2025

Hey @Aziz-Shameem!

Just checking in — do you have a sense of the timeline for finishing up this PR? No rush if you're still working on it; just want to plan ahead a bit.

Thanks a lot!

@Aziz-Shameem
Copy link
Author

Hey @timmens

Really sorry for the delay. I was at ICLR, that was taking too much time during the day.
Since the conference has now ended, I should finish this up in a day or two max.

I also wanted to see if I could tackle the second part of this PR (deals with the error handling). I want to take that up after I submit this. Do you think that will be feasible ? @janosg did mention it requires more knowledge of the internals, so I will have to read up for it.

Thanks.

@janosg
Copy link
Member

janosg commented Apr 28, 2025

@Aziz-Shameem let's first finish this PR and then open a seperate one to implement the penalty approach instead of throwing the error.

@Aziz-Shameem
Copy link
Author

Aziz-Shameem commented Apr 29, 2025

Hey @timmens, @janosg
Made the changes, I have also added several tests for many input-combinations.
This does not currently handle numerical derivatives correctly, even though it checks for inf/nan values in the jacobian while evaluating numerical derivatives.

I changed the error message to make it more informative. It now reads as follows :
image

@Aziz-Shameem
Copy link
Author

For handling such cases when using numerical derivatives :
I noticed assert_finite_jac() catches a portion of the cases itself. The others seem to be like the example I attach below :
image

Maybe we could fix a certain (±)THRESHOLD value, have it flag it as infinite when an element of the gradient crosses it, and throw the same error as for the cases when jac is provided explicitly.
If so, this could be added inside this check-function itself.

Copy link
Member

@timmens timmens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the changes, it looks much better now! There are a few things to consider before going forward though:

Feature:

  • Since _assert_finite_jac does not depend on self, I'd make this a regular function, just like _process_fun_value below.
  • For the error message to be more informative, I think we can go one step further.
    • I would add a keyword argument origin: Literal["numerical", "jac", "fun_and_jac"]. I.e. you write params: PyTree, *, origin: ....
    • For the error message you can then make the distinction whether the error occurred in a user provided derivative (jac or fun_and_jac case), or whether it occurred during a numerical derivative, in which case the user needs to check the criterion function fun.
    • The resulting errors could look like this:
      • The optimization failed because the derivative (via jac) contains infinite or NaN values. Please validate the derivative function.
      • The optimization failed because the derivative (via fun_and_jac) contains infinite or NaN values. Please validate the derivative function.
      • The optimization failed because the numerical derivative (computed using fun) contains infinite or NaN values. Please validate the criterion function.
    • Of course, you would still also print the parameter and Jacobian values.
    • Make sure to consistently capitalize Jacobian!

Tests:

You have many tests for different combinations, which is great! However, there is a lot of code duplication, and in this case one general PyTree parameter should be enough. I have added a proposal below. If you customize the error message per origin, then ideally, you should adjust the message in the tests as well, so that we can additionally test that the correct error message is thrown.

Proposal:

import numpy as np
import pytest

from optimagic.exceptions import UserFunctionRuntimeError
from optimagic.optimization.optimize import minimize


# ======================================================================================
# Test setup:
# --------------------------------------------------------------------------------------
# We test that minimize raises an error if the user function returns a jacobian
# containing invalid values (e.g. np.inf, np.nan). To test that this works not only at
# the start parameters, we create jac functions that return invalid values if the
# parameter norm becomes smaller than 1. For this test, we assume the following
# parameter structure: {"a": 1, "b": np.array([2, 3])}
# ======================================================================================

def sphere(params):
    return params["a"] ** 2 + (params["b"] ** 2).sum()


def sphere_gradient(params):
    return {"a": 2 * params["a"], "b": 2 * params["b"]}


def sphere_and_gradient(params):
    return sphere(params), sphere_gradient(params)
        

def params_norm(params):
    squared_norm = (
        params["a"] ** 2 + np.linalg.norm(params["b"]) ** 2
    )
    return np.sqrt(squared_norm)


def get_invalid_jac(invalid_jac_value):
    """Get function that returns invalid jac if the parameter norm < 1."""

    def jac(params):
        if params_norm(params) < 1:
            return invalid_jac_value
        else:
            return sphere_gradient(params)
        
    return jac

def get_invalid_fun_and_jac(invalid_jac_value):
    """Get function that returns invalid fun and jac if the parameter norm < 1."""

    def fun_and_jac(params):
        if params_norm(params) < 1:
            return sphere(params), invalid_jac_value
        else:
            return sphere_and_gradient(params)

    return fun_and_jac

# ======================================================================================
# Tests
# ======================================================================================


INVALID_JACOBIAN_VALUES = [
    {"a": np.inf, "b": 2 * np.array([1, 2])},
    {"a": 1, "b": 2 * np.array([np.inf, 2])},
    {"a": np.nan, "b": 2 * np.array([1, 2])},
    {"a": 1, "b": 2 * np.array([np.nan, 2])},
]


@pytest.mark.parametrize("invalid_jac_value", INVALID_JACOBIAN_VALUES)
def test_minimize_with_invalid_jac(invalid_jac_value):
    params = {
        "a": 1,
        "b": np.array([3, 4]),
    }

    with pytest.raises(
        UserFunctionRuntimeError,
        match="The optimization received Jacobian containing infinite"
    ):
        minimize(
            fun=sphere,
            params=params,
            algorithm="scipy_lbfgsb",
            jac=get_invalid_jac(invalid_jac_value),
        )


@pytest.mark.parametrize("invalid_jac_value", INVALID_JACOBIAN_VALUES)
def test_minimize_with_invalid_fun_and_jac(invalid_jac_value):
    params = {
        "a": 1,
        "b": np.array([3, 4]),
    }

    with pytest.raises(
        UserFunctionRuntimeError,
        match="The optimization received Jacobian containing infinite"
    ):
        minimize(
            params=params,
            algorithm="scipy_lbfgsb",
            fun_and_jac=get_invalid_fun_and_jac(invalid_jac_value),
        )

@Aziz-Shameem
Copy link
Author

Added the changes.
mypy fails at three places for now, although it is unrelated to the changes I have made.

Also, following up on @timmens

  1. I changed INVALID_JACOBIAN_VALUES a little to add a dict value as well, so that we have all three most common types as inputs - scalar, list and dict
  2. When you get inf/nan grads numerically, it sometimes occurs because of a sub-optimal optimizer (as in the case I attached as an image in my previous message). I reflect this in the corresponding error message
  3. PARAMS was being reused, so I declared it globally. Can change if it is not good practice.

Copy link

codecov bot commented May 5, 2025

Codecov Report

Attention: Patch coverage is 86.66667% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...agic/optimization/internal_optimization_problem.py 86.66% 2 Missing ⚠️
Files with missing lines Coverage Δ
src/optimagic/parameters/space_conversion.py 97.52% <ø> (ø)
...agic/optimization/internal_optimization_problem.py 94.63% <86.66%> (-0.42%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@timmens timmens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice changes, thank you!

I only have a few tiny comments. When these are done, you can request a review from janosg.

If you have any questions, feel free to reach out!

Args:
out_jac: internal processed Jacobian to check for finiteness.
jac_value: original Jacobian value as returned by the user function,
params: user-facing parameter representation at evaluation point.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

origin argument description is missing. Please add to the description that it is only used for a more detailed error message.

@@ -543,6 +545,8 @@ def func(x: NDArray[np.float64]) -> SpecificFunctionValue:
warnings.warn(msg)
fun_value, jac_value = self._error_penalty_func(x)

_assert_finite_jac(jac_value, jac_value, params, "numerical")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you call this function using the argument names, like so:

_assert_finite_jac(
    out_jac=jac_value,
    jac_value=jac_value,
    params=params,
    origin="numerical"
)

Same for all the other instances where you call it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. If you want to use a global PARAMS object, then I'd recommend defining a pytest.fixture called params. This you can then add to the test functions.
  2. Since PARAMS is already a dictionary, this case is already covered. I prefer the case without the entry "c" inside PARAMS, because it is less complex and easier to grasp.
  3. In my comment I added a description for the test module. I would like to see some sort of description. That can be like the section comment I wrote in my proposal, or a module docstring at the top of the file. In any case, some new developer should be able to understand directly why we do these tests and the specific setup when reading the description.

@@ -508,6 +509,7 @@ def func(x: NDArray[np.float64]) -> SpecificFunctionValue:
p = self._converter.params_from_internal(x)
return self._fun(p)

params = self._converter.params_from_internal(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since params is only needed for _assert_finite_jac here, you could also call the converter during the function call: _assert_finite_jac(..., params=self._converter.params_from_internal(x), ...).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants