diff --git a/solution/solution.go b/solution/solution.go index 692d227..bae1285 100644 --- a/solution/solution.go +++ b/solution/solution.go @@ -44,8 +44,8 @@ func ExtractValueOfVariable(s Solution, v symbolic.Variable) (float64, error) { // FindValueOfExpression evaluates a symbolic expression using the values from a solution. // It substitutes all variables in the expression with their values from the solution -// and returns the resulting scalar value. -func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error) { +// and returns the resulting symbolic expression (typically a constant). +func FindValueOfExpression(s Solution, expr symbolic.Expression) (symbolic.Expression, error) { // Get all variables in the expression vars := expr.Variables() @@ -54,7 +54,7 @@ func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error for _, v := range vars { val, err := ExtractValueOfVariable(s, v) if err != nil { - return 0.0, fmt.Errorf( + return nil, fmt.Errorf( "failed to extract value for variable %v: %w", v.ID, err, @@ -66,6 +66,31 @@ func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error // Substitute all variables with their values resultExpr := expr.SubstituteAccordingTo(subMap) + return resultExpr, nil +} + +// GetOptimalObjectiveValue evaluates the objective function of an optimization problem +// at the solution point. It uses the FindValueOfExpression function to compute the value +// of the objective expression using the variable values from the solution. +func GetOptimalObjectiveValue(sol Solution) (float64, error) { + // Get the problem from the solution + prob := sol.GetProblem() + if prob == nil { + return 0.0, fmt.Errorf("solution does not have an associated problem") + } + + // Get the objective expression from the problem + objectiveExpr := prob.Objective.Expression + if objectiveExpr == nil { + return 0.0, fmt.Errorf("problem does not have a defined objective") + } + + // Use FindValueOfExpression to evaluate the objective at the solution point + resultExpr, err := FindValueOfExpression(sol, objectiveExpr) + if err != nil { + return 0.0, fmt.Errorf("failed to evaluate objective expression: %w", err) + } + // Type assert to K (constant) to extract the float64 value resultK, ok := resultExpr.(symbolic.K) if !ok { diff --git a/testing/solution/solution_test.go b/testing/solution/solution_test.go index 4ad9211..c3a4f71 100644 --- a/testing/solution/solution_test.go +++ b/testing/solution/solution_test.go @@ -17,6 +17,15 @@ Description: (This seems like it is highly representative of the Gurobi solver; is there a reason to make it this way?) */ +// Helper function to convert a symbolic.Expression to float64 +func exprToFloat64(t *testing.T, expr symbolic.Expression) float64 { + resultK, ok := expr.(symbolic.K) + if !ok { + t.Fatalf("Expected result to be a constant, got type %T", expr) + } + return float64(resultK) +} + func TestSolution_ToMessage1(t *testing.T) { // Constants tempSol := solution.DummySolution{ @@ -161,11 +170,13 @@ func TestSolution_FindValueOfExpression1(t *testing.T) { expr := v1.Multiply(symbolic.K(2.0)).Plus(v2.Multiply(symbolic.K(3.0))) // Algorithm - result, err := solution.FindValueOfExpression(&tempSol, expr) + resultExpr, err := solution.FindValueOfExpression(&tempSol, expr) if err != nil { t.Errorf("FindValueOfExpression returned an error: %v", err) } + result := exprToFloat64(t, resultExpr) + expected := 13.0 if result != expected { t.Errorf( @@ -194,11 +205,13 @@ func TestSolution_FindValueOfExpression2(t *testing.T) { expr := symbolic.K(42.0) // Algorithm - result, err := solution.FindValueOfExpression(&tempSol, expr) + resultExpr, err := solution.FindValueOfExpression(&tempSol, expr) if err != nil { t.Errorf("FindValueOfExpression returned an error: %v", err) } + result := exprToFloat64(t, resultExpr) + expected := 42.0 if result != expected { t.Errorf( @@ -231,11 +244,13 @@ func TestSolution_FindValueOfExpression3(t *testing.T) { expr := v1.Plus(symbolic.K(10.0)) // Algorithm - result, err := solution.FindValueOfExpression(&tempSol, expr) + resultExpr, err := solution.FindValueOfExpression(&tempSol, expr) if err != nil { t.Errorf("FindValueOfExpression returned an error: %v", err) } + result := exprToFloat64(t, resultExpr) + expected := 15.5 if result != expected { t.Errorf( @@ -304,11 +319,13 @@ func TestSolution_FindValueOfExpression5(t *testing.T) { expr := v1.Plus(v2).Multiply(v3).Plus(symbolic.K(5.0)) // Algorithm - result, err := solution.FindValueOfExpression(&tempSol, expr) + resultExpr, err := solution.FindValueOfExpression(&tempSol, expr) if err != nil { t.Errorf("FindValueOfExpression returned an error: %v", err) } + result := exprToFloat64(t, resultExpr) + expected := 14.0 if result != expected { t.Errorf( @@ -377,3 +394,173 @@ func TestSolution_GetProblem2(t *testing.T) { t.Errorf("Expected GetProblem to return nil when no problem is set") } } + +/* +TestSolution_GetOptimalObjectiveValue1 +Description: + + This function tests whether we can compute the objective value at the solution point + for a simple linear objective. +*/ +func TestSolution_GetOptimalObjectiveValue1(t *testing.T) { + // Constants + p := problem.NewProblem("TestProblem") + v1 := p.AddVariable() + v2 := p.AddVariable() + + // Set objective: 2*v1 + 3*v2 + objectiveExpr := v1.Multiply(symbolic.K(2.0)).Plus(v2.Multiply(symbolic.K(3.0))) + err := p.SetObjective(objectiveExpr, problem.SenseMinimize) + if err != nil { + t.Errorf("Failed to set objective: %v", err) + } + + // Create solution with v1=2.0, v2=3.0 + // Expected objective value: 2*2.0 + 3*3.0 = 4.0 + 9.0 = 13.0 + tempSol := solution.DummySolution{ + Values: map[uint64]float64{ + v1.ID: 2.0, + v2.ID: 3.0, + }, + Objective: 13.0, + Status: solution_status.OPTIMAL, + Problem: p, + } + + // Algorithm + objectiveValue, err := solution.GetOptimalObjectiveValue(&tempSol) + if err != nil { + t.Errorf("GetOptimalObjectiveValue returned an error: %v", err) + } + + expected := 13.0 + if objectiveValue != expected { + t.Errorf( + "Expected objective value to be %v; received %v", + expected, + objectiveValue, + ) + } +} + +/* +TestSolution_GetOptimalObjectiveValue2 +Description: + + This function tests whether GetOptimalObjectiveValue returns an error + when the solution has no associated problem. +*/ +func TestSolution_GetOptimalObjectiveValue2(t *testing.T) { + // Constants + v1 := symbolic.NewVariable() + + tempSol := solution.DummySolution{ + Values: map[uint64]float64{ + v1.ID: 2.0, + }, + Objective: 2.3, + Status: solution_status.OPTIMAL, + Problem: nil, + } + + // Algorithm + _, err := solution.GetOptimalObjectiveValue(&tempSol) + if err == nil { + t.Errorf("Expected GetOptimalObjectiveValue to return an error for nil problem, but got nil") + } +} + +/* +TestSolution_GetOptimalObjectiveValue3 +Description: + + This function tests whether we can compute the objective value + for a constant objective function. +*/ +func TestSolution_GetOptimalObjectiveValue3(t *testing.T) { + // Constants + p := problem.NewProblem("TestProblem") + v1 := p.AddVariable() + + // Set constant objective: 42.0 + objectiveExpr := symbolic.K(42.0) + err := p.SetObjective(objectiveExpr, problem.SenseMaximize) + if err != nil { + t.Errorf("Failed to set objective: %v", err) + } + + // Create solution + tempSol := solution.DummySolution{ + Values: map[uint64]float64{ + v1.ID: 1.0, + }, + Objective: 42.0, + Status: solution_status.OPTIMAL, + Problem: p, + } + + // Algorithm + objectiveValue, err := solution.GetOptimalObjectiveValue(&tempSol) + if err != nil { + t.Errorf("GetOptimalObjectiveValue returned an error: %v", err) + } + + expected := 42.0 + if objectiveValue != expected { + t.Errorf( + "Expected objective value to be %v; received %v", + expected, + objectiveValue, + ) + } +} + +/* +TestSolution_GetOptimalObjectiveValue4 +Description: + + This function tests whether we can compute the objective value + for a more complex objective with multiple variables and operations. +*/ +func TestSolution_GetOptimalObjectiveValue4(t *testing.T) { + // Constants + p := problem.NewProblem("TestProblem") + v1 := p.AddVariable() + v2 := p.AddVariable() + v3 := p.AddVariable() + + // Set objective: (v1 + v2) * v3 + 5 + objectiveExpr := v1.Plus(v2).Multiply(v3).Plus(symbolic.K(5.0)) + err := p.SetObjective(objectiveExpr, problem.SenseMinimize) + if err != nil { + t.Errorf("Failed to set objective: %v", err) + } + + // Create solution with v1=1.0, v2=2.0, v3=3.0 + // Expected objective: (1.0 + 2.0) * 3.0 + 5 = 3.0 * 3.0 + 5 = 14.0 + tempSol := solution.DummySolution{ + Values: map[uint64]float64{ + v1.ID: 1.0, + v2.ID: 2.0, + v3.ID: 3.0, + }, + Objective: 14.0, + Status: solution_status.OPTIMAL, + Problem: p, + } + + // Algorithm + objectiveValue, err := solution.GetOptimalObjectiveValue(&tempSol) + if err != nil { + t.Errorf("GetOptimalObjectiveValue returned an error: %v", err) + } + + expected := 14.0 + if objectiveValue != expected { + t.Errorf( + "Expected objective value to be %v; received %v", + expected, + objectiveValue, + ) + } +}