diff --git a/doc/release_notes.rst b/doc/release_notes.rst index 730d3e04..1583a4fd 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -3,6 +3,7 @@ Release Notes .. Upcoming Version +* Allow constant values in objective cost function * Add support for SOS1 and SOS2 (Special Ordered Sets) constraints via ``Model.add_sos_constraints()`` and ``Model.remove_sos_constraints()`` * Add simplify method to LinearExpression to combine duplicate terms * Add convenience function to create LinearExpression from constant diff --git a/linopy/expressions.py b/linopy/expressions.py index 10e243de..008f5b82 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -1097,6 +1097,9 @@ def empty(self) -> EmptyDeprecationWrapper: """ return EmptyDeprecationWrapper(not self.size) + def drop_constant(self: GenericExpression) -> GenericExpression: + return self - self.const # type: ignore + def densify_terms(self: GenericExpression) -> GenericExpression: """ Move all non-zero term entries to the front and cut off all-zero diff --git a/linopy/model.py b/linopy/model.py index 81c069ab..fec3855f 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -10,9 +10,11 @@ import os import re from collections.abc import Callable, Mapping, Sequence +from functools import wraps from pathlib import Path from tempfile import NamedTemporaryFile, gettempdir -from typing import Any, Literal, overload +from typing import Any, Literal, ParamSpec, TypeVar, overload +from warnings import warn import numpy as np import pandas as pd @@ -77,6 +79,67 @@ logger = logging.getLogger(__name__) +P = ParamSpec("P") +R = TypeVar("R") + + +class ConstantInObjectiveWarning(UserWarning): ... + + +class ConstantObjectiveError(Exception): ... + + +def strip_and_replace_constant_objective(func: Callable[P, R]) -> Callable[P, R]: + """ + Decorates a Model instance method. + + If the model objective contains a constant term, this decorator will: + - Remove the constant term from the model objective + - Call the decorated method + - Add the constant term back to the model objective + """ + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + assert args, "Expected at least one argument (self)" + self = args[0] + assert isinstance(self, Model), ( + f"First argument must be a Model instance, got {type(self)}" + ) + model = self + if not self.objective.has_constant: + # Continue as normal if there is no constant term + return func(*args, **kwargs) + + # The objective contains a constant term + if not model.allow_constant_objective: + raise ConstantObjectiveError( + "Objective function contains constant terms. Please use LinearExpression.drop_constants()/QuadraticExpression.drop_constants() or set Model.allow_constant_objective=True." + ) + + # Modify the model objective to drop the constant term + model = self + constant = float(self.objective.expression.const.values) + model.objective.expression = self.objective.expression.drop_constant() + args = (model, *args[1:]) # type: ignore + + try: + result = func(*args, **kwargs) + except Exception as e: + # Even if there is an exception, make sure the model returns to it's original state + model.objective.expression = model.objective.expression + constant + raise e + + # Re-add the constant term to return the model objective to the original expression + model.objective.expression = model.objective.expression + constant + if model.objective.value is not None: + model.objective.set_value(model.objective.value + constant) + + return result + + return wrapper + + class Model: """ Linear optimization model. @@ -103,6 +166,7 @@ class Model: _dual: Dataset _status: str _termination_condition: str + _allow_constant_objective: bool _xCounter: int _cCounter: int _varnameCounter: int @@ -124,6 +188,7 @@ class Model: # hidden attributes "_status", "_termination_condition", + "_allow_constant_objective", # TODO: move counters to Variables and Constraints class "_xCounter", "_cCounter", @@ -175,6 +240,7 @@ def __init__( self._status: str = "initialized" self._termination_condition: str = "" + self._allow_constant_objective: bool = False self._xCounter: int = 0 self._cCounter: int = 0 self._varnameCounter: int = 0 @@ -727,6 +793,17 @@ def add_constraints( self.constraints.add(constraint) return constraint + @property + def allow_constant_objective(self) -> bool: + """ + Whether constant terms in the objective function are allowed. + """ + return self._allow_constant_objective + + @allow_constant_objective.setter + def allow_constant_objective(self, allow: bool) -> None: + self._allow_constant_objective = allow + def add_objective( self, expr: Variable @@ -748,7 +825,7 @@ def add_objective( Returns ------- - linopy.LinearExpression + linopy.LinearExpression, linopy.QuadraticExpression The objective function assigned to the model. """ if not overwrite: @@ -758,8 +835,14 @@ def add_objective( ) if isinstance(expr, Variable): expr = 1 * expr + self.objective.expression = expr self.objective.sense = sense + if not self.allow_constant_objective and self.objective.has_constant: + warn( + "Objective function contains constant terms but this is not allowed as Model.allow_constant_objective=False, running solve will result in an error. Please either remove constants from the expression with expr.drop_constants() or set Model.allow_constant_objective=True.", + ConstantInObjectiveWarning, + ) def remove_variables(self, name: str) -> None: """ @@ -1107,6 +1190,7 @@ def get_problem_file( ) as f: return Path(f.name) + @strip_and_replace_constant_objective def solve( self, solver_name: str | None = None, diff --git a/linopy/objective.py b/linopy/objective.py index b1449270..11e00b49 100644 --- a/linopy/objective.py +++ b/linopy/objective.py @@ -189,11 +189,15 @@ def expression( if len(expr.coord_dims): expr = expr.sum() - if (expr.const != 0.0) and not np.isnan(expr.const): - raise ValueError("Constant values in objective function not supported.") - self._expression = expr + @property + def has_constant(self) -> bool: + """ + Returns whether the objective has a constant term. + """ + return bool(self.expression.has_constant) + @property def model(self) -> Model: """ diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index a75ace3f..dfed5f7c 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -1230,6 +1230,21 @@ def test_cumsum(m: Model, multiple: float) -> None: cumsum.nterm == 2 +def test_drop_constant(x: Variable) -> None: + """Test that constants are removed""" + expr_a = 2 * x + expr_b = expr_a + [1, 2] + expr_c = expr_b + float("nan") + for expr in [expr_a, expr_b, expr_c]: + expr = 2 * x + 10 + expr_2 = expr.drop_constant() + + assert all(expr_2.const.values == 0.0), ( + f"Expected constant 0.0, got {expr_2.const.values}" + ) + assert not bool(expr_2.has_constant) + + def test_simplify_basic(x: Variable) -> None: """Test basic simplification with duplicate terms.""" expr = 2 * x + 3 * x + 1 * x diff --git a/test/test_model.py b/test/test_model.py index c363fe4c..063b70dd 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -82,9 +82,8 @@ def test_objective() -> None: assert m.objectiverange.min() == 2 assert m.objectiverange.max() == 2 - # test objective with constant which is not supported - with pytest.raises(ValueError): - m.objective = m.objective + 3 + # test objective with constant which is supported + m.objective = m.objective + 3 def test_remove_variable() -> None: diff --git a/test/test_objective.py b/test/test_objective.py index d869175a..80b2021a 100644 --- a/test/test_objective.py +++ b/test/test_objective.py @@ -192,5 +192,4 @@ def test_repr(linear_objective: Objective, quadratic_objective: Objective) -> No def test_objective_constant() -> None: m = Model() linear_expr = LinearExpression(None, m) + 1 - with pytest.raises(ValueError): - m.objective = Objective(linear_expr, m) + m.objective = Objective(linear_expr, m) diff --git a/test/test_optimization.py b/test/test_optimization.py index 12399a4e..4d0cbc06 100644 --- a/test/test_optimization.py +++ b/test/test_optimization.py @@ -20,6 +20,7 @@ from linopy import GREATER_EQUAL, LESS_EQUAL, Model, solvers from linopy.common import to_path from linopy.expressions import LinearExpression +from linopy.model import ConstantInObjectiveWarning, ConstantObjectiveError from linopy.solver_capabilities import ( SolverFeature, get_available_solvers_with_feature, @@ -955,6 +956,54 @@ def test_model_resolve( assert np.isclose(model.objective.value or 0, 5.25) +def test_model_with_constant_in_objective_feasible(model: Model) -> None: + objective = model.objective.expression + 1 + + with pytest.warns(ConstantInObjectiveWarning): + model.add_objective(expr=objective, overwrite=True) + + with pytest.raises(ConstantObjectiveError): + status, _ = model.solve(solver_name="highs") + + model.allow_constant_objective = True + status, _ = model.solve(solver_name="highs") + assert status == "ok" + # x = -0.1, y = 1.7 + assert model.objective.value == 4.3 + assert model.objective.expression.const == 1 + assert model.objective.expression.solution == 4.3 + + +def test_model_with_constant_in_objective_infeasible(model: Model) -> None: + objective = model.objective.expression + 1 + model.add_objective(expr=objective, overwrite=True) + model.add_constraints([(1, "x")], "<=", 0) + model.add_constraints([(1, "y")], "<=", 0) + + model.allow_constant_objective = True + _, condition = model.solve(solver_name="highs") + + assert condition == "infeasible" + # Even though the problem was not solved, the constant term should still be accessible + assert model.objective.expression.const == 1 + + +def test_model_with_constant_in_objective_error(model: Model) -> None: + objective = model.objective.expression + 1 + model.allow_constant_objective = True + model.add_objective(expr=objective, overwrite=True) + model.add_constraints([(1, "x")], "<=", 0) + model.add_constraints([(1, "y")], "<=", 0) + + try: + _ = model.solve(solver_name="apples") + except AssertionError: + pass + + # Even if something goes wrong, the model objective should return to the correct state + assert model.objective.expression.const == 1 + + @pytest.mark.parametrize( "solver,io_api,explicit_coordinate_names", [p for p in params if "direct" not in p] ) diff --git a/test/test_quadratic_expression.py b/test/test_quadratic_expression.py index fc1bb25f..a1904e24 100644 --- a/test/test_quadratic_expression.py +++ b/test/test_quadratic_expression.py @@ -312,6 +312,20 @@ def test_quadratic_expression_constant_to_polars() -> None: assert all(arr.to_numpy() == df["const"].to_numpy()) +def test_drop_constant(x: Variable) -> None: + """Test that constants are removed""" + expr_a = 2 * x * x + expr_b = expr_a + 1 + for expr in [expr_a, expr_b]: + expr = 2 * x + 10 + expr_2 = expr.drop_constant() + + assert all(expr_2.const.values == 0.0), ( + f"Expected constant 0.0, got {expr_2.const.values}" + ) + assert not bool(expr_2.has_constant) + + def test_quadratic_expression_to_matrix(model: Model, x: Variable, y: Variable) -> None: expr: QuadraticExpression = x * y + x + 5 # type: ignore