| 
1 | 1 | module Uncertainty  | 
2 | 2 | 
 
  | 
3 |  | -using ImplicitDifferentiation, Zygote, LinearAlgebra, ChainRulesCore, ForwardDiff  | 
 | 3 | +using ImplicitDifferentiation, Zygote, LinearAlgebra, ChainRulesCore, SparseArrays  | 
4 | 4 | using UnPack, NonconvexIpopt, Statistics, Distributions, Reexport, DistributionsAD  | 
5 | 5 | @reexport using LinearAlgebra  | 
6 | 6 | @reexport using Statistics  | 
7 | 7 | export RandomFunction, FORM, RIA, MvNormal  | 
8 | 8 | 
 
  | 
9 | 9 | struct RandomFunction{F,P,M}  | 
10 |  | -  f::F  | 
11 |  | -  p::P  | 
12 |  | -  method::M  | 
 | 10 | +    f::F  | 
 | 11 | +    p::P  | 
 | 12 | +    method::M  | 
13 | 13 | end  | 
14 | 14 | 
 
  | 
15 | 15 | struct FORM{M}  | 
16 |  | -	method::M  | 
 | 16 | +    method::M  | 
17 | 17 | end  | 
18 | 18 | struct RIA end  | 
19 | 19 | 
 
  | 
20 | 20 | function get_forward(f, p, ::FORM{<:RIA})  | 
21 |  | -	function forward(x)  | 
22 |  | -		# gets an objective function of p  | 
23 |  | -		obj = pc -> begin  | 
24 |  | -		   _p = pc[1:end-1]  | 
25 |  | -		   c = pc[end]  | 
26 |  | -		   return _p'*_p + c^2  | 
27 |  | -	   end  | 
28 |  | -	   # gets the constraint on p  | 
29 |  | -	   constr = pc -> begin  | 
30 |  | -		   _p = pc[1:end-1]  | 
31 |  | -		   c = pc[end]  | 
32 |  | -		   f(x, _p) .- c  | 
33 |  | -	   end  | 
34 |  | -	   # solve the RIA problem to find p0  | 
35 |  | -	   # should use and reuse the VecModel  | 
36 |  | -	   innerOptModel = Model(obj)  | 
37 |  | -	   n = size(p)[1]  | 
38 |  | -	   addvar!(innerOptModel, fill(-Inf, n + 1), fill(Inf, n + 1))  | 
39 |  | -	   add_eq_constraint!(innerOptModel, constr)  | 
40 |  | -	   result = optimize(innerOptModel, IpoptAlg(), [mean(p); 0.0], options = IpoptOptions(print_level = 0))  | 
41 |  | -	   return vcat(result.minimizer, result.problem.mult_g[1])  | 
42 |  | -   end  | 
43 |  | -   return forward  | 
44 |  | -end  | 
45 |  | - | 
46 |  | -# this is inefficient for multiple reasons  | 
47 |  | -# first is the use of ForwardDiff because Zygote over Zygote fails  | 
48 |  | -# second is that we are taking the jacobian wrt to both x and p because  | 
49 |  | -# of a limitation in Zygote over ForwardDiff  | 
50 |  | -function jac2(f, x, p)  | 
51 |  | -	ForwardDiff.jacobian(  | 
52 |  | -		xp -> f(xp[1:length(x)], xp[length(x)+1:end]),  | 
53 |  | -		vcat(x, p),  | 
54 |  | -	)[:,length(x)+1:end]  | 
 | 21 | +    function forward(x)  | 
 | 22 | +        # gets an objective function of p  | 
 | 23 | +        obj = pc -> begin  | 
 | 24 | +            _p = pc[1:end-1]  | 
 | 25 | +            c = pc[end]  | 
 | 26 | +            return _p' * _p + c^2  | 
 | 27 | +        end  | 
 | 28 | +        # gets the constraint on p  | 
 | 29 | +        constr = pc -> begin  | 
 | 30 | +            _p = pc[1:end-1]  | 
 | 31 | +            c = pc[end]  | 
 | 32 | +            f(x, _p) .- c  | 
 | 33 | +        end  | 
 | 34 | +        # solve the RIA problem to find p0  | 
 | 35 | +        # should use and reuse the VecModel  | 
 | 36 | +        innerOptModel = Model(obj)  | 
 | 37 | +        n = size(p)[1]  | 
 | 38 | +        addvar!(innerOptModel, fill(-Inf, n + 1), fill(Inf, n + 1))  | 
 | 39 | +        add_eq_constraint!(innerOptModel, constr)  | 
 | 40 | +        result = optimize(  | 
 | 41 | +            innerOptModel,  | 
 | 42 | +            IpoptAlg(),  | 
 | 43 | +            [mean(p); 0.0],  | 
 | 44 | +            options = IpoptOptions(print_level = 0),  | 
 | 45 | +        )  | 
 | 46 | +        return vcat(result.minimizer, result.problem.mult_g[1])  | 
 | 47 | +    end  | 
 | 48 | +    return forward  | 
55 | 49 | end  | 
56 | 50 | 
 
  | 
57 | 51 | function get_conditions(f, ::FORM{<:RIA})  | 
58 |  | -	function kkt_conditions(x, pcmult)  | 
59 |  | -		p = pcmult[1:end-2]  | 
60 |  | -		c = pcmult[end-1]  | 
61 |  | -		mult = pcmult[end]  | 
62 |  | -		return vcat(2 * p + jac2(f, x, p)' * mult, 2c - mult, f(x, p) .- c)  | 
63 |  | -	end  | 
 | 52 | +    function kkt_conditions(x, pcmult)  | 
 | 53 | +        p = pcmult[1:end-2]  | 
 | 54 | +        c = pcmult[end-1]  | 
 | 55 | +        mult = pcmult[end]  | 
 | 56 | +        return vcat(  | 
 | 57 | +            2 * p + Zygote.pullback(p -> f(x, p), p)[2](mult)[1],  | 
 | 58 | +            2c - mult,  | 
 | 59 | +            f(x, p) .- c,  | 
 | 60 | +        )  | 
 | 61 | +    end  | 
64 | 62 | end  | 
65 | 63 | 
 
  | 
66 | 64 | function get_implicit(f, p, method)  | 
67 |  | -	forward = get_forward(f, p, method)  | 
68 |  | -	kkt_conditions = get_conditions(f, method)  | 
69 |  | -	return ImplicitFunction(forward, kkt_conditions)  | 
 | 65 | +    forward = get_forward(f, p, method)  | 
 | 66 | +    kkt_conditions = get_conditions(f, method)  | 
 | 67 | +    return ImplicitFunction(forward, kkt_conditions)  | 
70 | 68 | end  | 
71 | 69 | 
 
  | 
72 | 70 | function getp0(f, x, p, method::FORM{<:RIA})  | 
73 |  | -	implicit_f = get_implicit(f, p, method)  | 
74 |  | -	return implicit_f(x)[1:size(p)[1]]  | 
 | 71 | +    implicit_f = get_implicit(f, p, method)  | 
 | 72 | +    return implicit_f(x)[1:size(p)[1]]  | 
75 | 73 | end  | 
76 | 74 | 
 
  | 
77 | 75 | function RandomFunction(f, p; method = FORM(RIA()))  | 
78 |  | -	return RandomFunction(f, p, method)  | 
 | 76 | +    return RandomFunction(f, p, method)  | 
79 | 77 | end  | 
80 | 78 | 
 
  | 
81 | 79 | _vec(x::Real) = [x]  | 
82 | 80 | _vec(x) = x  | 
83 | 81 | 
 
  | 
 | 82 | +function _jacobian(f, x)  | 
 | 83 | +    val, pb = Zygote.pullback(f, x)  | 
 | 84 | +    M = length(val)  | 
 | 85 | +    vecs = [Vector(sparsevec([i], [true], M)) for i in 1:M]  | 
 | 86 | +    Jt = reduce(hcat, first.(pb.(vecs)))  | 
 | 87 | +    return copy(Jt')  | 
 | 88 | +end  | 
 | 89 | + | 
84 | 90 | function (f::RandomFunction)(x)  | 
85 |  | -	mup = mean(f.p)  | 
86 |  | -	covp = cov(f.p)  | 
87 |  | -	p0 = getp0(f.f, x, f.p, f.method)  | 
88 |  | -	dfdp0 = jac2(f.f, x, p0)  | 
89 |  | -	fp0 = f.f(x, p0)  | 
90 |  | -	muf = _vec(fp0) + dfdp0 * (mup - p0)  | 
91 |  | -	covf = dfdp0 * covp * dfdp0'  | 
92 |  | -	return MvNormal(muf, covf)  | 
 | 91 | +    mup = mean(f.p)  | 
 | 92 | +    covp = cov(f.p)  | 
 | 93 | +    p0 = getp0(f.f, x, f.p, f.method)  | 
 | 94 | +    dfdp0 = _jacobian(p -> f.f(x, p), p0)  | 
 | 95 | +    fp0 = f.f(x, p0)  | 
 | 96 | +    muf = _vec(fp0) + dfdp0 * (mup - p0)  | 
 | 97 | +    covf = dfdp0 * covp * dfdp0'  | 
 | 98 | +    return MvNormal(muf, covf)  | 
93 | 99 | end  | 
94 | 100 | 
 
  | 
95 | 101 | end  | 
0 commit comments