Source code for pymc_marketing.model_graph
#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Functions to manipulate PyMC models as graphs."""
import pymc as pm
from pymc.model.fgraph import (
    extract_dims,
    fgraph_from_model,
    model_free_rv,
    model_from_fgraph,
)
from pymc.pytensorf import toposort_replace
from pytensor.graph import rewrite_graph
from pytensor.tensor.basic import infer_shape_db
[docs]
def deterministics_to_flat(model: pm.Model, names: list[str]) -> pm.Model:
    """Replace all specified Deterministic nodes in a pm.Model with Flat.
    This is useful to capture some state from a model and to then sample from
    the model using that state. For example, capturing the mean of a distribution
    or a value of a deterministic variable.
    See :class:`pymc_marketing.mmm.hsgp.SoftPlusHSGP` for an example of how this
    is used to keep a variable centered around 1.0 during sampling but stay continuous
    with new values.
    Parameters
    ----------
    model : pm.Model
        PyMC model to be transformed
    names : list[str]
        Names of the deterministic variables to be replaced by flat
    Returns
    -------
    new_model : pm.Model
        New model with all priors replaced by flat priors
    Examples
    --------
    Replace single Deterministic with Flat and sample as if it were zeros.
    .. code-block:: python
        import pymc as pm
        import numpy as np
        import xarray as xr
        from pymc_marketing.model_graph import deterministics_to_flat
        with pm.Model() as model:
            x = pm.Normal("x", mu=0, sigma=1)
            y = pm.Deterministic("y", x**2)
            z = pm.Deterministic("z", x + y)
        new_model = deterministics_to_flat(model, ["y"])
        chains, draws = 2, 100
        mock_posterior = xr.Dataset(
            {
                "y": (("chain", "draw"), np.zeros((chains, draws))),
            },
            coords={"chain": np.arange(chains), "draw": np.arange(draws)},
        )
        x_z_given_y = pm.sample_posterior_predictive(
            mock_posterior,
            model=new_model,
            var_names=["x", "z"],
        ).posterior_predictive
        np.testing.assert_allclose(
            x_z_given_y["x"],
            x_z_given_y["z"],
        )
    """
    fg, memo = fgraph_from_model(model, inlined_views=True)
    model_variables = [x for x in set(model.deterministics) if x.name in names]
    replacements = {}
    for variable in model_variables:
        model_var = memo[variable]
        dims = extract_dims(model_var)
        new_rv = pm.Flat.dist(shape=model_var.shape)
        new_rv.name = model_var.name
        replacements[model_var] = model_free_rv(
            new_rv,
            new_rv.type(name=model_var.name),
            None,
            *dims,
        )
    toposort_replace(fg, replacements=tuple(replacements.items()))
    fg = rewrite_graph(
        fg,
        include=("ShapeOpt",),
        custom_rewrite=infer_shape_db.default_query,
        clone=False,
    )
    new_model = model_from_fgraph(fg, mutate_fgraph=True)
    return new_model