Skip to content

Commit b5b022b

Browse files
Merge pull request #13 from LucasMSpereira/mt/fix_zygote
Use Zygote-over-Zygote and format
2 parents c0179f3 + 422e561 commit b5b022b

File tree

5 files changed

+74
-64
lines changed

5 files changed

+74
-64
lines changed

.JuliaFormatter.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
remove_extra_newlines = true
2+
always_for_in = true
3+
conditional_to_if = true

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@ version = "0.1.0"
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
88
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
99
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
10-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1110
ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207"
11+
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
1212
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
NonconvexIpopt = "bf347577-a06d-49ad-a669-8c0e005493b8"
1414
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
15+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1516
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1617
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
1718
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
@@ -20,8 +21,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2021
ChainRulesCore = "1"
2122
Distributions = "0.25"
2223
DistributionsAD = "0.6"
23-
ForwardDiff = "0.10"
2424
ImplicitDifferentiation = "0.2"
25+
JuliaFormatter = "1"
2526
NonconvexIpopt = "0.4"
2627
Reexport = "1"
2728
UnPack = "1"

src/Uncertainty.jl

Lines changed: 65 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,101 @@
11
module Uncertainty
22

3-
using ImplicitDifferentiation, Zygote, LinearAlgebra, ChainRulesCore, ForwardDiff
3+
using ImplicitDifferentiation, Zygote, LinearAlgebra, ChainRulesCore, SparseArrays
44
using UnPack, NonconvexIpopt, Statistics, Distributions, Reexport, DistributionsAD
55
@reexport using LinearAlgebra
66
@reexport using Statistics
77
export RandomFunction, FORM, RIA, MvNormal
88

99
struct RandomFunction{F,P,M}
10-
f::F
11-
p::P
12-
method::M
10+
f::F
11+
p::P
12+
method::M
1313
end
1414

1515
struct FORM{M}
16-
method::M
16+
method::M
1717
end
1818
struct RIA end
1919

2020
function get_forward(f, p, ::FORM{<:RIA})
21-
function forward(x)
22-
# gets an objective function of p
23-
obj = pc -> begin
24-
_p = pc[1:end-1]
25-
c = pc[end]
26-
return _p'*_p + c^2
27-
end
28-
# gets the constraint on p
29-
constr = pc -> begin
30-
_p = pc[1:end-1]
31-
c = pc[end]
32-
f(x, _p) .- c
33-
end
34-
# solve the RIA problem to find p0
35-
# should use and reuse the VecModel
36-
innerOptModel = Model(obj)
37-
n = size(p)[1]
38-
addvar!(innerOptModel, fill(-Inf, n + 1), fill(Inf, n + 1))
39-
add_eq_constraint!(innerOptModel, constr)
40-
result = optimize(innerOptModel, IpoptAlg(), [mean(p); 0.0], options = IpoptOptions(print_level = 0))
41-
return vcat(result.minimizer, result.problem.mult_g[1])
42-
end
43-
return forward
44-
end
45-
46-
# this is inefficient for multiple reasons
47-
# first is the use of ForwardDiff because Zygote over Zygote fails
48-
# second is that we are taking the jacobian wrt to both x and p because
49-
# of a limitation in Zygote over ForwardDiff
50-
function jac2(f, x, p)
51-
ForwardDiff.jacobian(
52-
xp -> f(xp[1:length(x)], xp[length(x)+1:end]),
53-
vcat(x, p),
54-
)[:,length(x)+1:end]
21+
function forward(x)
22+
# gets an objective function of p
23+
obj = pc -> begin
24+
_p = pc[1:end-1]
25+
c = pc[end]
26+
return _p' * _p + c^2
27+
end
28+
# gets the constraint on p
29+
constr = pc -> begin
30+
_p = pc[1:end-1]
31+
c = pc[end]
32+
f(x, _p) .- c
33+
end
34+
# solve the RIA problem to find p0
35+
# should use and reuse the VecModel
36+
innerOptModel = Model(obj)
37+
n = size(p)[1]
38+
addvar!(innerOptModel, fill(-Inf, n + 1), fill(Inf, n + 1))
39+
add_eq_constraint!(innerOptModel, constr)
40+
result = optimize(
41+
innerOptModel,
42+
IpoptAlg(),
43+
[mean(p); 0.0],
44+
options = IpoptOptions(print_level = 0),
45+
)
46+
return vcat(result.minimizer, result.problem.mult_g[1])
47+
end
48+
return forward
5549
end
5650

5751
function get_conditions(f, ::FORM{<:RIA})
58-
function kkt_conditions(x, pcmult)
59-
p = pcmult[1:end-2]
60-
c = pcmult[end-1]
61-
mult = pcmult[end]
62-
return vcat(2 * p + jac2(f, x, p)' * mult, 2c - mult, f(x, p) .- c)
63-
end
52+
function kkt_conditions(x, pcmult)
53+
p = pcmult[1:end-2]
54+
c = pcmult[end-1]
55+
mult = pcmult[end]
56+
return vcat(
57+
2 * p + Zygote.pullback(p -> f(x, p), p)[2](mult)[1],
58+
2c - mult,
59+
f(x, p) .- c,
60+
)
61+
end
6462
end
6563

6664
function get_implicit(f, p, method)
67-
forward = get_forward(f, p, method)
68-
kkt_conditions = get_conditions(f, method)
69-
return ImplicitFunction(forward, kkt_conditions)
65+
forward = get_forward(f, p, method)
66+
kkt_conditions = get_conditions(f, method)
67+
return ImplicitFunction(forward, kkt_conditions)
7068
end
7169

7270
function getp0(f, x, p, method::FORM{<:RIA})
73-
implicit_f = get_implicit(f, p, method)
74-
return implicit_f(x)[1:size(p)[1]]
71+
implicit_f = get_implicit(f, p, method)
72+
return implicit_f(x)[1:size(p)[1]]
7573
end
7674

7775
function RandomFunction(f, p; method = FORM(RIA()))
78-
return RandomFunction(f, p, method)
76+
return RandomFunction(f, p, method)
7977
end
8078

8179
_vec(x::Real) = [x]
8280
_vec(x) = x
8381

82+
function _jacobian(f, x)
83+
val, pb = Zygote.pullback(f, x)
84+
M = length(val)
85+
vecs = [Vector(sparsevec([i], [true], M)) for i in 1:M]
86+
Jt = reduce(hcat, first.(pb.(vecs)))
87+
return copy(Jt')
88+
end
89+
8490
function (f::RandomFunction)(x)
85-
mup = mean(f.p)
86-
covp = cov(f.p)
87-
p0 = getp0(f.f, x, f.p, f.method)
88-
dfdp0 = jac2(f.f, x, p0)
89-
fp0 = f.f(x, p0)
90-
muf = _vec(fp0) + dfdp0 * (mup - p0)
91-
covf = dfdp0 * covp * dfdp0'
92-
return MvNormal(muf, covf)
91+
mup = mean(f.p)
92+
covp = cov(f.p)
93+
p0 = getp0(f.f, x, f.p, f.method)
94+
dfdp0 = _jacobian(p -> f.f(x, p), p0)
95+
fp0 = f.f(x, p0)
96+
muf = _vec(fp0) + dfdp0 * (mup - p0)
97+
covf = dfdp0 * covp * dfdp0'
98+
return MvNormal(muf, covf)
9399
end
94100

95101
end

src/results.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using Nonconvex
22
Nonconvex.@load Ipopt
33

44
f(x) = sqrt(x[2])
5-
g(x, a, b) = (a*x[1] + b)^3 - x[2]
5+
g(x, a, b) = (a * x[1] + b)^3 - x[2]
66

77
model = Model(f)
88
addvar!(model, [0.0, 0.0], [10.0, 10.0])
@@ -16,4 +16,4 @@ propertynames(r)
1616
println(propertynames(r.problem))
1717
typeof(r.problem.intermediate)
1818
propertynames(r.problem.intermediate)
19-
r.problem
19+
r.problem

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using Uncertainty, Test, FiniteDifferences, Zygote
1212
d = rf(x)
1313
function obj(x)
1414
dist = rf(x)
15-
mean(dist)[1] + 2 * sqrt(cov(dist)[1,1])
15+
mean(dist)[1] + 2 * sqrt(cov(dist)[1, 1])
1616
end
1717
obj(x)
1818
g1 = FiniteDifferences.grad(central_fdm(5, 1), obj, x)[1]

0 commit comments

Comments
 (0)