Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions solution/solution.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ func ExtractValueOfVariable(s Solution, v symbolic.Variable) (float64, error) {

// FindValueOfExpression evaluates a symbolic expression using the values from a solution.
// It substitutes all variables in the expression with their values from the solution
// and returns the resulting scalar value.
func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error) {
// and returns the resulting symbolic expression (typically a constant).
func FindValueOfExpression(s Solution, expr symbolic.Expression) (symbolic.Expression, error) {
// Get all variables in the expression
vars := expr.Variables()

Expand All @@ -54,7 +54,7 @@ func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error
for _, v := range vars {
val, err := ExtractValueOfVariable(s, v)
if err != nil {
return 0.0, fmt.Errorf(
return nil, fmt.Errorf(
"failed to extract value for variable %v: %w",
v.ID,
err,
Expand All @@ -66,6 +66,31 @@ func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error
// Substitute all variables with their values
resultExpr := expr.SubstituteAccordingTo(subMap)

return resultExpr, nil
}

// GetOptimalObjectiveValue evaluates the objective function of an optimization problem
// at the solution point. It uses the FindValueOfExpression function to compute the value
// of the objective expression using the variable values from the solution.
func GetOptimalObjectiveValue(sol Solution) (float64, error) {
// Get the problem from the solution
prob := sol.GetProblem()
if prob == nil {
return 0.0, fmt.Errorf("solution does not have an associated problem")
}

// Get the objective expression from the problem
objectiveExpr := prob.Objective.Expression
if objectiveExpr == nil {
return 0.0, fmt.Errorf("problem does not have a defined objective")
}

// Use FindValueOfExpression to evaluate the objective at the solution point
resultExpr, err := FindValueOfExpression(sol, objectiveExpr)
if err != nil {
return 0.0, fmt.Errorf("failed to evaluate objective expression: %w", err)
}

// Type assert to K (constant) to extract the float64 value
resultK, ok := resultExpr.(symbolic.K)
if !ok {
Expand Down
195 changes: 191 additions & 4 deletions testing/solution/solution_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ Description:
(This seems like it is highly representative of the Gurobi solver; is there a reason to make it this way?)
*/

// Helper function to convert a symbolic.Expression to float64
func exprToFloat64(t *testing.T, expr symbolic.Expression) float64 {
resultK, ok := expr.(symbolic.K)
if !ok {
t.Fatalf("Expected result to be a constant, got type %T", expr)
}
return float64(resultK)
}

func TestSolution_ToMessage1(t *testing.T) {
// Constants
tempSol := solution.DummySolution{
Expand Down Expand Up @@ -161,11 +170,13 @@ func TestSolution_FindValueOfExpression1(t *testing.T) {
expr := v1.Multiply(symbolic.K(2.0)).Plus(v2.Multiply(symbolic.K(3.0)))

// Algorithm
result, err := solution.FindValueOfExpression(&tempSol, expr)
resultExpr, err := solution.FindValueOfExpression(&tempSol, expr)
if err != nil {
t.Errorf("FindValueOfExpression returned an error: %v", err)
}

result := exprToFloat64(t, resultExpr)

expected := 13.0
if result != expected {
t.Errorf(
Expand Down Expand Up @@ -194,11 +205,13 @@ func TestSolution_FindValueOfExpression2(t *testing.T) {
expr := symbolic.K(42.0)

// Algorithm
result, err := solution.FindValueOfExpression(&tempSol, expr)
resultExpr, err := solution.FindValueOfExpression(&tempSol, expr)
if err != nil {
t.Errorf("FindValueOfExpression returned an error: %v", err)
}

result := exprToFloat64(t, resultExpr)

expected := 42.0
if result != expected {
t.Errorf(
Expand Down Expand Up @@ -231,11 +244,13 @@ func TestSolution_FindValueOfExpression3(t *testing.T) {
expr := v1.Plus(symbolic.K(10.0))

// Algorithm
result, err := solution.FindValueOfExpression(&tempSol, expr)
resultExpr, err := solution.FindValueOfExpression(&tempSol, expr)
if err != nil {
t.Errorf("FindValueOfExpression returned an error: %v", err)
}

result := exprToFloat64(t, resultExpr)

expected := 15.5
if result != expected {
t.Errorf(
Expand Down Expand Up @@ -304,11 +319,13 @@ func TestSolution_FindValueOfExpression5(t *testing.T) {
expr := v1.Plus(v2).Multiply(v3).Plus(symbolic.K(5.0))

// Algorithm
result, err := solution.FindValueOfExpression(&tempSol, expr)
resultExpr, err := solution.FindValueOfExpression(&tempSol, expr)
if err != nil {
t.Errorf("FindValueOfExpression returned an error: %v", err)
}

result := exprToFloat64(t, resultExpr)

expected := 14.0
if result != expected {
t.Errorf(
Expand Down Expand Up @@ -377,3 +394,173 @@ func TestSolution_GetProblem2(t *testing.T) {
t.Errorf("Expected GetProblem to return nil when no problem is set")
}
}

/*
TestSolution_GetOptimalObjectiveValue1
Description:

This function tests whether we can compute the objective value at the solution point
for a simple linear objective.
*/
func TestSolution_GetOptimalObjectiveValue1(t *testing.T) {
// Constants
p := problem.NewProblem("TestProblem")
v1 := p.AddVariable()
v2 := p.AddVariable()

// Set objective: 2*v1 + 3*v2
objectiveExpr := v1.Multiply(symbolic.K(2.0)).Plus(v2.Multiply(symbolic.K(3.0)))
err := p.SetObjective(objectiveExpr, problem.SenseMinimize)
if err != nil {
t.Errorf("Failed to set objective: %v", err)
}

// Create solution with v1=2.0, v2=3.0
// Expected objective value: 2*2.0 + 3*3.0 = 4.0 + 9.0 = 13.0
tempSol := solution.DummySolution{
Values: map[uint64]float64{
v1.ID: 2.0,
v2.ID: 3.0,
},
Objective: 13.0,
Status: solution_status.OPTIMAL,
Problem: p,
}

// Algorithm
objectiveValue, err := solution.GetOptimalObjectiveValue(&tempSol)
if err != nil {
t.Errorf("GetOptimalObjectiveValue returned an error: %v", err)
}

expected := 13.0
if objectiveValue != expected {
t.Errorf(
"Expected objective value to be %v; received %v",
expected,
objectiveValue,
)
}
}

/*
TestSolution_GetOptimalObjectiveValue2
Description:

This function tests whether GetOptimalObjectiveValue returns an error
when the solution has no associated problem.
*/
func TestSolution_GetOptimalObjectiveValue2(t *testing.T) {
// Constants
v1 := symbolic.NewVariable()

tempSol := solution.DummySolution{
Values: map[uint64]float64{
v1.ID: 2.0,
},
Objective: 2.3,
Status: solution_status.OPTIMAL,
Problem: nil,
}

// Algorithm
_, err := solution.GetOptimalObjectiveValue(&tempSol)
if err == nil {
t.Errorf("Expected GetOptimalObjectiveValue to return an error for nil problem, but got nil")
}
}

/*
TestSolution_GetOptimalObjectiveValue3
Description:

This function tests whether we can compute the objective value
for a constant objective function.
*/
func TestSolution_GetOptimalObjectiveValue3(t *testing.T) {
// Constants
p := problem.NewProblem("TestProblem")
v1 := p.AddVariable()

// Set constant objective: 42.0
objectiveExpr := symbolic.K(42.0)
err := p.SetObjective(objectiveExpr, problem.SenseMaximize)
if err != nil {
t.Errorf("Failed to set objective: %v", err)
}

// Create solution
tempSol := solution.DummySolution{
Values: map[uint64]float64{
v1.ID: 1.0,
},
Objective: 42.0,
Status: solution_status.OPTIMAL,
Problem: p,
}

// Algorithm
objectiveValue, err := solution.GetOptimalObjectiveValue(&tempSol)
if err != nil {
t.Errorf("GetOptimalObjectiveValue returned an error: %v", err)
}

expected := 42.0
if objectiveValue != expected {
t.Errorf(
"Expected objective value to be %v; received %v",
expected,
objectiveValue,
)
}
}

/*
TestSolution_GetOptimalObjectiveValue4
Description:

This function tests whether we can compute the objective value
for a more complex objective with multiple variables and operations.
*/
func TestSolution_GetOptimalObjectiveValue4(t *testing.T) {
// Constants
p := problem.NewProblem("TestProblem")
v1 := p.AddVariable()
v2 := p.AddVariable()
v3 := p.AddVariable()

// Set objective: (v1 + v2) * v3 + 5
objectiveExpr := v1.Plus(v2).Multiply(v3).Plus(symbolic.K(5.0))
err := p.SetObjective(objectiveExpr, problem.SenseMinimize)
if err != nil {
t.Errorf("Failed to set objective: %v", err)
}

// Create solution with v1=1.0, v2=2.0, v3=3.0
// Expected objective: (1.0 + 2.0) * 3.0 + 5 = 3.0 * 3.0 + 5 = 14.0
tempSol := solution.DummySolution{
Values: map[uint64]float64{
v1.ID: 1.0,
v2.ID: 2.0,
v3.ID: 3.0,
},
Objective: 14.0,
Status: solution_status.OPTIMAL,
Problem: p,
}

// Algorithm
objectiveValue, err := solution.GetOptimalObjectiveValue(&tempSol)
if err != nil {
t.Errorf("GetOptimalObjectiveValue returned an error: %v", err)
}

expected := 14.0
if objectiveValue != expected {
t.Errorf(
"Expected objective value to be %v; received %v",
expected,
objectiveValue,
)
}
}