Source code for pymc_marketing.mmm.media_transformation
#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Module for applying media transformations to media data.
Examples
--------
Create a media transformation for online and offline media channels:
.. code-block:: python
    from pymc_marketing.mmm import (
        GeometricAdstock,
        HillSaturation,
        MediaTransformation,
        MichaelisMentenSaturation,
    )
    # Shared media transformation for all offline media channels
    offline_media_transform = MediaTransformation(
        adstock=GeometricAdstock(l_max=15),
        saturation=HillSaturation(),
        adstock_first=True,
    )
    # Shared media transformation for all online media channels
    online_media_transform = MediaTransformation(
            adstock=GeometricAdstock(l_max=10),
            saturation=MichaelisMentenSaturation(),
            adstock_first=False,
        ),
    )
Create a combined media configuration for offline and online media channels:
.. code-block:: python
    from pymc_marketing.mmm import (
        MediaConfig,
        MediaConfigList,
    )
    media_configs: MediaConfigList(
        [
            MediaConfig(
                name="offline",
                columns=["TV", "Radio"],
                media_transformation=offline_media_transform,
            ),
            MediaConfig(
                name="online",
                columns=["Facebook", "Instagram", "YouTube", "TikTok"],
                media_transformation=online_media_transform,
            ),
        ]
    )
Apply the media transformation to media data in PyMC model:
.. code-block:: python
    import pymc as pm
    import pandas as pd
    df: pd.DataFrame = ...
    media_columns = media_configs.media_values
    coords = {
        "date": df["week"],
        "media": media_columns,
    }
    with pm.Model(coords=coords) as model:
        media_data = pm.Data(
            "media_data", df.loc[:, media_columns].to_numpy(), dims=("date", "media")
        )
        transformed_media_data = media_configs(media_data)
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import cast
import pymc as pm
import pytensor.tensor as pt
from pymc.distributions.shape_utils import Dims
from pymc_extras.deserialize import register_deserialization
from pymc_marketing.mmm.components.adstock import (
    AdstockTransformation,
    adstock_from_dict,
)
from pymc_marketing.mmm.components.saturation import (
    SaturationTransformation,
    saturation_from_dict,
)
[docs]
@dataclass
class MediaTransformation:
    """Wrapper for applying adstock and saturation transformation to media data.
    Parameters
    ----------
    adstock : AdstockTransformation
        The adstock transformation to apply.
    saturation : SaturationTransformation
        The saturation transformation to apply.
    adstock_first : bool
        Flag to apply the adstock transformation first.
    dims : Dims
        The dimensions of the parameters.
    Attributes
    ----------
    first : AdstockTransformation | SaturationTransformation
        The first transformation to apply.
    second : AdstockTransformation | SaturationTransformation
        The second transformation to apply.
    """
    adstock: AdstockTransformation
    saturation: SaturationTransformation
    adstock_first: bool
    dims: Dims | None = None
    def __post_init__(self):
        """Set the first and second transformations based on the adstock_first flag."""
        self.first, self.second = (
            (self.adstock, self.saturation)
            if self.adstock_first
            else (self.saturation, self.adstock)
        )
        if isinstance(self.dims, str):
            self.dims = (self.dims,)
        self.dims = self.dims or ()
        self._check_compatible_dims()
    def _check_compatible_dims(self):
        self.dims = cast(Dims, self.dims)
        if not set(self.adstock.combined_dims).issubset(self.dims):
            raise ValueError(
                f"Adstock dimensions {self.adstock.combined_dims} are not a subset of {self.dims}"
            )
        if not set(self.saturation.combined_dims).issubset(self.dims):
            raise ValueError(
                f"Saturation dimensions {self.saturation.combined_dims} are not a subset of {self.dims}"
            )
    def __call__(self, x):
        """Apply adstock and saturation transformation to media data.
        Parameters
        ----------
        x : pt.TensorLike
            The media data to transform.
        dim : str
            The dimension of the parameters.
        Returns
        -------
        pt.TensorVariable
            The transformed media data.
        Examples
        --------
        Apply the media transformation to media data:
        .. code-block:: python
            from pymc_marketing.mmm import (
                GeometricAdstock,
                HillSaturation,
                MediaTransformation,
            )
            media_data = ...
            media_transformation = MediaTransformation(
                adstock=GeometricAdstock(l_max=15),
                saturation=HillSaturation(),
                adstock_first=True,
            )
            coords = {
                "date": ...,
                "media": ...,
            }
            with pm.Model(coords=coords) as model:
                transformed_media_data = media_transformation(
                    media_data,
                    dim="media",
                )
        """
        return self.second.apply(self.first.apply(x, self.dims), self.dims)
[docs]
    def to_dict(self) -> dict:
        """Convert the media transformation to a dictionary.
        Returns
        -------
        dict
            The media transformation as a dictionary.
        """
        return {
            "adstock": self.adstock.to_dict(),
            "saturation": self.saturation.to_dict(),
            "adstock_first": self.adstock_first,
            "dims": self.dims,
        }
[docs]
    @classmethod
    def from_dict(cls, data) -> MediaTransformation:
        """Create a media transformation from a dictionary.
        Parameters
        ----------
        data : dict
            The data to create the media transformation from.
        Returns
        -------
        MediaTransformation
            The media transformation created from the dictionary.
        """
        return cls(
            adstock=adstock_from_dict(data["adstock"]),
            saturation=saturation_from_dict(data["saturation"]),
            adstock_first=data["adstock_first"],
            dims=data.get("dims"),
        )
def _is_media_transformation(data):
    return (
        isinstance(data, dict)
        and "adstock" in data
        and "saturation" in data
        and "adstock_first" in data
    )
register_deserialization(
    is_type=_is_media_transformation,
    deserialize=MediaTransformation.from_dict,
)
[docs]
@dataclass
class MediaConfig:
    """Configuration for a media transformation to certain media channels.
    Parameters
    ----------
    name : str
        The name of the media transformation and prefix of all media variables.
    columns : list[str]
        The media channels to apply the transformation to.
    media_transformation : MediaTransformation
        The media transformation to apply to the media channels.
    """
    name: str
    columns: list[str]
    media_transformation: MediaTransformation
[docs]
    def to_dict(self) -> dict:
        """Convert the media configuration to a dictionary.
        Returns
        -------
        dict
            The media configuration as a dictionary.
        """
        return {
            "name": self.name,
            "columns": self.columns,
            "media_transformation": self.media_transformation.to_dict(),
        }
[docs]
    @classmethod
    def from_dict(cls, data) -> MediaConfig:
        """Create a media configuration from a dictionary.
        Parameters
        ----------
        data : dict
            The data to create the media configuration from.
        Returns
        -------
        MediaConfig
            The media configuration created from the dictionary.
        """
        return cls(
            name=data["name"],
            columns=data["columns"],
            media_transformation=MediaTransformation.from_dict(
                data["media_transformation"]
            ),
        )
def _is_media_config(data):
    return (
        isinstance(data, dict)
        and "name" in data
        and "columns" in data
        and "media_transformation" in data
        and _is_media_transformation(data["media_transformation"])
    )
[docs]
class MediaConfigList:
    """Wrapper for a list of media configurations to apply to media data.
    Parameters
    ----------
    media_configs : list[MediaConfig]
        The media configurations to apply to the media data.
    Examples
    --------
    Different order of media transformations for online and offline media channels:
    .. code-block:: python
        from pymc_marketing.mmm import (
            GeometricAdstock,
            LogisticSaturation,
            MediaTransformation,
            MediaConfig,
            MediaConfigList,
        )
        online = MediaConfig(
            name="online",
            columns=["Facebook", "Instagram", "YouTube", "TikTok"],
            media_transformation=MediaTransformation(
                adstock=GeometricAdstock(l_max=10).set_dims_for_all_priors("online"),
                saturation=LogisticSaturation().set_dims_for_all_priors("online"),
                adstock_first=True,
            ),
        )
        offline = MediaConfig(
            name="offline",
            columns=["TV", "Radio"],
            media_transformation=MediaTransformation(
                adstock=GeometricAdstock(
                    l_max=10,
                ).set_dims_for_all_priors("offline"),
                saturation=LogisticSaturation().set_dims_for_all_priors("offline"),
                adstock_first=False,
            ),
        )
        media_configs = MediaConfigList([online, offline])
    """
[docs]
    def __init__(self, media_configs: list[MediaConfig]) -> None:
        self.media_configs = media_configs
    def __eq__(self, other) -> bool:
        """Check if the media configuration lists are equal.
        Parameters
        ----------
        other : MediaConfigList
            The other media configuration list to compare.
        Returns
        -------
        bool
            True if the media configuration lists are equal, False otherwise.
        """
        return self.media_configs == other.media_configs
    def __getitem__(self, key: int) -> MediaConfig:
        """Get the media configuration at the specified index.
        Parameters
        ----------
        key : int
            The index of the media configuration to get.
        Returns
        -------
        MediaConfig
            The media configuration at the specified index.
        """
        return self.media_configs[key]
    @property
    def media_values(self) -> list[str]:
        """Get the media values from the media configurations.
        Returns
        -------
        list[str]
            The media values from the media configurations in the order they appear.
        """
        result = []
        for config in self.media_configs:
            result.extend(config.columns)
        return result
[docs]
    def to_dict(self) -> list[dict]:
        """Convert the media configuration list to a dictionary.
        Returns
        -------
        list[dict]
            The media configuration list as a dictionary.
        """
        return [config.to_dict() for config in self.media_configs]
[docs]
    @classmethod
    def from_dict(cls, data: list[dict]) -> MediaConfigList:
        """Create a media configuration list from a dictionary.
        Parameters
        ----------
        data : list[dict]
            The data to create the media configuration list from.
        Returns
        -------
        MediaConfigList
            The media configuration list created from the dictionary.
        """
        return cls([MediaConfig.from_dict(config) for config in data])
    def __call__(self, x) -> pt.TensorVariable:
        """Apply media transformation to media data.
        Assumes that the columns in the data correspond to the media channels
        in the media_configs.
        Parameters
        ----------
        x : pt.TensorLike
            The media data to transform.
        Returns
        -------
        pt.TensorVariable
            The transformed media data.
        """
        model = pm.modelcontext(None)
        transformed_data = []
        start_idx = 0
        for config in self.media_configs:
            config.media_transformation.dims = config.name
            model.add_coord(config.name, config.columns)
            end_idx = start_idx + len(config.columns)
            media_data = x[:, start_idx:end_idx]
            adstock = config.media_transformation.adstock
            saturation = config.media_transformation.saturation
            adstock.prefix = f"{config.name}_{adstock.prefix}"
            saturation.prefix = f"{config.name}_{saturation.prefix}"
            media_transformation_data = config.media_transformation(
                media_data,
            )
            transformed_data.append(media_transformation_data)
            start_idx = end_idx
        return pt.concatenate(transformed_data, axis=1)
def _is_media_config_list(data):
    return isinstance(data, list) and all(_is_media_config(config) for config in data)
register_deserialization(
    is_type=_is_media_config_list,
    deserialize=MediaConfigList.from_dict,
)