@@ -540,196 +540,35 @@ function init_optimization!(
540540            JuMP. set_attribute (optim, " nlp_scaling_max_gradient"  , 10.0 / C)
541541        end 
542542    end 
543-     if  JuMP. solver_name (optim) ≠  " Ipopt" 
544-         #  everything with the splatting syntax:
545-         J_func, ∇J_func!, g_funcs, ∇g_funcs!, geq_funcs, ∇geq_funcs! =  get_optim_functions (
546-             mpc, optim
547-         )
548-     else 
549-         #  constraints with vector nonlinear oracle, objective function with splatting:
550-         g_oracle, geq_oracle, J_func, ∇J_func! =  get_nonlinops (mpc, optim)
551-     end 
552-     @operator (optim, J, nZ̃, J_func, ∇J_func!)
553-     @objective (optim, Min, J (Z̃var... ))
554-     if  JuMP. solver_name (optim) ≠  " Ipopt" 
555-         init_nonlincon! (mpc, model, transcription, g_funcs, ∇g_funcs!, geq_funcs, ∇geq_funcs!)
556-         set_nonlincon! (mpc, model, transcription, optim)
557-     else 
558-         set_nonlincon_exp! (mpc, g_oracle, geq_oracle)
559-     end  
543+     #  constraints with vector nonlinear oracle, objective function with splatting:
544+     g_oracle, geq_oracle, J_op =  get_nonlinops (mpc, optim)
545+     optim[:J_op ] =  J_op
546+     @objective (optim, Min, J_op (Z̃var... ))
547+     set_nonlincon! (mpc, optim, g_oracle, geq_oracle) 
560548    return  nothing 
561549end 
562550
563551""" 
564-     get_optim_functions( 
565-         mpc::NonLinMPC, optim::JuMP.GenericModel 
566-     ) -> J_func, ∇J_func!, g_funcs, ∇g_funcs!, geq_funcs, ∇geq_funcs! 
552+     get_nonlinops(mpc::NonLinMPC, optim) -> g_oracle, geq_oracle, J_op 
567553
568- Return the functions  for the nonlinear optimization of `mpc` [`NonLinMPC`](@ref) controller. 
554+ Return the operators  for the nonlinear optimization of `mpc` [`NonLinMPC`](@ref) controller. 
569555
570- Return the nonlinear objective `J_func` function, and `∇J_func!`, to compute its gradient.  
571- Also return vectors with the nonlinear inequality constraint functions `g_funcs`, and  
572- `∇g_funcs!`, for the associated gradients. Lastly, also return vectors with the nonlinear  
573- equality constraint functions `geq_funcs` and gradients `∇geq_funcs!`. 
574- 
575- This method is really intricate and I'm not proud of it. That's because of 3 elements: 
556+ Return `g_oracle` and `geq_oracle`, the inequality and equality [`VectorNonlinearOracle`](@extref MathOptInterface MathOptInterface.VectorNonlinearOracle) 
557+ for the two respective constraints. Note that `g_oracle` only includes the non-`Inf` 
558+ inequality constraints, thus it must be re-constructed if they change. Also return `J_op`,  
559+ the [`NonlinearOperator`](@extref JuMP NonlinearOperator) for the objective function, based 
560+ on the splatting syntax. This method is really intricate and that's because of 3 elements: 
576561
577562- These functions are used inside the nonlinear optimization, so they must be type-stable 
578563  and as efficient as possible. All the function outputs and derivatives are cached and 
579564  updated in-place if required to use the efficient [`value_and_jacobian!`](@extref DifferentiationInterface DifferentiationInterface.value_and_jacobian!). 
580- - The `JuMP` NLP syntax forces splatting for the decision variable, which implies use 
581-   of `Vararg{T,N}` (see the [performance tip](@extref Julia Be-aware-of-when-Julia-avoids-specializing)) 
565+ - The splatting syntax for objective functions implies the use of `Vararg{T,N}` (see the [performance tip](@extref Julia Be-aware-of-when-Julia-avoids-specializing)) 
582566  and memoization to avoid redundant computations. This is already complex, but it's even 
583-   worse knowing that most  automatic differentiation tools do not support splatting. 
567+   worse knowing that the  automatic differentiation tools do not support splatting. 
584568- The signature of gradient and hessian functions is not the same for univariate (`nZ̃ == 1`) 
585569  and multivariate (`nZ̃ > 1`) operators in `JuMP`. Both must be defined. 
586- 
587- Inspired from: [User-defined operators with vector outputs](@extref JuMP User-defined-operators-with-vector-outputs) 
588570""" 
589- function  get_optim_functions (mpc:: NonLinMPC , :: JuMP.GenericModel{JNT} ) where  JNT<: Real 
590-     #  ----------- common cache for Jfunc, gfuncs and geqfuncs  ----------------------------
591-     model =  mpc. estim. model
592-     transcription =  mpc. transcription
593-     grad, jac =  mpc. gradient, mpc. jacobian
594-     nu, ny, nx̂, nϵ =  model. nu, model. ny, mpc. estim. nx̂, mpc. nϵ
595-     nk =  get_nk (model, transcription)
596-     Hp, Hc =  mpc. Hp, mpc. Hc
597-     ng, nc, neq =  length (mpc. con. i_g), mpc. con. nc, mpc. con. neq
598-     nZ̃, nU, nŶ, nX̂, nK =  length (mpc. Z̃), Hp* nu, Hp* ny, Hp* nx̂, Hp* nk
599-     nΔŨ, nUe, nŶe =  nu* Hc +  nϵ, nU +  nu, nŶ +  ny  
600-     strict =  Val (true )
601-     myNaN  =  convert (JNT, NaN )
602-     J:: Vector{JNT}                    =  zeros (JNT, 1 )
603-     ΔŨ:: Vector{JNT}                   =  zeros (JNT, nΔŨ)
604-     x̂0end:: Vector{JNT}                =  zeros (JNT, nx̂)
605-     K0:: Vector{JNT}                   =  zeros (JNT, nK)
606-     Ue:: Vector{JNT} , Ŷe:: Vector{JNT}  =  zeros (JNT, nUe), zeros (JNT, nŶe)
607-     U0:: Vector{JNT} , Ŷ0:: Vector{JNT}  =  zeros (JNT, nU),  zeros (JNT, nŶ)
608-     Û0:: Vector{JNT} , X̂0:: Vector{JNT}  =  zeros (JNT, nU),  zeros (JNT, nX̂)
609-     gc:: Vector{JNT} , g:: Vector{JNT}   =  zeros (JNT, nc),  zeros (JNT, ng)
610-     geq:: Vector{JNT}                  =  zeros (JNT, neq)
611-     #  ---------------------- objective function ------------------------------------------- 
612-     function  Jfunc! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq)
613-         update_predictions! (ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
614-         return  obj_nonlinprog! (Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
615-     end 
616-     Z̃_∇J =  fill (myNaN, nZ̃)      #  NaN to force update_predictions! at first call
617-     ∇J_context =  (
618-         Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0), 
619-         Cache (Û0), Cache (K0), Cache (X̂0), 
620-         Cache (gc), Cache (g), Cache (geq),
621-     )
622-     ∇J_prep =  prepare_gradient (Jfunc!, grad, Z̃_∇J, ∇J_context... ; strict)
623-     ∇J =  Vector {JNT} (undef, nZ̃)
624-     function  update_objective! (J, ∇J, Z̃_∇J, Z̃arg)
625-         if  isdifferent (Z̃arg, Z̃_∇J)
626-             Z̃_∇J .=  Z̃arg
627-             J[], _ =  value_and_gradient! (Jfunc!, ∇J, ∇J_prep, grad, Z̃_∇J, ∇J_context... )
628-         end 
629-     end     
630-     function  J_func (Z̃arg:: Vararg{T, N} ) where  {N, T<: Real }
631-         update_objective! (J, ∇J, Z̃_∇J, Z̃arg)
632-         return  J[]:: T 
633-     end 
634-     ∇J_func! =  if  nZ̃ ==  1         #  univariate syntax (see JuMP.@operator doc):
635-         function  (Z̃arg)
636-             update_objective! (J, ∇J, Z̃_∇J, Z̃arg)
637-             return  ∇J[begin ]
638-         end 
639-     else                         #  multivariate syntax (see JuMP.@operator doc):
640-         function  (∇Jarg:: AbstractVector{T} , Z̃arg:: Vararg{T, N} ) where  {N, T<: Real }
641-             update_objective! (J, ∇J, Z̃_∇J, Z̃arg)
642-             return  ∇Jarg .=  ∇J
643-         end 
644-     end 
645-     #  --------------------- inequality constraint functions -------------------------------
646-     function  gfunc! (g, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq)
647-         update_predictions! (ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
648-         return  g
649-     end 
650-     Z̃_∇g =  fill (myNaN, nZ̃)      #  NaN to force update_predictions! at first call
651-     ∇g_context =  (
652-         Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0), 
653-         Cache (Û0), Cache (K0), Cache (X̂0), 
654-         Cache (gc), Cache (geq),
655-     )
656-     #  temporarily enable all the inequality constraints for sparsity detection:
657-     mpc. con. i_g[1 : end - nc] .=  true 
658-     ∇g_prep  =  prepare_jacobian (gfunc!, g, jac, Z̃_∇g, ∇g_context... ; strict)
659-     mpc. con. i_g[1 : end - nc] .=  false 
660-     ∇g =  init_diffmat (JNT, jac, ∇g_prep, nZ̃, ng)
661-     function  update_con! (g, ∇g, Z̃_∇g, Z̃arg)
662-         if  isdifferent (Z̃arg, Z̃_∇g)
663-             Z̃_∇g .=  Z̃arg
664-             value_and_jacobian! (gfunc!, g, ∇g, ∇g_prep, jac, Z̃_∇g, ∇g_context... )
665-         end 
666-     end 
667-     g_funcs =  Vector {Function} (undef, ng)
668-     for  i in  eachindex (g_funcs)
669-         gfunc_i =  function  (Z̃arg:: Vararg{T, N} ) where  {N, T<: Real }
670-             update_con! (g, ∇g, Z̃_∇g, Z̃arg)
671-             return  g[i]:: T 
672-         end 
673-         g_funcs[i] =  gfunc_i
674-     end 
675-     ∇g_funcs! =  Vector {Function} (undef, ng)
676-     for  i in  eachindex (∇g_funcs!)
677-         ∇gfuncs_i! =  if  nZ̃ ==  1      #  univariate syntax (see JuMP.@operator doc):
678-             function  (Z̃arg:: T ) where  T<: Real 
679-                 update_con! (g, ∇g, Z̃_∇g, Z̃arg)
680-                 return  ∇g[i, begin ]
681-             end 
682-         else                         #  multivariate syntax (see JuMP.@operator doc):
683-             function  (∇g_i, Z̃arg:: Vararg{T, N} ) where  {N, T<: Real }
684-                 update_con! (g, ∇g, Z̃_∇g, Z̃arg)
685-                 return  ∇g_i .=  @views  ∇g[i, :] 
686-             end 
687-         end 
688-         ∇g_funcs![i] =  ∇gfuncs_i!
689-     end 
690-     #  --------------------- equality constraint functions ---------------------------------
691-     function  geqfunc! (geq, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g) 
692-         update_predictions! (ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
693-         return  geq
694-     end 
695-     Z̃_∇geq =  fill (myNaN, nZ̃)    #  NaN to force update_predictions! at first call
696-     ∇geq_context =  (
697-         Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0),
698-         Cache (Û0), Cache (K0),   Cache (X̂0),
699-         Cache (gc), Cache (g)
700-     )
701-     ∇geq_prep =  prepare_jacobian (geqfunc!, geq, jac, Z̃_∇geq, ∇geq_context... ; strict)
702-     ∇geq =  init_diffmat (JNT, jac, ∇geq_prep, nZ̃, neq)
703-     function  update_con_eq! (geq, ∇geq, Z̃_∇geq, Z̃arg)
704-         if  isdifferent (Z̃arg, Z̃_∇geq)
705-             Z̃_∇geq .=  Z̃arg
706-             value_and_jacobian! (geqfunc!, geq, ∇geq, ∇geq_prep, jac, Z̃_∇geq, ∇geq_context... )
707-         end 
708-     end 
709-     geq_funcs =  Vector {Function} (undef, neq)
710-     for  i in  eachindex (geq_funcs)
711-         geqfunc_i =  function  (Z̃arg:: Vararg{T, N} ) where  {N, T<: Real }
712-             update_con_eq! (geq, ∇geq, Z̃_∇geq, Z̃arg)
713-             return  geq[i]:: T 
714-         end 
715-         geq_funcs[i] =  geqfunc_i          
716-     end 
717-     ∇geq_funcs! =  Vector {Function} (undef, neq)
718-     for  i in  eachindex (∇geq_funcs!)
719-         #  only multivariate syntax, univariate is impossible since nonlinear equality
720-         #  constraints imply MultipleShooting, thus input increment ΔU and state X̂0 in Z̃:
721-         ∇geqfuncs_i! =  
722-             function  (∇geq_i, Z̃arg:: Vararg{T, N} ) where  {N, T<: Real }
723-                 update_con_eq! (geq, ∇geq, Z̃_∇geq, Z̃arg)
724-                 return  ∇geq_i .=  @views  ∇geq[i, :]
725-             end 
726-         ∇geq_funcs![i] =  ∇geqfuncs_i!
727-     end 
728-     return  J_func, ∇J_func!, g_funcs, ∇g_funcs!, geq_funcs, ∇geq_funcs!
729- end 
730- 
731- #  TODO : move docstring of method above here an re-work it
732- function  get_nonlinops (mpc:: NonLinMPC , :: JuMP.GenericModel{JNT} ) where  JNT<: Real 
571+ function  get_nonlinops (mpc:: NonLinMPC , optim:: JuMP.GenericModel{JNT} ) where  JNT<: Real 
733572    #  ----------- common cache for all functions  ----------------------------------------
734573    model =  mpc. estim. model
735574    transcription =  mpc. transcription
@@ -785,7 +624,7 @@ function get_nonlinops(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT<:Real
785624    gi_min =  fill (- myInf, ngi)
786625    gi_max =  zeros (JNT,   ngi)
787626    ∇gi_structure =  init_diffstructure (∇gi)
788-     g_oracle =  Ipopt . _VectorNonlinearOracle (;
627+     g_oracle =  MOI . VectorNonlinearOracle (;
789628        dimension =  nZ̃,
790629        l =  gi_min,
791630        u =  gi_max,
@@ -823,7 +662,7 @@ function get_nonlinops(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT<:Real
823662    end 
824663    geq_min =  geq_max =  zeros (JNT, neq)
825664    ∇geq_structure =  init_diffstructure (∇geq)
826-     geq_oracle =  Ipopt . _VectorNonlinearOracle (;
665+     geq_oracle =  MOI . VectorNonlinearOracle (;
827666        dimension =  nZ̃,
828667        l =  geq_min,
829668        u =  geq_max,
@@ -865,10 +704,10 @@ function get_nonlinops(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT<:Real
865704            return  ∇J_arg .=  ∇J
866705        end 
867706    end 
868-     return  g_oracle, geq_oracle, J_func, ∇J_func!
707+     J_op =  JuMP. add_nonlinear_operator (optim, nZ̃, J_func, ∇J_func!, name= :J_op )
708+     return  g_oracle, geq_oracle, J_op
869709end 
870710
871- 
872711""" 
873712    update_predictions!( 
874713        ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq,  
@@ -895,19 +734,20 @@ function update_predictions!(
895734end 
896735
897736""" 
898-     set_nonlincon_exp !(mpc::NonLinMPC, g_oracle, geq_oracle) 
737+     set_nonlincon !(mpc::NonLinMPC, optim , g_oracle, geq_oracle) 
899738
900739Set the nonlinear inequality and equality constraints for `NonLinMPC`, if any. 
901740""" 
902- function  set_nonlincon_exp! (
903-     mpc:: NonLinMPC , g_oracle, geq_oracle
904- )
905-     optim =  mpc. optim
741+ function  set_nonlincon! (
742+     mpc:: NonLinMPC , optim:: JuMP.GenericModel{JNT} , g_oracle, geq_oracle
743+ ) where  JNT<: Real 
906744    Z̃var =  optim[:Z̃var ]
907745    nonlin_constraints =  JuMP. all_constraints (
908-         optim, JuMP. Vector{JuMP. VariableRef}, Ipopt . _VectorNonlinearOracle 
746+         optim, JuMP. Vector{JuMP. VariableRef}, MOI . VectorNonlinearOracle{JNT} 
909747    )
910748    map (con_ref ->  JuMP. delete (optim, con_ref), nonlin_constraints)
749+     optim[:g_oracle ]   =  g_oracle
750+     optim[:geq_oracle ] =  geq_oracle
911751    any (mpc. con. i_g) &&  @constraint (optim, Z̃var in  g_oracle)
912752    mpc. con. neq >  0   &&  @constraint (optim, Z̃var in  geq_oracle)
913753    return  nothing 
0 commit comments