diff --git a/test/enzyme.jl b/test/enzyme.jl deleted file mode 100644 index 9ffe1501..00000000 --- a/test/enzyme.jl +++ /dev/null @@ -1,30 +0,0 @@ -using MatrixAlgebraKit -using Test -using LinearAlgebra: Diagonal -using CUDA, AMDGPU - -BLASFloats = (ComplexF64,) # full suite is too expensive on CI -GenericFloats = () #(BigFloat,) -@isdefined(TestSuite) || include("testsuite/TestSuite.jl") -using .TestSuite - -is_buildkite = get(ENV, "BUILDKITE", "false") == "true" - -m = 19 -for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) - #=if T ∈ BLASFloats - if CUDA.functional() - TestSuite.test_enzyme(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - #n == m && TestSuite.test_enzyme(Diagonal{T, CuVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) - end - if AMDGPU.functional() - TestSuite.test_enzyme(ROCMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - #TestSuite.test_enzyme(Diagonal{T, ROCVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) - end - end=# - if !is_buildkite - TestSuite.test_enzyme(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - #n == m && TestSuite.test_enzyme(Diagonal{T, Vector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) - end -end diff --git a/test/enzyme/eig.jl b/test/enzyme/eig.jl new file mode 100644 index 00000000..02283d75 --- /dev/null +++ b/test/enzyme/eig.jl @@ -0,0 +1,19 @@ +using MatrixAlgebraKit +using Test +using LinearAlgebra: Diagonal +using CUDA, AMDGPU + +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +GenericFloats = () +@isdefined(TestSuite) || include("../testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 19 +for T in (BLASFloats..., GenericFloats...) + TestSuite.seed_rng!(123) + if !is_buildkite + TestSuite.test_enzyme_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + end +end diff --git a/test/enzyme/eigh.jl b/test/enzyme/eigh.jl new file mode 100644 index 00000000..88f52549 --- /dev/null +++ b/test/enzyme/eigh.jl @@ -0,0 +1,19 @@ +using MatrixAlgebraKit +using Test +using LinearAlgebra: Diagonal +using CUDA, AMDGPU + +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +GenericFloats = () +@isdefined(TestSuite) || include("../testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 19 +for T in (BLASFloats..., GenericFloats...) + TestSuite.seed_rng!(123) + if !is_buildkite + TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + end +end diff --git a/test/enzyme/lq.jl b/test/enzyme/lq.jl new file mode 100644 index 00000000..6699ddfa --- /dev/null +++ b/test/enzyme/lq.jl @@ -0,0 +1,19 @@ +using MatrixAlgebraKit +using Test +using LinearAlgebra: Diagonal +using CUDA, AMDGPU + +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +GenericFloats = () +@isdefined(TestSuite) || include("../testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 19 +for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) + TestSuite.seed_rng!(123) + if !is_buildkite + TestSuite.test_enzyme_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end +end diff --git a/test/enzyme/orthnull.jl b/test/enzyme/orthnull.jl new file mode 100644 index 00000000..2e7a554d --- /dev/null +++ b/test/enzyme/orthnull.jl @@ -0,0 +1,19 @@ +using MatrixAlgebraKit +using Test +using LinearAlgebra: Diagonal +using CUDA, AMDGPU + +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +GenericFloats = () +@isdefined(TestSuite) || include("../testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 19 +for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) + TestSuite.seed_rng!(123) + if !is_buildkite + TestSuite.test_enzyme_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end +end diff --git a/test/enzyme/polar.jl b/test/enzyme/polar.jl new file mode 100644 index 00000000..31b89907 --- /dev/null +++ b/test/enzyme/polar.jl @@ -0,0 +1,19 @@ +using MatrixAlgebraKit +using Test +using LinearAlgebra: Diagonal +using CUDA, AMDGPU + +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +GenericFloats = () +@isdefined(TestSuite) || include("../testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 19 +for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) + TestSuite.seed_rng!(123) + if !is_buildkite + TestSuite.test_enzyme_polar(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end +end diff --git a/test/enzyme/qr.jl b/test/enzyme/qr.jl new file mode 100644 index 00000000..62a169c7 --- /dev/null +++ b/test/enzyme/qr.jl @@ -0,0 +1,19 @@ +using MatrixAlgebraKit +using Test +using LinearAlgebra: Diagonal +using CUDA, AMDGPU + +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +GenericFloats = () +@isdefined(TestSuite) || include("../testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 19 +for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) + TestSuite.seed_rng!(123) + if !is_buildkite + TestSuite.test_enzyme_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end +end diff --git a/test/enzyme/svd.jl b/test/enzyme/svd.jl new file mode 100644 index 00000000..6143f61e --- /dev/null +++ b/test/enzyme/svd.jl @@ -0,0 +1,19 @@ +using MatrixAlgebraKit +using Test +using LinearAlgebra: Diagonal +using CUDA, AMDGPU + +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +GenericFloats = () +@isdefined(TestSuite) || include("../testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 19 +for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) + TestSuite.seed_rng!(1234) + if !is_buildkite + TestSuite.test_enzyme_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 69c18501..cca71b39 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,11 +26,10 @@ if filter_tests!(testsuite, args) else is_apple_ci = Sys.isapple() && get(ENV, "CI", "false") == "true" if is_apple_ci - delete!(testsuite, "enzyme") filter!(p -> !startswith(first(p), "mooncake/"), testsuite) delete!(testsuite, "chainrules") end - Sys.iswindows() && delete!(testsuite, "enzyme") + (Sys.iswindows() || is_apple_ci) && filter!(p -> !startswith(first(p), "enzyme/"), testsuite) end end diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index 2edd0846..0a475773 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -15,6 +15,7 @@ using LinearAlgebra: Diagonal, norm, istriu, istril, I using Random, StableRNGs using Mooncake using AMDGPU, CUDA +using Enzyme, EnzymeTestUtils const tests = Dict() @@ -117,7 +118,16 @@ include("mooncake/svd.jl") include("mooncake/polar.jl") include("mooncake/orthnull.jl") -include("enzyme.jl") include("chainrules.jl") +# Enzyme +# ------ +include("enzyme/eig.jl") +include("enzyme/eigh.jl") +include("enzyme/qr.jl") +include("enzyme/lq.jl") +include("enzyme/svd.jl") +include("enzyme/polar.jl") +include("enzyme/orthnull.jl") + end diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index fce118a8..4170b216 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -1,3 +1,21 @@ +""" + remove_svd_gauge_dependence!(ΔV, D, V) + +Remove the gauge-dependent part from the cotangents `ΔU` and ΔVᴴ` of the singular vector matrices `U` +and `Vᴴ`. The singular vectors are only determined up to complex phase (and unitary mixing for degenerate +eigenvalues), so the corresponding components of `ΔU` and `ΔVᴴ` are projected out. +""" +function remove_svd_gauge_dependence!( + ΔU, ΔVᴴ, U, S, Vᴴ; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S) + ) + gaugepart = mul!(U' * ΔU, Vᴴ, ΔVᴴ', true, true) + gaugepart = project_antihermitian!(gaugepart) + gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0 + mul!(ΔU, U, gaugepart, -1, 1) + return ΔU, ΔVᴴ +end + """ remove_eig_gauge_dependence!(ΔV, D, V) @@ -163,6 +181,8 @@ function call_and_zero!(f!, A, alg) return F′ end +is_cpu(A) = typeof(parent(A)) <: Array + """ eigh_wrapper(f, A, alg) @@ -234,6 +254,11 @@ function ad_qr_compact_setup(A::Diagonal) end function ad_qr_null_setup(A) + m, n = size(A) + minmn = min(m, n) + Q, R = qr_compact(A) + T = eltype(A) + ΔN = Q * randn!(similar(A, T, minmn, max(0, m - minmn))) N = qr_null(A) ΔN = randn!(copy(N)) remove_qr_null_gauge_dependence!(ΔN, A, N) @@ -246,7 +271,6 @@ function ad_qr_full_setup(A) remove_qr_gauge_dependence!(ΔQR..., A, QR...) return QR, ΔQR end - ad_qr_full_setup(A::Diagonal) = ad_qr_compact_setup(A) function ad_qr_rank_deficient_compact_setup(A) @@ -516,8 +540,8 @@ end function ad_left_null_setup(A) m, n = size(A) T = eltype(A) - N = left_orth(A; alg = :qr)[1] * randn!(similar(A, T, min(m, n), m - min(m, n))) - ΔN = left_orth(A; alg = :qr)[1] * randn!(similar(A, T, min(m, n), m - min(m, n))) + N = left_orth(A)[1] * randn!(similar(A, T, min(m, n), m - min(m, n))) + ΔN = left_orth(A)[1] * randn!(similar(A, T, min(m, n), m - min(m, n))) return N, ΔN end @@ -533,7 +557,7 @@ ad_right_orth_setup(A::Diagonal) = ad_left_orth_setup(A) function ad_right_null_setup(A) m, n = size(A) T = eltype(A) - Nᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A; alg = :lq)[2] - ΔNᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A; alg = :lq)[2] + Nᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A)[2] + ΔNᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A)[2] return Nᴴ, ΔNᴴ end diff --git a/test/testsuite/enzyme.jl b/test/testsuite/enzyme.jl deleted file mode 100644 index e9b07a31..00000000 --- a/test/testsuite/enzyme.jl +++ /dev/null @@ -1,464 +0,0 @@ -using TestExtras -using MatrixAlgebraKit -using Enzyme, EnzymeTestUtils -using MatrixAlgebraKit: diagview, TruncatedAlgorithm -using LinearAlgebra: Diagonal, Hermitian, mul!, BlasFloat -using GenericLinearAlgebra, GenericSchur - -function enz_copy_eigh_full(A, alg) - A = (A + A') / 2 - return eigh_full(A, alg) -end - -function enz_copy_eigh_full!(A, DV::Tuple, alg::MatrixAlgebraKit.AbstractAlgorithm) - A = (A + A') / 2 - return eigh_full!(A, DV, alg) -end - -function enz_copy_eigh_vals(A; kwargs...) - A = (A + A') / 2 - return eigh_vals(A; kwargs...) -end - -function enz_copy_eigh_vals!(A, D; kwargs...) - A = (A + A') / 2 - return eigh_vals!(A, D; kwargs...) -end - -function enz_copy_eigh_vals(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_vals(A, alg; kwargs...) -end - -function enz_copy_eigh_vals!(A, D, alg; kwargs...) - A = (A + A') / 2 - return eigh_vals!(A, D, alg; kwargs...) -end - -function enz_copy_eigh_trunc_no_error(A, alg) - A = (A + A') / 2 - return eigh_trunc_no_error(A, alg) -end - -function enz_copy_eigh_trunc_no_error!(A, DV, alg) - A = (A + A') / 2 - return eigh_trunc_no_error!(A, DV, alg) -end - -# necessary due to name conflict with Mooncake -function enz_test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; ȳ = copy.(Δargs), return_act = Duplicated) - ΔA = randn!(similar(A)) - A_ΔA() = Duplicated(copy(A), copy(ΔA)) - function args_Δargs() - if isnothing(args) - return Const(args) - elseif args isa Tuple && all(isnothing, args) - return Const(args) - else - return Duplicated(copy.(args), copy.(Δargs)) - end - end - copy_activities = isnothing(alg) ? (Const(f), A_ΔA()) : (Const(f), A_ΔA(), Const(alg)) - inplace_activities = isnothing(alg) ? (Const(f!), A_ΔA(), args_Δargs()) : (Const(f!), A_ΔA(), args_Δargs(), Const(alg)) - - mode = EnzymeTestUtils.set_runtime_activity(ReverseSplitWithPrimal, false) - c_act = Const(EnzymeTestUtils.call_with_kwargs) - forward_copy, reverse_copy = autodiff_thunk( - mode, typeof(c_act), return_act, typeof(Const(())), map(typeof, copy_activities)... - ) - forward_inplace, reverse_inplace = autodiff_thunk( - mode, typeof(c_act), return_act, typeof(Const(())), map(typeof, inplace_activities)... - ) - copy_tape, copy_y_ad, copy_shadow_result = forward_copy(c_act, Const(()), copy_activities...) - inplace_tape, inplace_y_ad, inplace_shadow_result = forward_inplace(c_act, Const(()), inplace_activities...) - if !(copy_shadow_result === nothing) - flush(stdout) - EnzymeTestUtils.map_fields_recursive(copyto!, copy_shadow_result, copy.(ȳ)) - end - if !(inplace_shadow_result === nothing) - EnzymeTestUtils.map_fields_recursive(copyto!, inplace_shadow_result, copy.(ȳ)) - end - dx_copy_ad = only(reverse_copy(c_act, Const(()), copy_activities..., copy_tape)) - dx_inplace_ad = only(reverse_inplace(c_act, Const(()), inplace_activities..., inplace_tape)) - # check all returned derivatives between copy & inplace - for (i, (copy_act_i, inplace_act_i)) in enumerate(zip(copy_activities[2:end], inplace_activities[2:end])) - if copy_act_i isa Duplicated && inplace_act_i isa Duplicated - msg_deriv = "shadow derivative for argument $(i - 1) should match between copy and inplace" - EnzymeTestUtils.test_approx(copy_act_i.dval, inplace_act_i.dval, msg_deriv) - end - end - return -end - -function test_enzyme(T::Type, sz; kwargs...) - summary_str = testargs_summary(T, sz) - return @testset "Enzyme AD $summary_str" begin - test_enzyme_qr(T, sz; kwargs...) - test_enzyme_lq(T, sz; kwargs...) - if length(sz) == 1 || sz[1] == sz[2] - test_enzyme_eig(T, sz; kwargs...) - # missing Enzyme rule - eltype(T) <: BlasFloat && test_enzyme_eigh(T, sz; kwargs...) - end - test_enzyme_svd(T, sz; kwargs...) - if eltype(T) <: BlasFloat - test_enzyme_polar(T, sz; kwargs...) - test_enzyme_orthnull(T, sz; kwargs...) - end - end -end - -is_cpu(A) = typeof(parent(A)) <: Array - -function test_enzyme_qr( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "QR Enzyme AD rules $summary_str" begin - A = instantiate_matrix(T, sz) - fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - alg = MatrixAlgebraKit.default_qr_algorithm(A) - @testset "qr_compact" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - QR, ΔQR = ad_qr_compact_setup(A) - eltype(T) <: BlasFloat && test_reverse(qr_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, qr_compact!, qr_compact, A, QR, ΔQR, alg) - end - end - @testset "qr_null" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - N, ΔN = ad_qr_null_setup(A) - eltype(T) <: BlasFloat && test_reverse(qr_null, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔN) - is_cpu(A) && enz_test_pullbacks_match(rng, qr_null!, qr_null, A, N, ΔN, alg) - end - end - @testset "qr_full" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - QR, ΔQR = ad_qr_full_setup(A) - eltype(T) <: BlasFloat && test_reverse(qr_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, qr_full!, qr_full, A, QR, ΔQR, alg) - end - end - @testset "qr_compact - rank-deficient A" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - m, n = size(A) - r = min(m, n) - 5 - Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - QR, ΔQR = ad_qr_rank_deficient_compact_setup(Ard) - eltype(T) <: BlasFloat && test_reverse(qr_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, QR, ΔQR, alg) - end - end - end -end - -function test_enzyme_lq( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "LQ Enzyme AD rules $summary_str" begin - A = instantiate_matrix(T, sz) - alg = MatrixAlgebraKit.default_lq_algorithm(A) - fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - @testset "lq_compact" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - LQ, ΔLQ = ad_lq_compact_setup(A) - eltype(T) <: BlasFloat && test_reverse(lq_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, lq_compact!, lq_compact, A, LQ, ΔLQ, alg) - end - end - @testset "lq_null" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - Nᴴ, ΔNᴴ = ad_lq_null_setup(A) - eltype(T) <: BlasFloat && test_reverse(lq_null, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔNᴴ) - is_cpu(A) && enz_test_pullbacks_match(rng, lq_null!, lq_null, A, Nᴴ, ΔNᴴ, alg) - end - end - @testset "lq_full" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - LQ, ΔLQ = ad_lq_full_setup(A) - eltype(T) <: BlasFloat && test_reverse(lq_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, lq_full!, lq_full, A, LQ, ΔLQ, alg) - end - end - @testset "lq_compact -- rank-deficient A" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - m, n = size(A) - r = min(m, n) - 5 - Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(Ard) - eltype(T) <: BlasFloat && test_reverse(lq_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, LQ, ΔLQ, alg) - end - end - end -end - -function test_enzyme_eig( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "EIG Enzyme AD rules $summary_str" begin - A = make_eig_matrix(T, sz) - m = size(A, 1) - fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - alg = MatrixAlgebraKit.default_eig_algorithm(A) - @testset "eig_full" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - DV, ΔDV = ad_eig_full_setup(A) - if eltype(T) <: BlasFloat - test_reverse(eig_full, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔDV, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, eig_full!, eig_full, A, DV, ΔDV, alg) - else - is_cpu(A) && enz_test_pullbacks_match(rng, eig_full!, eig_full, A, (nothing, nothing), (nothing, nothing), alg; ȳ = ΔDV) - end - end - end - @testset "eig_vals" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - D, ΔD = ad_eig_vals_setup(A) - if eltype(T) <: BlasFloat - test_reverse(eig_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔD, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, eig_vals!, eig_vals, A, D, ΔD, alg) - else - is_cpu(A) && enz_test_pullbacks_match(rng, eig_vals!, eig_vals, A, nothing, nothing, alg; ȳ = ΔD) - end - end - end - @testset "eig_trunc" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - for r in 1:4:m - truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(r; by = abs)) - DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) - if eltype(T) <: BlasFloat - test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg, ȳ = ΔDVtrunc) - else - is_cpu(A) && enz_test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg, ȳ = ΔDVtrunc) - end - end - truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(5; by = real)) - DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) - if eltype(T) <: BlasFloat - test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg, ȳ = ΔDVtrunc) - else - is_cpu(A) && enz_test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg, ȳ = ΔDVtrunc) - end - end - end - end -end - -function test_enzyme_eigh( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "EIGH Enzyme AD rules $summary_str" begin - A = make_eigh_matrix(T, sz) - m = size(A, 1) - alg = MatrixAlgebraKit.default_eigh_algorithm(A) - fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - @testset "eigh_full" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - DV, ΔDV = ad_eigh_full_setup(A) - if eltype(T) <: BlasFloat - test_reverse(enz_copy_eigh_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔDV, fdm) - test_reverse(enz_copy_eigh_full!, RT, (A, TA), (DV, TA), (alg, Const); atol, rtol, output_tangent = ΔDV, fdm) - end - is_cpu(A) && enz_test_pullbacks_match(rng, enz_copy_eigh_full!, enz_copy_eigh_full, A, DV, ΔDV, alg) - end - end - @testset "eigh_vals" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - D, ΔD = ad_eigh_vals_setup(A) - eltype(T) <: BlasFloat && test_reverse(enz_copy_eigh_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔD, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, enz_copy_eigh_vals!, enz_copy_eigh_vals, A, D, ΔD, alg) - end - end - @testset "eigh_trunc" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - D = eigh_vals(A / 2) - for r in 1:4:m - truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) - eltype(T) <: BlasFloat && test_reverse(enz_copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, enz_copy_eigh_trunc_no_error!, enz_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg, ȳ = ΔDVtrunc, return_act = RT) - end - truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, D) / 2)) - DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) - eltype(T) <: BlasFloat && test_reverse(enz_copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, enz_copy_eigh_trunc_no_error!, enz_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg, ȳ = ΔDVtrunc, return_act = RT) - end - end - end -end - -function test_enzyme_svd( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "SVD Enzyme AD rules $summary_str" begin - A = instantiate_matrix(T, sz) - minmn = min(size(A)...) - alg = MatrixAlgebraKit.default_svd_algorithm(A) - fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - @testset "svd_compact" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - USVᴴ, ΔUSVᴴ = ad_svd_compact_setup(A) - if eltype(T) <: BlasFloat - test_reverse(svd_compact, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔUSVᴴ, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, svd_compact!, svd_compact, A, USVᴴ, ΔUSVᴴ, alg) - else - USVᴴ = MatrixAlgebraKit.initialize_output(svd_compact!, A, alg) - is_cpu(A) && enz_test_pullbacks_match(rng, svd_compact!, svd_compact, A, USVᴴ, (nothing, nothing, nothing), alg; ȳ = ΔUSVᴴ) - end - end - end - @testset "svd_full" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A) - if eltype(T) <: BlasFloat - test_reverse(svd_full, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔUSVᴴ, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, svd_full!, svd_full, A, USVᴴ, ΔUSVᴴ, alg) - else - USVᴴ = MatrixAlgebraKit.initialize_output(svd_full!, A, alg) - is_cpu(A) && enz_test_pullbacks_match(rng, svd_full!, svd_full, A, USVᴴ, (nothing, nothing, nothing), alg; ȳ = ΔUSVᴴ) - end - end - end - @testset "svd_vals" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - S, ΔS = ad_svd_vals_setup(A) - if eltype(T) <: BlasFloat - test_reverse(svd_vals, RT, (A, TA); atol, rtol, fkwargs = (alg = alg,), output_tangent = ΔS, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, ΔS, alg) - else - S = MatrixAlgebraKit.initialize_output(svd_vals!, A, alg) - is_cpu(A) && enz_test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, nothing, alg; ȳ = ΔS) - end - end - end - @testset "svd_trunc" begin - S, ΔS = ad_svd_vals_setup(A) - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - for r in 1:4:minmn - truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), truncrank(r)) - USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) - if eltype(T) <: BlasFloat - test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg, ȳ = ΔUSVᴴtrunc) - else - is_cpu(A) && enz_test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (nothing, nothing, nothing), (nothing, nothing, nothing), truncalg, ȳ = ΔUSVᴴtrunc) - end - end - truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), trunctol(atol = S[1, 1] / 2)) - USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) - if eltype(T) <: BlasFloat - test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg, ȳ = ΔUSVᴴtrunc) - else - is_cpu(A) && enz_test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (nothing, nothing, nothing), (nothing, nothing, nothing), truncalg, ȳ = ΔUSVᴴtrunc) - end - end - end - end -end - -# GLA works with polar, but these tests -# segfault because of Sylvester + BigFloat -function test_enzyme_polar( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "Polar Enzyme AD rules $summary_str" begin - A = instantiate_matrix(T, sz) - m, n = size(A) - alg = MatrixAlgebraKit.default_polar_algorithm(A) - @testset "left_polar" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - if m >= n - WP, ΔWP = ad_left_polar_setup(A) - eltype(T) <: BlasFloat && test_reverse(left_polar, RT, (A, TA), (alg, Const); atol, rtol) - is_cpu(A) && enz_test_pullbacks_match(rng, left_polar!, left_polar, A, WP, ΔWP, alg) - end - end - end - @testset "right_polar" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - if m <= n - PWᴴ, ΔPWᴴ = ad_right_polar_setup(A) - eltype(T) <: BlasFloat && test_reverse(right_polar, RT, (A, TA), (alg, Const); atol, rtol) - is_cpu(A) && enz_test_pullbacks_match(rng, right_polar!, right_polar, A, PWᴴ, ΔPWᴴ, alg) - end - end - end - end -end - -function test_enzyme_orthnull( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "Orthnull Enzyme AD rules $summary_str" begin - A = instantiate_matrix(T, sz) - m, n = size(A) - VC, ΔVC = ad_left_orth_setup(A) - CVᴴ, ΔCVᴴ = ad_right_orth_setup(A) - fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - @testset "left_orth" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - @testset for alg in (:polar, :qr) - n > m && alg == :polar && continue - eltype(T) <: BlasFloat && test_reverse(left_orth, RT, (A, TA); atol, rtol, fkwargs = (alg = alg,), fdm) - left_orth_alg!(A, VC) = left_orth!(A, VC; alg = alg) - left_orth_alg(A) = left_orth(A; alg = alg) - is_cpu(A) && enz_test_pullbacks_match(rng, left_orth_alg!, left_orth_alg, A, VC, ΔVC) - end - end - end - N, ΔN = ad_left_null_setup(A) - @testset "left_null" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - left_null_qr!(A, N) = left_null!(A, N; alg = :qr) - left_null_qr(A) = left_null(A; alg = :qr) - eltype(T) <: BlasFloat && test_reverse(left_null_qr, RT, (A, TA); output_tangent = ΔN, atol, rtol) - is_cpu(A) && enz_test_pullbacks_match(rng, left_null_qr!, left_null_qr, A, N, ΔN) - end - end - @testset "right_orth" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - @testset for alg in (:polar, :lq) - n < m && alg == :polar && continue - eltype(T) <: BlasFloat && test_reverse(right_orth, RT, (A, TA); atol, rtol, fkwargs = (alg = alg,), fdm) - right_orth_alg!(A, CVᴴ) = right_orth!(A, CVᴴ; alg = alg) - right_orth_alg(A) = right_orth(A; alg = alg) - is_cpu(A) && enz_test_pullbacks_match(rng, right_orth_alg!, right_orth_alg, A, CVᴴ, ΔCVᴴ) - end - end - end - Nᴴ, ΔNᴴ = ad_right_null_setup(A) - @testset "right_null" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - right_null_lq!(A, Nᴴ) = right_null!(A, Nᴴ; alg = :lq) - right_null_lq(A) = right_null(A; alg = :lq) - eltype(T) <: BlasFloat && test_reverse(right_null_lq, RT, (A, TA); output_tangent = ΔNᴴ, atol, rtol) - is_cpu(A) && enz_test_pullbacks_match(rng, right_null_lq!, right_null_lq, A, Nᴴ, ΔNᴴ) - end - end - end -end diff --git a/test/testsuite/enzyme/eig.jl b/test/testsuite/enzyme/eig.jl new file mode 100644 index 00000000..55c96aa0 --- /dev/null +++ b/test/testsuite/enzyme/eig.jl @@ -0,0 +1,87 @@ +""" + test_enzyme_eig(T, sz; kwargs...) + +Run all Enzyme AD tests for eigendecompositions of element type `T` and size `sz`. +""" +function test_enzyme_eig(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Enzyme eig $summary_str" begin + test_enzyme_eig_full(T, sz; kwargs...) + test_enzyme_eig_vals(T, sz; kwargs...) + test_enzyme_eig_trunc(T, sz; kwargs...) + end +end + +""" + test_enzyme_eig_full(T, sz; rng, atol, rtol) + +Test the Enzyme reverse-mode AD rule for `eig_full` and its in-place variant. +""" +function test_enzyme_eig_full( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "eig_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = make_eig_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(eig_full, A) + DV, ΔDV = ad_eig_full_setup(A) + test_reverse(eig_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔDV, fdm) + test_reverse(call_and_zero!, RT, (eig_full!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔDV, fdm) + end +end + +""" + test_enzyme_eig_vals(T, sz; rng, atol, rtol) + +Test the Enzyme reverse-mode AD rule for `eig_vals` and its in-place variant. +""" +function test_enzyme_eig_vals( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "eig_vals reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = make_eig_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(eig_vals, A) + D, ΔD = ad_eig_vals_setup(A) + test_reverse(eig_vals, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD, fdm) + test_reverse(call_and_zero!, RT, (eig_vals!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD, fdm) + end +end + +""" + test_enzyme_eig_trunc(T, sz; rng, atol, rtol) + +Test the Enzyme reverse-mode AD rules for `eig_trunc`, `eig_trunc_no_error`, and their +in-place variants, over a range of truncation ranks and a tolerance-based truncation. +""" +function test_enzyme_eig_trunc( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "eig_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = make_eig_matrix(T, sz) + m = size(A, 1) + + alg = MatrixAlgebraKit.select_algorithm(eig_full, A) + @testset "truncrank($r)" for r in round.(Int, range(1, m + 4, 4)) + trunc = truncrank(r; by = abs) + truncalg = TruncatedAlgorithm(alg, trunc) + A = make_eig_matrix(T, sz) + DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) + test_reverse(call_and_zero!, RT, (eig_trunc_no_error!, Const), (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) + end + @testset "trunctol" begin + A = make_eig_matrix(T, sz) + D = eig_vals(A) + trunc = trunctol(atol = maximum(abs, D) / 2; by = abs) + truncalg = TruncatedAlgorithm(alg, trunc) + DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) + test_reverse(call_and_zero!, RT, (eig_trunc_no_error!, Const), (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) + end + end +end diff --git a/test/testsuite/enzyme/eigh.jl b/test/testsuite/enzyme/eigh.jl new file mode 100644 index 00000000..e7f5416f --- /dev/null +++ b/test/testsuite/enzyme/eigh.jl @@ -0,0 +1,87 @@ +""" + test_enzyme_eigh(T, sz; kwargs...) + +Run all Enzyme AD tests for Hermitian eigendecompositions of element type `T` and size `sz`. +""" +function test_enzyme_eigh(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Enzyme eigh $summary_str" begin + test_enzyme_eigh_full(T, sz; kwargs...) + test_enzyme_eigh_vals(T, sz; kwargs...) + test_enzyme_eigh_trunc(T, sz; kwargs...) + end +end + +""" + test_enzyme_eigh_full(T, sz; rng, atol, rtol) + +Test the Enzyme reverse-mode AD rule for `eigh_full` and its in-place variant. +""" +function test_enzyme_eigh_full( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "eigh_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = make_eigh_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(eigh_full, A) + DV, ΔDV = ad_eigh_full_setup(A) + test_reverse(eigh_wrapper, RT, (eigh_full, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔDV, fdm) + test_reverse(eigh!_wrapper, RT, (eigh_full!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔDV, fdm) + end +end + +""" + test_enzyme_eigh_vals(T, sz; rng, atol, rtol) + +Test the Enzyme reverse-mode AD rule for `eigh_vals` and its in-place variant. +""" +function test_enzyme_eigh_vals( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "eigh_vals reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = make_eigh_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(eigh_vals, A) + D, ΔD = ad_eigh_vals_setup(A) + test_reverse(eigh_wrapper, RT, (eigh_vals, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD, fdm) + test_reverse(eigh!_wrapper, RT, (eigh_vals!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD, fdm) + end +end + +""" + test_enzyme_eigh_trunc(T, sz; rng, atol, rtol) + +Test the Enzyme reverse-mode AD rules for `eigh_trunc`, `eigh_trunc_no_error`, and their +in-place variants, over a range of truncation ranks. +""" +function test_enzyme_eigh_trunc( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "eigh_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = make_eigh_matrix(T, sz) + m = size(A, 1) + + alg = MatrixAlgebraKit.select_algorithm(eigh_full, A) + @testset "truncrank($r)" for r in round.(Int, range(1, m + 4, 4)) + trunc = truncrank(r; by = abs) + truncalg = TruncatedAlgorithm(alg, trunc) + A = make_eigh_matrix(T, sz) + DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + test_reverse(eigh_wrapper, RT, (eigh_trunc_no_error, Const), (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) + test_reverse(eigh!_wrapper, RT, (eigh_trunc_no_error!, Const), (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) + end + @testset "trunctol" begin + A = make_eigh_matrix(T, sz) + D = eigh_vals(A / 2, alg) + trunc = trunctol(; atol = maximum(abs, D) / 2) + truncalg = TruncatedAlgorithm(alg, trunc) + DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + test_reverse(eigh_wrapper, RT, (eigh_trunc_no_error, Const), (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) + test_reverse(eigh!_wrapper, RT, (eigh_trunc_no_error!, Const), (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) + end + end +end diff --git a/test/testsuite/enzyme/enzyme.jl b/test/testsuite/enzyme/enzyme.jl new file mode 100644 index 00000000..b65b7617 --- /dev/null +++ b/test/testsuite/enzyme/enzyme.jl @@ -0,0 +1,24 @@ +function call_and_zero!(f!, A, alg) + F′ = f!(A, alg) + MatrixAlgebraKit.zero!(A) + return F′ +end + +function test_enzyme(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Enzyme AD $summary_str" begin + test_enzyme_qr(T, sz; kwargs...) + test_enzyme_lq(T, sz; kwargs...) + if length(sz) == 1 || sz[1] == sz[2] + test_enzyme_eig(T, sz; kwargs...) + test_enzyme_eigh(T, sz; kwargs...) + end + test_enzyme_svd(T, sz; kwargs...) + if eltype(T) <: BlasFloat # no Sylvester for BigFloat + test_enzyme_polar(T, sz; kwargs...) + test_enzyme_orthnull(T, sz; kwargs...) + end + end +end + +is_cpu(A) = typeof(parent(A)) <: Array diff --git a/test/testsuite/enzyme/lq.jl b/test/testsuite/enzyme/lq.jl new file mode 100644 index 00000000..e4571a69 --- /dev/null +++ b/test/testsuite/enzyme/lq.jl @@ -0,0 +1,73 @@ +""" + test_enzyme_lq(T, sz; kwargs...) + +Run all Enzyme AD tests for LQ decompositions of element type `T` and size `sz`. +""" +function test_enzyme_lq(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Enzyme lq $summary_str" begin + test_enzyme_lq_compact(T, sz; kwargs...) + test_enzyme_lq_compact_rank_deficient(T, sz; kwargs...) + test_enzyme_lq_full(T, sz; kwargs...) + test_enzyme_lq_null(T, sz; kwargs...) + end +end + +function test_enzyme_lq_compact( + T::Type, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "lq_compact reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(lq_compact, A) + LQ, ΔLQ = ad_lq_compact_setup(A) + test_reverse(lq_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) + test_reverse(call_and_zero!, RT, (lq_compact!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) + end +end + +function test_enzyme_lq_compact_rank_deficient( + T::Type, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "lq_compact rank deficient A reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = instantiate_matrix(T, sz) + m, n = size(A) + r = min(m, n) - 5 + A = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + alg = MatrixAlgebraKit.select_algorithm(lq_compact, A) + LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(A) + test_reverse(lq_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) + test_reverse(call_and_zero!, RT, (lq_compact!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) + end +end + +function test_enzyme_lq_full( + T::Type, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "lq_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(lq_full, A) + LQ, ΔLQ = ad_lq_full_setup(A) + test_reverse(lq_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) + test_reverse(call_and_zero!, RT, (lq_full!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) + end +end + +function test_enzyme_lq_null( + T::Type, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "lq_null reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(lq_null, A) + Nᴴ, ΔNᴴ = ad_lq_null_setup(A) + test_reverse(lq_null, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔNᴴ) + test_reverse(call_and_zero!, RT, (lq_null!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔNᴴ) + end +end diff --git a/test/testsuite/enzyme/orthnull.jl b/test/testsuite/enzyme/orthnull.jl new file mode 100644 index 00000000..6785d9c0 --- /dev/null +++ b/test/testsuite/enzyme/orthnull.jl @@ -0,0 +1,128 @@ +""" + test_enzyme_orthnull(T, sz; kwargs...) + +Run all Enzyme AD tests for orthogonal basis and null space computations of element type `T` +and size `sz`. +""" +function test_enzyme_orthnull(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Enzyme orthnull $summary_str" begin + test_enzyme_left_orth(T, sz; kwargs...) + test_enzyme_right_orth(T, sz; kwargs...) + test_enzyme_left_null(T, sz; kwargs...) + test_enzyme_right_null(T, sz; kwargs...) + end +end + +""" + test_enzyme_left_orth(T, sz; rng, atol, rtol) + +Test the Enzyme reverse-mode AD rules for `left_orth` with QR and polar (when `m >= n`) +algorithms, and their in-place variants. +""" +function test_enzyme_left_orth( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "left_orth reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = instantiate_matrix(T, sz) + m, n = size(A) + + @testset "qr" begin + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(left_orth!, A, :qr) + VC, ΔVC = ad_left_orth_setup(A) + test_reverse(left_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔVC) + test_reverse(call_and_zero!, RT, (left_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔVC) + end + + if m >= n + @testset "polar" begin + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(left_orth!, A, :polar) + VC, ΔVC = ad_left_orth_setup(A) + test_reverse(left_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔVC) + test_reverse(call_and_zero!, RT, (left_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔVC) + end + end + end +end + +""" + test_enzyme_right_orth(T, sz; rng, atol, rtol) + +Test the Enzyme reverse-mode AD rules for `right_orth` with LQ and polar (when `m <= n`) +algorithms, and their in-place variants. +""" +function test_enzyme_right_orth( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "right_orth reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = instantiate_matrix(T, sz) + m, n = size(A) + @testset "lq" begin + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(right_orth!, A, :lq) + CVᴴ, ΔCVᴴ = ad_right_orth_setup(A) + test_reverse(right_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔCVᴴ) + test_reverse(call_and_zero!, RT, (right_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔCVᴴ) + end + + if m <= n + @testset "polar" begin + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(right_orth!, A, :polar) + CVᴴ, ΔCVᴴ = ad_right_orth_setup(A) + test_reverse(right_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔCVᴴ) + test_reverse(call_and_zero!, RT, (right_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔCVᴴ) + end + end + end +end + +""" + test_enzyme_left_null(T, sz; rng, atol, rtol) + +Test the Enzyme reverse-mode AD rule for `left_null` with the QR algorithm and its +in-place variant. +""" +function test_enzyme_left_null( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "left_null reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = instantiate_matrix(T, sz) + @testset "qr" begin + alg = MatrixAlgebraKit.select_algorithm(left_null!, A, :qr) + N, ΔN = ad_left_null_setup(A) + test_reverse(left_null, RT, (A, TA), (alg, Const); output_tangent = ΔN, atol, rtol) + test_reverse(call_and_zero!, RT, (left_null!, Const), (A, TA), (alg, Const); output_tangent = ΔN, atol, rtol) + end + end +end + +""" + test_enzyme_right_null(T, sz; rng, atol, rtol) + +Test the Enzyme reverse-mode AD rule for `right_null` with the LQ algorithm and its +in-place variant. +""" +function test_enzyme_right_null( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "right_null reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = instantiate_matrix(T, sz) + @testset "lq" begin + alg = MatrixAlgebraKit.select_algorithm(right_null!, A, :lq) + Nᴴ, ΔNᴴ = ad_right_null_setup(A) + test_reverse(right_null, RT, (A, TA), (alg, Const); output_tangent = ΔNᴴ, atol, rtol) + test_reverse(call_and_zero!, RT, (right_null!, Const), (A, TA), (alg, Const); output_tangent = ΔNᴴ, atol, rtol) + end + end +end diff --git a/test/testsuite/enzyme/polar.jl b/test/testsuite/enzyme/polar.jl new file mode 100644 index 00000000..bfc889c2 --- /dev/null +++ b/test/testsuite/enzyme/polar.jl @@ -0,0 +1,56 @@ +""" + test_enzyme_polar(T, sz; kwargs...) + +Run all Enzyme AD tests for polar decompositions of element type `T` and size `sz`. +""" +function test_enzyme_polar(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Enzyme polar $summary_str" begin + test_enzyme_left_polar(T, sz; kwargs...) + test_enzyme_right_polar(T, sz; kwargs...) + end +end + +""" + test_enzyme_left_polar(T, sz; rng, atol, rtol) + +Test the Enzyme reverse-mode AD rule for `left_polar` and its in-place variant. Only runs +for tall or square matrices (`m >= n`). +""" +function test_enzyme_left_polar( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "left_polar reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = instantiate_matrix(T, sz) + m, n = size(A) + if m >= n + alg = MatrixAlgebraKit.select_algorithm(left_polar, A) + WP, ΔWP = ad_left_polar_setup(A) + test_reverse(left_polar, RT, (A, TA), (alg, Const); atol, rtol) + test_reverse(call_and_zero!, RT, (left_polar!, Const), (A, TA), (alg, Const); atol, rtol) + end + end +end + +""" + test_enzyme_right_polar(T, sz; rng, atol, rtol) + +Test the Enzyme reverse-mode AD rule for `right_polar` and its in-place variant. Only runs +for wide or square matrices (`m <= n`). +""" +function test_enzyme_right_polar( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "right_polar reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = instantiate_matrix(T, sz) + m, n = size(A) + if m <= n + alg = MatrixAlgebraKit.select_algorithm(right_polar, A) + PWᴴ, ΔPWᴴ = ad_right_polar_setup(A) + test_reverse(right_polar, RT, (A, TA), (alg, Const); atol, rtol) + test_reverse(call_and_zero!, RT, (right_polar!, Const), (A, TA), (alg, Const); atol, rtol) + end + end +end diff --git a/test/testsuite/enzyme/qr.jl b/test/testsuite/enzyme/qr.jl new file mode 100644 index 00000000..960d5915 --- /dev/null +++ b/test/testsuite/enzyme/qr.jl @@ -0,0 +1,73 @@ +""" + test_enzyme_qr(T, sz; kwargs...) + +Run all Enzyme AD tests for QR decompositions of element type `T` and size `sz`. +""" +function test_enzyme_qr(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Enzyme qr $summary_str" begin + test_enzyme_qr_compact(T, sz; kwargs...) + test_enzyme_qr_compact_rank_deficient(T, sz; kwargs...) + test_enzyme_qr_full(T, sz; kwargs...) + test_enzyme_qr_null(T, sz; kwargs...) + end +end + +function test_enzyme_qr_compact( + T::Type, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "qr_compact reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(qr_compact, A) + QR, ΔQR = ad_qr_compact_setup(A) + test_reverse(qr_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm) + test_reverse(call_and_zero!, RT, (qr_compact!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm) + end +end + +function test_enzyme_qr_compact_rank_deficient( + T::Type, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "qr_compact rank deficient A reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = instantiate_matrix(T, sz) + m, n = size(A) + r = min(m, n) - 5 + A = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + alg = MatrixAlgebraKit.select_algorithm(qr_compact, A) + QR, ΔQR = ad_qr_rank_deficient_compact_setup(A) + test_reverse(qr_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm) + test_reverse(call_and_zero!, RT, (qr_compact!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm) + end +end + +function test_enzyme_qr_full( + T::Type, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "qr_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(qr_full, A) + QR, ΔQR = ad_qr_full_setup(A) + test_reverse(qr_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm) + test_reverse(call_and_zero!, RT, (qr_full!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm) + end +end + +function test_enzyme_qr_null( + T::Type, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "qr_null reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(qr_null, A) + N, ΔN = ad_qr_null_setup(A) + test_reverse(qr_null, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔN) + test_reverse(call_and_zero!, RT, (qr_null!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔN) + end +end diff --git a/test/testsuite/enzyme/svd.jl b/test/testsuite/enzyme/svd.jl new file mode 100644 index 00000000..daa69c0f --- /dev/null +++ b/test/testsuite/enzyme/svd.jl @@ -0,0 +1,81 @@ +function test_enzyme_svd(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Enzyme svd $summary_str" begin + test_enzyme_svd_compact(T, sz; kwargs...) + test_enzyme_svd_full(T, sz; kwargs...) + test_enzyme_svd_vals(T, sz; kwargs...) + test_enzyme_svd_trunc(T, sz; kwargs...) + end +end + +function test_enzyme_svd_compact( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "svd_compact reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(svd_compact, A) + USVᴴ, ΔUSVᴴ = ad_svd_compact_setup(A) + test_reverse(svd_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) + test_reverse(call_and_zero!, RT, (svd_compact!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) + end +end + +function test_enzyme_svd_full( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "svd_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(svd_full, A) + USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A) + test_reverse(svd_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) + test_reverse(call_and_zero!, RT, (svd_full!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) + end +end + +function test_enzyme_svd_vals( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "svd_vals reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(svd_vals, A) + S, ΔS = ad_svd_vals_setup(A) + test_reverse(svd_vals, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm) + test_reverse(call_and_zero!, RT, (svd_vals!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm) + end +end + +function test_enzyme_svd_trunc( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + ) + return @testset "svd_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + A = instantiate_matrix(T, sz) + m, n = size(A) + minmn = min(m, n) + alg = MatrixAlgebraKit.select_algorithm(svd_compact, A) + @testset "truncrank($r)" for r in round.(Int, range(1, minmn + 4, 4)) + A = instantiate_matrix(T, sz) + trunc = truncrank(r) + truncalg = TruncatedAlgorithm(alg, trunc) + USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) + test_reverse(call_and_zero!, RT, (svd_trunc_no_error!, Const), (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) + end + @testset "trunctol" begin + A = instantiate_matrix(T, sz) + S = svd_vals(A, alg) + trunc = trunctol(atol = S[1] / 2) + truncalg = TruncatedAlgorithm(alg, trunc) + USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) + test_reverse(call_and_zero!, RT, (svd_trunc_no_error!, Const), (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) + end + end +end diff --git a/test/testsuite/mooncake/eig.jl b/test/testsuite/mooncake/eig.jl index 88042df2..5b0f1df0 100644 --- a/test/testsuite/mooncake/eig.jl +++ b/test/testsuite/mooncake/eig.jl @@ -50,7 +50,7 @@ function test_mooncake_eig_vals( return @testset "eig_vals" begin A = make_eig_matrix(T, sz) alg = MatrixAlgebraKit.select_algorithm(eig_vals, A) - D = eig_vals(A, alg) + D = eig_vals(A) output_tangent = Mooncake.randn_tangent(rng, D) Mooncake.TestUtils.test_rule( diff --git a/test/testsuite/mooncake/eigh.jl b/test/testsuite/mooncake/eigh.jl index 62952aa8..5a1c74ea 100644 --- a/test/testsuite/mooncake/eigh.jl +++ b/test/testsuite/mooncake/eigh.jl @@ -50,8 +50,8 @@ function test_mooncake_eigh_vals( return @testset "eigh_vals" begin A = make_eigh_matrix(T, sz) alg = MatrixAlgebraKit.select_algorithm(eigh_vals, A) - D = eigh_vals(A, alg) - output_tangent = Mooncake.randn_tangent(rng, D) + D, ΔD = ad_eigh_vals_setup(A) + output_tangent = Mooncake.primal_to_tangent!!(Mooncake.zero_tangent(D), ΔD) Mooncake.TestUtils.test_rule( rng, eigh_wrapper, eigh_vals, A, alg;