plot#
MMM related plotting class.
Examples#
Quickstart with MMM:
from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation
from pymc_marketing.mmm.multidimensional import MMM
import pandas as pd
# Minimal dataset
X = pd.DataFrame(
    {
        "date": pd.date_range("2025-01-01", periods=12, freq="W-MON"),
        "C1": [100, 120, 90, 110, 105, 115, 98, 102, 108, 111, 97, 109],
        "C2": [80, 70, 95, 85, 90, 88, 92, 94, 91, 89, 93, 87],
    }
)
y = pd.Series(
    [230, 260, 220, 240, 245, 255, 235, 238, 242, 246, 233, 249], name="y"
)
mmm = MMM(
    date_column="date",
    channel_columns=["C1", "C2"],
    target_column="y",
    adstock=GeometricAdstock(l_max=10),
    saturation=LogisticSaturation(),
)
mmm.fit(X, y)
mmm.sample_posterior_predictive(X)
# Posterior predictive time series
_ = mmm.plot.posterior_predictive(var=["y"], hdi_prob=0.9)
# Posterior contributions over time (e.g., channel_contribution)
_ = mmm.plot.contributions_over_time(var=["channel_contribution"], hdi_prob=0.9)
# Channel saturation scatter plot (scaled space by default)
_ = mmm.plot.saturation_scatterplot(original_scale=False)
Wrap a custom PyMC model#
Requirements
posterior_predictive plots: an
az.InferenceDatawith aposterior_predictivegroup containing the variable(s) you want to plot with adatecoordinate.contributions_over_time plots: a
posteriorgroup with time‑series variables (withdate).saturation plots: a
constant_datadataset with variables: -channel_data: dims include("date", "channel", ...)-channel_scale: dims include("channel", ...)-target_scale: scalar or broadcastable to the curve dims and aposteriorvariable namedchannel_contribution(orchannel_contribution_original_scaleif plottingoriginal_scale=True).
import numpy as np
import pandas as pd
import pymc as pm
from pymc_marketing.mmm.plot import MMMPlotSuite
dates = pd.date_range("2025-01-01", periods=30, freq="D")
y_obs = np.random.normal(size=len(dates))
with pm.Model(coords={"date": dates}):
    sigma = pm.HalfNormal("sigma", 1.0)
    pm.Normal("y", 0.0, sigma, observed=y_obs, dims="date")
    idata = pm.sample_prior_predictive(random_seed=1)
    idata.extend(pm.sample(draws=200, chains=2, tune=200, random_seed=1))
    idata.extend(pm.sample_posterior_predictive(idata, random_seed=1))
plot = MMMPlotSuite(idata)
_ = plot.posterior_predictive(var=["y"], hdi_prob=0.9)
Custom contributions_over_time#
import numpy as np
import pandas as pd
import pymc as pm
from pymc_marketing.mmm.plot import MMMPlotSuite
dates = pd.date_range("2025-01-01", periods=30, freq="D")
x = np.linspace(0, 2 * np.pi, len(dates))
series = np.sin(x)
with pm.Model(coords={"date": dates}):
    pm.Deterministic("component", series, dims="date")
    idata = pm.sample_prior_predictive(random_seed=2)
    idata.extend(pm.sample(draws=50, chains=1, tune=0, random_seed=2))
plot = MMMPlotSuite(idata)
_ = plot.contributions_over_time(var=["component"], hdi_prob=0.9)
Saturation plots with a custom model#
import numpy as np
import pandas as pd
import xarray as xr
import pymc as pm
from pymc_marketing.mmm.plot import MMMPlotSuite
dates = pd.date_range("2025-01-01", periods=20, freq="W-MON")
channels = ["C1", "C2"]
# Create constant_data required for saturation plots
channel_data = xr.DataArray(
    np.random.rand(len(dates), len(channels)),
    dims=("date", "channel"),
    coords={"date": dates, "channel": channels},
    name="channel_data",
)
channel_scale = xr.DataArray(
    np.ones(len(channels)),
    dims=("channel",),
    coords={"channel": channels},
    name="channel_scale",
)
target_scale = xr.DataArray(1.0, name="target_scale")
# Build a toy model that yields a matching posterior var
with pm.Model(coords={"date": dates, "channel": channels}):
    # A fake contribution over time per channel (dims must include date & channel)
    contrib = pm.Normal("channel_contribution", 0.0, 1.0, dims=("date", "channel"))
    idata = pm.sample_prior_predictive(random_seed=3)
    idata.extend(pm.sample(draws=50, chains=1, tune=0, random_seed=3))
# Attach constant_data to idata
idata.constant_data = xr.Dataset(
    {
        "channel_data": channel_data,
        "channel_scale": channel_scale,
        "target_scale": target_scale,
    }
)
plot = MMMPlotSuite(idata)
_ = plot.saturation_scatterplot(original_scale=False)
Notes#
MMMexposes this suite via themmm.plotproperty, which internally passes the model’sidataintoMMMPlotSuite.Any PyMC model can use
MMMPlotSuitedirectly if itsInferenceDatacontains the needed groups/variables described above.
Classes
  | 
Media Mix Model Plot Suite.  |