diff --git a/src/L2ODLL.jl b/src/L2ODLL.jl index 93a2dd1..856acec 100644 --- a/src/L2ODLL.jl +++ b/src/L2ODLL.jl @@ -178,6 +178,9 @@ end function y_shape(cache::DLLCache) return length.(get_y_dual(cache.dual_model, cache.decomposition)) end +function y_shape(dual_model::JuMP.Model, decomposition::AbstractDecomposition) + return length.(get_y_dual(dual_model, decomposition)) +end """ flatten_y(y) diff --git a/src/projection.jl b/src/projection.jl index f24ef99..f308885 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -32,3 +32,40 @@ function get_y_sets(dual_model, decomposition) for set in get_y_constraint(dual_model, decomposition) ] end + +function make_jump_proj_fn(decomposition::AbstractDecomposition, dual_model::JuMP.Model, optimizer; silent=true) + sets = get_y_sets(dual_model, decomposition) + shapes = y_shape(dual_model, decomposition) + + proj_model = JuMP.Model(optimizer) + + idxs = [(i, ji) for (i,j) in enumerate(shapes) for ji in 1:j] + JuMP.@variable(proj_model, y[idxs]) + + for (i, set) in enumerate(sets) + isnothing(set) && continue + y_vars = filter(ij->first(ij)==i, idxs) + if length(y_vars) == 1 + JuMP.@constraint(proj_model, y[only(y_vars)] ∈ set) + else + JuMP.@constraint(proj_model, y[y_vars] ∈ set) + end + end + + silent && JuMP.set_silent(proj_model) + proj_model.ext[:🔒] = ReentrantLock() + # TODO: define frule/rrule using Moreau + return (y_prediction) -> begin + lock(proj_model.ext[:🔒]) + try + JuMP.set_objective_function(proj_model, sum((y .- flatten_y(y_prediction)).^2)) + JuMP.set_objective_sense(proj_model, MOI.MIN_SENSE) + JuMP.optimize!(proj_model) + JuMP.assert_is_solved_and_feasible(proj_model) + + value.(y) + finally + unlock(proj_model.ext[:🔒]) + end + end +end