Skip to content

Commit

Permalink
Add tests for BLAS.dot, BLAS.dotc, and BLAS.dotu (#842)
Browse files Browse the repository at this point in the history
* Add tests for BLAS rules

* Make checks approximate

* Fix approx check for tuples

* Apply suggestion

* Add EnzymeTestUtils as test dependency

* Refactor to use EnzymeTestUtils

Also removes checks for consistency with 2-arg dot and pointer tests, since these should be covered by the base case.

* Make sure dev version of EnzymeTestUtils used in testing

* Move blas.jl to test dir

* Update path

* Try to fix deving of EnzymeTestUtils

* Dev EnzymeTestUtils while running tests

* Remove EnzymeTestUtils as test dep

* Load Pkg

* Remove workaround in CI

---------

Co-authored-by: William S. Moses <gh@wsmoses.com>
  • Loading branch information
sethaxen and wsmoses authored Sep 10, 2023
1 parent 04ced8a commit 744ffce
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
50 changes: 50 additions & 0 deletions test/blas.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using Enzyme
using EnzymeTestUtils
using LinearAlgebra
using Test

@testset "BLAS rules" begin
RTs = (Float32, Float64)
RCs = (ComplexF32, ComplexF64)
n = 10

@testset for fun in (BLAS.dot, BLAS.dotu, BLAS.dotc)
@testset "forward" begin
@testset for Tret in (
Const,
Duplicated,
DuplicatedNoNeed,
BatchDuplicated,
BatchDuplicatedNoNeed,
),
Tx in (Const, Duplicated, BatchDuplicated),
Ty in (Const, Duplicated, BatchDuplicated),
T in (fun == BLAS.dot ? RTs : RCs),
(sz, inc) in ((10, 1), ((2, 20), -2))

are_activities_compatible(Tret, Tx, Ty) || continue

x = randn(T, sz)
y = randn(T, sz)
atol = rtol = sqrt(eps(real(T)))
test_forward(fun, Tret, n, (x, Tx), inc, (y, Ty), inc; atol, rtol)
end
end

@testset "reverse" begin
@testset for Tret in (Const, Active),
Tx in (Const, Duplicated, BatchDuplicated),
Ty in (Const, Duplicated, BatchDuplicated),
T in (fun == BLAS.dot ? RTs : RCs),
(sz, inc) in ((10, 1), ((2, 20), -2))

are_activities_compatible(Tret, Tx, Ty) || continue

x = randn(T, sz)
y = randn(T, sz)
atol = rtol = sqrt(eps(real(T)))
test_reverse(fun, Tret, n, (x, Tx), inc, (y, Ty), inc; atol, rtol)
end
end
end
end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ if isfile(preferences_file) && !isfile(test_preferences_file)
end
end

# work around https://github.com/JuliaLang/Pkg.jl/issues/1585
using Pkg
Pkg.develop(PackageSpec(; path=joinpath(dirname(@__DIR__), "lib", "EnzymeTestUtils")))

using GPUCompiler
using Enzyme
using Test
Expand Down Expand Up @@ -78,6 +82,7 @@ include("typetree.jl")
include("ruleinvalidation.jl")
end
end
include("blas.jl")

f0(x) = 1.0 + x
function vrec(start, x)
Expand Down

0 comments on commit 744ffce

Please sign in to comment.