Source code for pymc_marketing.mmm.constraints
#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Constraints for the BudgetOptimizer."""
from collections.abc import Callable
from typing import Any, Literal
import pytensor.tensor as pt
from pymc.pytensorf import rewrite_pregrad
from pytensor import function
[docs]
class Constraint:
    """
    Represents a constraint for the BudgetOptimizer.
    Attributes
    ----------
        key (str): Identifier for the constraint.
        constraint_type (Literal["eq", "ineq"]): Type of the constraint ("eq" for equality, "ineq" for inequality).
        constraint_fun (Callable[[pt.TensorVariable, pt.TensorVariable, Any], pt.TensorVariable]):
            Function that computes the symbolic constraint, taking `budgets_sym`, `total_budget_sym`, and `optimizer`.
    """
[docs]
    def __init__(
        self,
        key: str,
        constraint_type: Literal["eq", "ineq"],
        constraint_fun: Callable[
            [pt.TensorVariable, pt.TensorVariable, Any], pt.TensorVariable
        ],
    ):
        self.key = key
        self.constraint_type = constraint_type
        self.constraint_fun = constraint_fun 
 
[docs]
def build_default_sum_constraint(key: str = "default") -> Constraint:
    """Return a Constraint enforcing sum(budgets) == total_budget."""
    def _constraint_fun(
        budgets_sym: pt.TensorVariable, total_budget_sym: pt.TensorVariable, optimizer
    ) -> pt.TensorVariable:
        return pt.sum(budgets_sym) - total_budget_sym
    return Constraint(
        key=key,
        constraint_type="eq",
        constraint_fun=_constraint_fun,
    ) 
[docs]
def compile_constraints_for_scipy(constraints: list[Constraint] | dict, optimizer):
    """Compile constraints for scipy."""
    compiled_constraints = []
    budgets = optimizer._budgets
    budgets_flat = optimizer._budgets_flat
    total_budget = optimizer._total_budget
    if isinstance(constraints, dict):
        constraints = list(constraints.values())
    if not constraints:
        raise ValueError("No constraints provided for compilation.")
    for constraint in constraints:
        if not isinstance(constraint, Constraint):
            raise TypeError(
                f"Expected an instance of Constraint, but received {type(constraint)}. "
                "Ensure all constraints are created using the Constraint class."
            )
        # Pass the required arguments to constraint_fun
        constraint_fun_output = constraint.constraint_fun(
            budgets, total_budget, optimizer
        )
        sym_jac_output = pt.grad(rewrite_pregrad(constraint_fun_output), budgets_flat)
        # Compile symbolic => python callables
        compiled_fun = function(
            inputs=[budgets_flat],
            outputs=constraint_fun_output,
            on_unused_input="ignore",
        )
        compiled_jac = function(
            inputs=[budgets_flat],
            outputs=sym_jac_output,
            on_unused_input="ignore",
        )
        compiled_constraints.append(
            {
                "type": constraint.constraint_type,
                "fun": compiled_fun,
                "jac": compiled_jac,
            }
        )
    return compiled_constraints