Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential precision loss in AutoDiff through the loss function #931

Open
IvanBioli opened this issue Mar 15, 2025 · 0 comments
Open

Potential precision loss in AutoDiff through the loss function #931

IvanBioli opened this issue Mar 15, 2025 · 0 comments
Labels

Comments

@IvanBioli
Copy link

IvanBioli commented Mar 15, 2025

Description
I have been testing the numerical accuracy of automatic differentiation in NeuralPDEs.jl and observed significant numerical errors when differentiating through the loss function defined by NeuralPDEs. Specifically, the forward-mode Jacobian-vector products (JVPs) with the residual vector defining the loss function exhibit errors in the range of 1e-8, while other related computations (such as direct evaluations of the model) remain at the expected machine precision (~1e-16).

This suggests that some internal operations in NeuralPDEs' loss formulation might be inadvertently using lower-precision arithmetic (Float32 instead of Float64), or otherwise introducing unexpected numerical instability.

Additional context and expected behavior
To investigate this issue, I implemented functions that explicitly define the loss as $L(\theta) = \sum_{i=1}^n (r_i(\theta))^2$, where $\theta$ are the neural network parameters, and the residuals are defined as:

  • For internal points: $r_i(\theta) = \mathcal{D}u(x_i) - f(x_i)$ where $\mathcal{D}$ is the differential operator.
  • For boundary points: $r_i(\theta) = u(x_i) - g(x_i)$ (with optional weighting constants).

I have attached a minimal working example below that demonstrates the issue. The key observations are:

  1. As a sanity check of my implementations, the relative error in evaluating loss(θ) versus loss_neuralpdes(θ) is on the order of 1e-16
  2. The error when computing JVPs via AutoForwardDiff for direct neural network evaluations is also on the order of 1e-16.
  3. However, the error in JVPs computed for the residual function (which contributes to the loss) is 1e-8, indicating a significant loss of numerical precision.

Minimal Reproducible Example 👇

using NeuralPDE, Lux, LuxCUDA, Random, ComponentArrays
using Optimization
using OptimizationOptimisers
import ModelingToolkit: Interval
using Plots
using Printf

################################# AUXILIARY FUNCTIONS ####################################
using Optimization: OptimizationProblem
using NeuralPDE: NeuralPDE, PINNRepresentation, recursive_eltype, EltypeAdaptor, safe_get_device, GridTraining
using Statistics: Statistics, mean

# only for PhysicsInformedNN
function merge_strategy_with_residual_vector(pinnrep::PINNRepresentation,
        strategy::GridTraining, datafree_pde_loss_function, datafree_bc_loss_function)
    (; domains, eqs, bcs, dict_indvars, dict_depvars) = pinnrep
    eltypeθ = recursive_eltype(pinnrep.flat_init_params)
    adaptor = EltypeAdaptor{eltypeθ}()

    train_sets = generate_training_sets(domains, strategy.dx, eqs, bcs, eltypeθ,
        dict_indvars, dict_depvars)

    # the points in the domain and on the boundary
    pde_train_sets, bcs_train_sets = train_sets |> adaptor
    pde_loss_functions = [get_residual_vector(pinnrep, _loss, _set, eltypeθ, strategy)
                        for (_loss, _set) in zip(
        datafree_pde_loss_function, pde_train_sets)]

    bc_loss_functions = [get_residual_vector(pinnrep, _loss, _set, eltypeθ, strategy)
                        for (_loss, _set) in zip(datafree_bc_loss_function, bcs_train_sets)]

    return pde_loss_functions, bc_loss_functions
end

function get_residual_vector(
        init_params, loss_function, train_set, eltype0, ::GridTraining; τ = nothing)
    init_params = init_params isa PINNRepresentation ? init_params.init_params : init_params
    train_set = train_set |> safe_get_device(init_params) |> EltypeAdaptor{eltype0}()
    return θ -> loss_function(train_set, θ)
end


function get_full_residual(prob::OptimizationProblem, symprob::PINNRepresentation)
    # Get PDE and BC residuals
    pde_residuals, bc_residuals = merge_strategy_with_residual_vector(symprob,
        symprob.strategy, symprob.loss_functions.datafree_pde_loss_functions, symprob.loss_functions.datafree_bc_loss_functions)

    # Setup weights for PDE and BCs
    flat_init_params = prob.u0
    adaloss = discretization.adaptive_loss
    @assert isnothing(adaloss) # FIXME: Assuming no adaloss
    num_additional_loss = 0

    adaloss === nothing && (adaloss = NonAdaptiveLoss{eltype(flat_init_params)}())
    
    # setup for all adaptive losses
    num_pde_losses = length(pde_residuals)
    num_bc_losses = length(bc_residuals)
    adaloss_T = eltype(adaloss.pde_loss_weights)

    # this will error if the user has provided a number of initial weights that is more than 1 and doesn't match the number of loss functions
    adaloss.pde_loss_weights = ones(adaloss_T, num_pde_losses) .* adaloss.pde_loss_weights
    adaloss.bc_loss_weights = ones(adaloss_T, num_bc_losses) .* adaloss.bc_loss_weights
    adaloss.additional_loss_weights = ones(adaloss_T, num_additional_loss) .*
                                    adaloss.additional_loss_weights

    function full_residual(θ)
        pde_losses = [pde_residual(θ) for pde_residual in pde_residuals]
        bc_losses = [bc_residual(θ) for bc_residual in bc_residuals]

        weighted_pde_losses = sqrt.(adaloss.pde_loss_weights) .* pde_losses ./ sqrt.(length.(pde_losses))
        weighted_bc_losses = sqrt.(adaloss.bc_loss_weights) .* bc_losses ./ sqrt.(length.(bc_losses))

        # full_res = hcat(Iterators.flatten((weighted_pde_losses, weighted_bc_losses))...)
        full_res = hcat(hcat(weighted_pde_losses...), hcat(weighted_bc_losses...))
        return full_res
    end

    return full_residual
end

function get_quadpoints(symprob::PINNRepresentation, strategy::GridTraining)
    (; domains, eqs, dict_indvars, dict_depvars) = symprob
    eltypeθ = recursive_eltype(symprob.flat_init_params)

    train_sets = hcat(generate_training_sets(domains, strategy.dx, eqs, [], eltypeθ,
        dict_indvars, dict_depvars)[1]...)
    return train_sets
end

################################# NEURALPDES TUTORIAL ####################################
@parameters t x y
@variables u(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2
Dt = Differential(t)
t_min = 0.0
t_max = 2.0
x_min = 0.0
x_max = 2.0
y_min = 0.0
y_max = 2.0

# 2D PDE
eq = Dt(u(t, x, y)) ~ Dxx(u(t, x, y)) + Dyy(u(t, x, y))

analytic_sol_func(t, x, y) = exp(x + y) * cos(x + y + 4t)
# Initial and boundary conditions
bcs = [u(t_min, x, y) ~ analytic_sol_func(t_min, x, y),
    u(t, x_min, y) ~ analytic_sol_func(t, x_min, y),
    u(t, x_max, y) ~ analytic_sol_func(t, x_max, y),
    u(t, x, y_min) ~ analytic_sol_func(t, x, y_min),
    u(t, x, y_max) ~ analytic_sol_func(t, x, y_max)]

# Space and time domains
domains = [t  Interval(t_min, t_max),
    x  Interval(x_min, x_max),
    y  Interval(y_min, y_max)]

# Neural network
inner = 25
chain = Chain(Dense(3, inner, σ), Dense(inner, 1))

strategy = GridTraining(0.1)
ps, st = Lux.setup(Random.default_rng(), chain)
ps = ps |> ComponentArray .|> Float64
discretization = PhysicsInformedNN(chain, strategy; init_params = ps)

@named pde_system = PDESystem(eq, bcs, domains, [t, x, y], [u(t, x, y)])
prob = discretize(pde_system, discretization)
symprob = symbolic_discretize(pde_system, discretization)

callback = function (p, l)
    println("Current loss is: $l")
    return false
end

# Definition of the residual vector
residual = get_full_residual(prob, symprob)
loss = θ -> sum(abs2, residual(θ))
loss_neuralpdes = θ -> prob.f(θ, prob.p)

################################# TESTS ON THE ACCURACY ####################################
using ForwardDiff, Zygote, DifferentiationInterface, LinearAlgebra

# Sanity check
θ = prob.u0
rel_err = (loss_neuralpdes(θ) - loss(θ)) / loss_neuralpdes(θ)
println("Error on the loss: \t\t\t $rel_err") # In the order of 1e-16

# Test of AutoForwardDiff for differentiation of model evaluations
x = get_quadpoints(symprob, strategy)
fun = ps -> chain(x, ps, st)[1]
v = randn(length(θ))
J_fwd = ForwardDiff.jacobian(fun, θ)
jvp_explicit = J_fwd * v
jvp_pushforward = DifferentiationInterface.pushforward(
    fun,
    AutoForwardDiff(),
    θ,
    (v,),
)[1]
println("AutoForwardDiff error on model jvp:\t $(norm(jvp_explicit - jvp_pushforward[:]) / norm(jvp_explicit))") # In the order of 1e-16

# Check with JVPs
v = randn(length(θ))
J_fwd = ForwardDiff.jacobian(residual, θ)
jvp_explicit = J_fwd * v
jvp_pushforward = DifferentiationInterface.pushforward(
    residual,
    AutoForwardDiff(),
    θ,
    (v,),
)[1]
println("AutoForwardDiff error on residual jvp:\t $(norm(jvp_explicit - jvp_pushforward[:]) / norm(jvp_explicit))") # In the order of 1e-8!!!
@IvanBioli IvanBioli added the bug label Mar 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant