Source code for pymc_marketing.mmm.sensitivity_analysis
#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Counterfactual sweeps for Marketing Mix Models (MMM)."""
from typing import Literal
import numpy as np
import pandas as pd
import xarray as xr
[docs]
class SensitivityAnalysis:
    """SensitivityAnalysis class is used to perform counterfactual analysis on MMM's."""
[docs]
    def __init__(self, mmm) -> None:
        """
        Initialize the SensitivityAnalysis with a reference to the MMM instance.
        Parameters
        ----------
        mmm : MMM
            The marketing mix model instance used for predictions.
        """
        self.mmm = mmm 
[docs]
    def run_sweep(
        self,
        var_names: list[str],
        sweep_values: np.ndarray,
        sweep_type: Literal[
            "multiplicative", "additive", "absolute"
        ] = "multiplicative",
    ) -> xr.Dataset:
        """Run the model's predict function over the sweep grid and store results.
        Parameters
        ----------
        var_names : list[str]
            List of variable names to intervene on.
        sweep_values : np.ndarray
            Array of sweep values.
        sweep_type : Literal["multiplicative", "additive", "absolute"], optional
            Type of intervention to apply, by default "multiplicative".
            - 'multiplicative': Multiply the original predictor values by each sweep value.
            - 'additive': Add each sweep value to the original predictor values.
            - 'absolute': Set the predictor values directly to each sweep value (ignoring original values).
        Returns
        -------
        xr.Dataset
            Dataset containing the sensitivity analysis results.
        """
        # Validate that idata exists
        if not hasattr(self.mmm, "idata"):
            raise ValueError("idata does not exist. Build the model first and fit.")
        # Store parameters for this run
        self.var_names = var_names
        self.sweep_values = sweep_values
        self.sweep_type = sweep_type
        # TODO: Ideally we can use this --------------------------------------------
        # actual = self.mmm._get_group_predictive_data(
        #     group="posterior_predictive", original_scale=True
        # )["y"]
        actual = self.mmm.idata["posterior_predictive"]["y"]
        # --------------------------------------------------------------------------
        predictions = []
        for sweep_value in self.sweep_values:
            X_new = self.create_intervention(sweep_value)
            counterfac = self.mmm.sample_posterior_predictive(
                X_new, extend_idata=False, combined=False, progressbar=False
            )
            uplift = counterfac - actual
            predictions.append(uplift)
        results = (
            xr.concat(predictions, dim="sweep")
            .assign_coords(sweep=self.sweep_values)
            .transpose(..., "sweep")
        )
        marginal_effects = self.compute_marginal_effects(results, self.sweep_values)
        results = xr.merge(
            [
                results,
                marginal_effects.rename({"y": "marginal_effects"}),
            ]
        ).transpose(..., "sweep")
        # Add metadata to the results
        results.attrs["sweep_type"] = self.sweep_type
        results.attrs["var_names"] = self.var_names
        # Add results to the MMM's idata
        if hasattr(self.mmm.idata, "sensitivity_analysis"):
            delattr(self.mmm.idata, "sensitivity_analysis")
        self.mmm.idata.add_groups({"sensitivity_analysis": results})  # type: ignore
        return results 
[docs]
    def create_intervention(self, sweep_value: float) -> pd.DataFrame:
        """Apply the intervention to the predictors."""
        X_new = self.mmm.X.copy()
        if self.sweep_type == "multiplicative":
            for var_name in self.var_names:
                X_new[var_name] *= sweep_value
        elif self.sweep_type == "additive":
            for var_name in self.var_names:
                X_new[var_name] += sweep_value
        elif self.sweep_type == "absolute":
            for var_name in self.var_names:
                X_new[var_name] = sweep_value
        else:
            raise ValueError(f"Unsupported sweep_type: {self.sweep_type}")
        return X_new 
[docs]
    @staticmethod
    def compute_marginal_effects(results, sweep_values) -> xr.DataArray:
        """Compute marginal effects via finite differences from the sweep results."""
        marginal_effects = results.differentiate(coord="sweep")
        return marginal_effects