diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..15f09c8 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,6 @@ +[deps] +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +MAT = "23992714-dd62-5051-b70f-ba57cb901cac" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/burgerset.mat b/test/burgerset.mat new file mode 100644 index 0000000..429ebce Binary files /dev/null and b/test/burgerset.mat differ diff --git a/test/deeponet.jl b/test/deeponet.jl index 0f3a482..c8fb823 100644 --- a/test/deeponet.jl +++ b/test/deeponet.jl @@ -1,4 +1,4 @@ -using Test, Random, Flux +using Test, Random, Flux, MAT @testset "DeepONet" begin @testset "dimensions" begin @@ -14,4 +14,50 @@ using Test, Random, Flux # Accept only Int as architecture parameters @test_throws MethodError DeepONet((32.5,64,72), (24,48,72), σ, tanh) @test_throws MethodError DeepONet((32,64,72), (24.1,48,72)) -end \ No newline at end of file +end + +#Just the first 16 datapoints from the Burgers' equation dataset +a = [0.83541104, 0.83479851, 0.83404712, 0.83315711, 0.83212979, 0.83096755, 0.82967374, 0.82825263, 0.82670928, 0.82504949, 0.82327962, 0.82140651, 0.81943734, 0.81737952, 0.8152405, 0.81302771] +sensors = collect(range(0, 1, length=16))' + +model = DeepONet((16, 22, 30), (1, 16, 24, 30), σ, tanh; init_branch=Flux.glorot_normal, bias_trunk=false) + +model(a,sensors) + +#forward pass +@test size(model(a, sensors)) == (1, 16) + +mgrad = Flux.Zygote.gradient((x,p)->sum(model(x,p)),a,sensors) + +#gradients +@test !iszero(Flux.Zygote.gradient((x,p)->sum(model(x,p)),a,sensors)[1]) +@test !iszero(Flux.Zygote.gradient((x,p)->sum(model(x,p)),a,sensors)[2]) + +#training +vars = matread("burgerset.mat") + +xtrain = vars["a"][1:280, :]' +xval = vars["a"][end-19:end, :]' + +ytrain = vars["u"][1:280, :] +yval = vars["u"][end-19:end, :] + +grid = collect(range(0, 1, length=1024))' +model = DeepONet((1024,1024,1024),(1,1024,1024),gelu,gelu) + +learning_rate = 0.001 +opt = ADAM(learning_rate) + +parameters = params(model) + +loss(xtrain,ytrain,sensor) = Flux.Losses.mse(model(xtrain,sensor),ytrain) + +evalcb() = @show(loss(xval,yval,grid)) + +Flux.@epochs 400 Flux.train!(loss, parameters, [(xtrain,ytrain,grid)], opt, cb = evalcb) + +ỹ = model(xval, grid) + +diffvec = vec(abs.((yval .- ỹ))) +mean_diff = sum(diffvec)/length(diffvec) +@test mean_diff < 0.4