#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""MLflow logging utilities for PyMC models.
This module provides utilities to log various aspects of PyMC models to MLflow
which is then extended to PyMC-Marketing models.
Autologging is supported for PyMC models and PyMC-Marketing models. This including
logging of sampler diagnostics, model information, data used in the model, and
InferenceData objects.
The autologging can be enabled by calling the `autolog` function. The following functions
are patched:
- `pymc.sample`:
    - :func:`log_versions`: Log the versions of PyMC-Marketing, PyMC, and ArviZ to MLflow.
    - :func:`log_model_derived_info`: Log types of parameters, coords, model graph, etc.
    - :func:`log_sample_diagnostics`: Log information derived from the InferenceData object.
    - :func:`log_arviz_summary`: Log table of summary statistics about estimated parameters
    - :func:`log_metadata`: Log the metadata of the data used in the model.
    - :func:`log_error`: Log the traceback and exception if an error occurs during sampling.
- `pymc.find_MAP`:
    - :func:`log_model_derived_info`: Log types of parameters, coords, model graph, etc.
- `MMM.fit`:
    - All parameters, metrics, and artifacts from `pymc.sample`
    - :func:`log_mmm_configuration`: Log the configuration of the MMM model.
- `CLVModel.fit`:
    - Information dependent on fit method used (MCMC or MAP)
    - Model type and fit method
Examples
--------
Autologging for a PyMC model:
.. code-block:: python
    import mlflow
    import pymc as pm
    import pymc_marketing.mlflow
    pymc_marketing.mlflow.autolog()
    # Usual PyMC model code
    with pm.Model() as model:
        mu = pm.Normal("mu", mu=0, sigma=1)
        obs = pm.Normal("obs", mu=mu, sigma=1, observed=[1, 2, 3])
    # Incorporate into MLflow workflow
    mlflow.set_experiment("PyMC Experiment")
    with mlflow.start_run():
        idata = pm.sample(model=model)
Autologging for a PyMC-Marketing MMM:
.. code-block:: python
    import pandas as pd
    import mlflow
    from pymc_marketing.mmm import (
        GeometricAdstock,
        LogisticSaturation,
        MMM,
    )
    from pymc_marketing.paths import data_dir
    import pymc_marketing.mlflow
    pymc_marketing.mlflow.autolog(log_mmm=True)
    # Usual PyMC-Marketing model code
    file_path = data_dir / "mmm_example.csv"
    data = pd.read_csv(file_path, parse_dates=["date_week"])
    X = data.drop("y", axis=1)
    y = data["y"]
    mmm = MMM(
        adstock=GeometricAdstock(l_max=8),
        saturation=LogisticSaturation(),
        date_column="date_week",
        channel_columns=["x1", "x2"],
        control_columns=[
            "event_1",
            "event_2",
            "t",
        ],
        yearly_seasonality=2,
    )
    # Incorporate into MLflow workflow
    mlflow.set_experiment("MMM Experiment")
    with mlflow.start_run():
        idata = mmm.fit(X, y)
        # Additional specific logging
        fig = mmm.plot_components_contributions()
        mlflow.log_figure(fig, "components.png")
Autologging for a PyMC-Marketing CLV model:
.. code-block:: python
    import pandas as pd
    import mlflow
    from pymc_marketing.clv import BetaGeoModel
    from pymc_marketing.paths import data_dir
    import pymc_marketing.mlflow
    pymc_marketing.mlflow.autolog(log_clv=True)
    mlflow.set_experiment("CLV Experiment")
    file_path = data_dir / "clv_quickstart.csv"
    data = pd.read_csv(file_path)
    data["customer_id"] = data.index
    model = BetaGeoModel(data=data)
    with mlflow.start_run():
        model.fit()
"""
import logging
import os
import tempfile
import traceback
import warnings
from collections.abc import Callable
from functools import wraps
from pathlib import Path
from typing import Any, Literal
import arviz as az
import numpy.typing as npt
import pandas as pd
import pymc as pm
import xarray as xr
from pymc.model.core import Model
from pytensor.tensor import TensorVariable
try:
    import mlflow
except ImportError:  # pragma: no cover
    msg = "This module requires mlflow. Install using `pip install mlflow`"
    raise ImportError(msg)
from mlflow.utils.autologging_utils import autologging_integration
from pymc_marketing.clv.models.basic import CLVModel
from pymc_marketing.mmm import MMM
from pymc_marketing.mmm.evaluation import compute_summary_metrics
from pymc_marketing.version import __version__
FLAVOR_NAME = "pymc"
PYMC_MARKETING_ISSUE = "https://github.com/pymc-labs/pymc-marketing/issues/new"
warning_msg = (
    "This functionality is experimental and subject to change. "
    "If you encounter any issues or have suggestions, please raise them at: "
    f"{PYMC_MARKETING_ISSUE}"
)
warnings.warn(warning_msg, FutureWarning, stacklevel=1)
def _exclude_tuning(func):
    def callback(trace, draw):
        if draw.tuning:
            return
        return func(trace, draw)
    return callback
def _take_every(n: int):
    def decorator(func):
        def callback(trace, draw):
            if draw.draw_idx % n != 0:
                return
            return func(trace, draw)
        return callback
    return decorator
[docs]
def create_log_callback(
    stats: list[str] | None = None,
    parameters: list[str] | None = None,
    exclude_tuning: bool = True,
    take_every: int = 100,
):
    """Create callback function to log sample stats and parameter values to MLflow during sampling.
    This callback only works for the "pymc" sampler.
    Parameters
    ----------
    stats : list of str, optional
        List of sample statistics to log from the Draw
    parameters : list of str, optional
        List of parameters to log from the Draw
    exclude_tuning : bool, optional
        Whether to exclude tuning steps from logging. Defaults to True.
    take_every : int, optional
        Specifies the interval at which to log values. Defaults to 100.
    Returns
    -------
    callback : Callable
        The callback function to log sample stats and parameter values to MLflow during sampling
    Examples
    --------
    Create example model:
    .. code-block:: python
        import pymc as pm
        with pm.Model() as model:
            mu = pm.Normal("mu")
            sigma = pm.HalfNormal("sigma")
            obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=[1, 2, 3])
    Log off divergences and logp every 100th draw:
    .. code-block:: python
        import mlflow
        from pymc_marketing.mlflow import create_log_callback
        callback = create_log_callback(
            stats=["diverging", "model_logp"],
            take_every=100,
        )
        mlflow.set_experiment("Live Tracking Stats")
        with mlflow.start_run():
            idata = pm.sample(model=model, callback=callback)
    Log the parameters `mu` and `sigma_log__` every 100th draw:
    .. code-block:: python
        import mlflow
        from pymc_marketing.mlflow import create_log_callback
        callback = create_log_callback(
            parameters=["mu", "sigma_log__"],
            take_every=100,
        )
        mlflow.set_experiment("Live Tracking Parameters")
        with mlflow.start_run():
            idata = pm.sample(model=model, callback=callback)
    """
    if not stats and not parameters:
        raise ValueError("At least one of `stats` or `parameters` must be provided.")
    def callback(_, draw):
        prefix = f"chain_{draw.chain}"
        for stat in stats or []:
            mlflow.log_metric(
                key=f"{prefix}/{stat}",
                value=draw.stats[0][stat],
                step=draw.draw_idx,
            )
        for parameter in parameters or []:
            mlflow.log_metric(
                key=f"{prefix}/{parameter}",
                value=draw.point[parameter],
                step=draw.draw_idx,
            )
    if exclude_tuning:
        callback = _exclude_tuning(callback)
    if take_every:
        callback = _take_every(n=take_every)(callback)
    return callback 
def _log_and_remove_artifact(path: str | Path) -> None:
    """Log an artifact to MLflow and then remove the local file.
    Parameters
    ----------
    path : str | Path
        Path to the artifact file to log and remove.
    """
    mlflow.log_artifact(str(path))
    os.remove(path)
def _force_load_idata_groups(idata: az.InferenceData) -> None:
    """Force load all groups into memory since ArviZ does lazy loading.
    Parameters
    ----------
    idata : az.InferenceData
        The InferenceData object to force load.
    """
    for group in idata.groups():
        # Convert each group to an in-memory dataset
        if hasattr(idata, group):
            group_data = getattr(idata, group)
            if hasattr(group_data, "load"):
                group_data.load()
[docs]
def log_arviz_summary(
    idata: az.InferenceData,
    path: str | Path,
    var_names: list[str] | None = None,
    **summary_kwargs,
) -> None:
    """Log the ArviZ summary as an artifact on MLflow.
    Automatically removes the file after logging.
    Parameters
    ----------
    idata : az.InferenceData
        The InferenceData object returned by the sampling method.
    path : str | Path
        The path to save the summary as HTML.
    var_names : list[str], optional
        The names of the variables to include in the summary. Default is
        all the variables in the InferenceData object.
    summary_kwargs : dict
        Additional keyword arguments to pass to `az.summary`.
    """
    df_summary = az.summary(idata, var_names=var_names, **summary_kwargs)
    df_summary.to_html(path)
    mlflow.log_artifact(str(path))
    os.remove(path) 
[docs]
def log_model_graph(model: Model, path: str | Path) -> None:
    """Log the model graph PDF as artifact on MLflow.
    Automatically removes the file after logging.
    Parameters
    ----------
    model : Model
        The PyMC model object.
    path : str | Path
        The path to save the model graph
    """
    try:
        graph = pm.model_to_graphviz(model)
    except ImportError as e:
        msg = (
            "Unable to render the model graph. Please install the graphviz package. "
            f"{e}"
        )
        logging.info(msg)
        return None
    try:
        saved_path = graph.render(path)
    except Exception as e:
        msg = f"Unable to render the model graph. {e}"
        logging.info(msg)
        return None
    else:
        _log_and_remove_artifact(saved_path)
        os.remove(path) 
def _get_random_variable_name(rv) -> str:
    # Taken from new version of pymc/model_graph.py
    symbol = rv.owner.op.__class__.__name__
    if symbol.endswith("RV"):
        symbol = symbol[:-2]
    return symbol
[docs]
def log_types_of_parameters(model: Model) -> None:
    """Log the types of parameters in a PyMC model to MLflow.
    Parameters
    ----------
    model : Model
        The PyMC model object.
    """
    mlflow.log_param("n_free_RVs", len(model.free_RVs))
    mlflow.log_param("n_observed_RVs", len(model.observed_RVs))
    mlflow.log_param("n_deterministics", len(model.deterministics))
    mlflow.log_param("n_potentials", len(model.potentials)) 
[docs]
def log_likelihood_type(model: Model) -> None:
    """Save the likelihood type of the model to MLflow.
    Parameters
    ----------
    model : Model
        The PyMC model object.
    """
    observed_RVs_types = [_get_random_variable_name(rv) for rv in model.observed_RVs]
    if len(observed_RVs_types) == 1:
        mlflow.log_param("likelihood", observed_RVs_types[0])
    elif len(observed_RVs_types) > 1:
        mlflow.log_param("observed_RVs_types", observed_RVs_types) 
[docs]
def log_model_derived_info(model: Model) -> None:
    """Log various model derived information to MLflow.
    Includes:
    - The types of parameters in the model.
    - The likelihood type of the model.
    - The model representation (str).
    - The model coordinates (coords.json).
    Parameters
    ----------
    model : Model
        The PyMC model object.
    """
    log_types_of_parameters(model)
    mlflow.log_text(model.str_repr(), "model_repr.txt")
    if model.coords:
        mlflow.log_dict(model.coords, "coords.json")
    log_model_graph(model, "model_graph")
    log_likelihood_type(model) 
[docs]
def log_sample_diagnostics(
    idata: az.InferenceData,
    tune: int | None = None,
) -> None:
    """Log sample diagnostics to MLflow.
    Includes:
    - The total number of divergences
    - The total sampling time in seconds (if available)
    - The time per draw in seconds (if available)
    - The number of tuning steps (if available)
    - The number of draws
    - The number of chains
    - The inference library used
    - The version of the inference library
    - The version of ArviZ
    Parameters
    ----------
    idata : az.InferenceData
        The InferenceData object returned by the sampling method.
    tune : int, optional
        The number of tuning steps used in sampling. Derived from the
        inference data if not provided.
    """
    if "posterior" not in idata:
        raise KeyError("InferenceData object does not contain the group posterior.")
    if "sample_stats" not in idata:
        raise KeyError("InferenceData object does not contain the group sample_stats.")
    posterior = idata["posterior"]
    sample_stats = idata["sample_stats"]
    diverging = sample_stats["diverging"]
    chains = posterior.sizes["chain"]
    draws = posterior.sizes["draw"]
    posterior_samples = chains * draws
    tuning_step = sample_stats.attrs.get("tuning_steps", tune)
    if tuning_step is not None:
        tuning_samples = tuning_step * chains
        mlflow.log_param("tuning_steps", tuning_step)
        mlflow.log_param("tuning_samples", tuning_samples)
    total_divergences = diverging.sum().item()
    mlflow.log_metric("total_divergences", total_divergences)
    if sampling_time := sample_stats.attrs.get("sampling_time"):
        mlflow.log_metric("sampling_time", sampling_time)
        mlflow.log_metric(
            "time_per_draw",
            sampling_time / posterior_samples,
        )
    mlflow.log_param("draws", draws)
    mlflow.log_param("chains", chains)
    mlflow.log_param("posterior_samples", posterior_samples)
    if inference_library := posterior.attrs.get("inference_library"):
        mlflow.log_param("inference_library", inference_library)
        mlflow.log_param(
            "inference_library_version",
            posterior.attrs["inference_library_version"],
        ) 
[docs]
def log_inference_data(
    idata: az.InferenceData,
    save_file: str | Path = "idata.nc",
) -> None:
    """Log the InferenceData to MLflow.
    Parameters
    ----------
    idata : az.InferenceData
        The InferenceData object returned by the sampling method.
    save_file : str | Path
        The path to save the InferenceData object as a netCDF file.
    """
    idata.to_netcdf(str(save_file))
    _log_and_remove_artifact(save_file) 
[docs]
def log_mmm_evaluation_metrics(
    y_true: npt.NDArray | pd.Series,
    y_pred: npt.NDArray | xr.DataArray,
    metrics_to_calculate: list[str] | None = None,
    hdi_prob: float = 0.94,
    prefix: str = "",
) -> None:
    """Log evaluation metrics produced by `pymc_marketing.mmm.evaluation.compute_summary_metrics()` to MLflow.
    Parameters
    ----------
    y_true : npt.NDArray | pd.Series
        The true values of the target variable.
    y_pred : npt.NDArray | xr.DataArray
        The predicted values of the target variable.
    metrics_to_calculate : list of str or None, optional
        List of metrics to calculate. If None, all available metrics will be calculated.
        Options include:
            * `r_squared`: Bayesian R-squared.
            * `rmse`: Root Mean Squared Error.
            * `nrmse`: Normalized Root Mean Squared Error.
            * `mae`: Mean Absolute Error.
            * `nmae`: Normalized Mean Absolute Error.
            * `mape`: Mean Absolute Percentage Error.
    hdi_prob : float, optional
        The probability mass of the highest density interval. Defaults to 0.94.
    prefix : str, optional
        Prefix to add to the metric names. Defaults to "".
    Examples
    --------
    Log in-sample evaluation metrics for a PyMC-Marketing MMM model:
    .. code-block:: python
        import mlflow
        from pymc_marketing.mmm import MMM
        mmm = MMM(...)
        mmm.fit(X, y)
        predictions = mmm.sample_posterior_predictive(X)
        with mlflow.start_run():
            log_mmm_evaluation_metrics(y, predictions["y"])
    """
    metric_summaries = compute_summary_metrics(
        y_true=y_true,
        y_pred=y_pred,
        metrics_to_calculate=metrics_to_calculate,
        hdi_prob=hdi_prob,
    )
    if prefix and not prefix.endswith("_"):
        prefix = f"{prefix}_"
    for metric, stats in metric_summaries.items():
        for stat, value in stats.items():
            # mlflow doesn't support % in metric names
            mlflow.log_metric(f"{prefix}{metric}_{stat.replace('%', '')}", value) 
[docs]
class MMMWrapper(mlflow.pyfunc.PythonModel):
    """A class to prepare a PyMC-Marketing Mix Model (MMM) for logging and registering in MLflow.
    This class extends MLflow's PythonModel to handle prediction tasks using a PyMC-based MMM.
    It supports several prediction methods, including point-prediction, posterior and prior predictive sampling.
    Parameters
    ----------
    model : pymc_marketing.mmm.MMM
        The marketing mix model to be registered and used for predictions.
    predict_method : str, optional, default="predict"
        The default prediction method to use, such as "predict",
        "sample_posterior_predictive", or "sample_prior_predictive".
    extend_idata : bool, default=False
        Boolean determining whether the predictions should be added to inference data object. Defaults to False.
    combined : bool, default=True
        Combine chain and draw dims into sample. Won't work if a dim named sample already exists. Defaults to True.
    include_last_observations : bool, default=False
        Boolean determining whether to include the last observations of the training data in order to carry over
        costs with the adstock transformation. Assumes that X are the next predictions following the
        training data. Defaults to False.
    original_scale : bool, default=True
        Boolean determining whether to return the predictions in the original scale of the target variable.
    var_names : list of str, optional, default=None
        The variable names to include in the predictions.
    sample_kwargs : dict, optional
        Additional keyword arguments to pass to the selected sampling methods.
    """
[docs]
    def __init__(
        self,
        model: MMM,
        predict_method: Literal[
            "predict", "sample_posterior_predictive", "sample_prior_predictive"
        ] = "predict",
        extend_idata: bool = False,
        combined: bool = True,
        include_last_observations: bool = False,
        original_scale: bool = True,
        var_names: list[str] | None = None,
        **sample_kwargs: dict,
    ):
        self.model = model
        self.predict_method = predict_method
        self.extend_idata = extend_idata
        self.combined = combined
        self.include_last_observations = include_last_observations
        self.original_scale = original_scale
        self.var_names = (
            var_names if var_names is not None else [model.output_var]
        )  # Initialize if not provided
        self.sample_kwargs = sample_kwargs 
[docs]
    def predict(
        self, context: Any, model_input, params: dict[str, Any] | None = None
    ) -> Any:
        """Perform predictions or sampling using the specified prediction method.
        Parameters
        ----------
        context : Any
            The context in which the model is running. Isn't specified by users but is passed by MLflow.
        model_input : array, shape (n_pred, n_features)
            The input data used for prediction.
        params : dict, optional
            A dictionary of parameters to specify the prediction method.
        Returns
        -------
        ndarray or InferenceData
            The predictions or samples generated by the model.
        Raises
        ------
        ValueError
            If an unsupported prediction method is specified.
        """
        # Use the class-level predict_method if params is not provided or doesn't contain 'predict_method'
        params = params or {"predict_method": "predict"}
        predict_method = params.get("predict_method", self.predict_method)
        if predict_method == "predict":
            return self.model.predict(
                model_input,
                extend_idata=self.extend_idata,
                include_last_observations=self.include_last_observations,
                original_scale=self.original_scale,
                var_names=self.var_names,
                **self.sample_kwargs,  # type: ignore[arg-type]
            )
        elif predict_method == "sample_posterior_predictive":
            return self.model.sample_posterior_predictive(
                model_input,
                extend_idata=self.extend_idata,
                combined=self.combined,
                include_last_observations=self.include_last_observations,
                original_scale=self.original_scale,
                var_names=self.var_names,
                **self.sample_kwargs,  # type: ignore[arg-type]
            )
        elif predict_method == "sample_prior_predictive":
            return self.model.sample_prior_predictive(
                model_input,
                extend_idata=self.extend_idata,
                combined=self.combined,
                var_names=self.var_names,
                **self.sample_kwargs,  # type: ignore[arg-type]
            )
        else:
            raise ValueError(
                f"The prediction method '{predict_method}' is not supported."
            ) 
 
[docs]
def log_mmm(
    mmm: MMM,
    artifact_path: str = "model",
    registered_model_name: str | None = None,
    extend_idata: bool = False,
    combined: bool = True,
    include_last_observations: bool = False,
    original_scale: bool = True,
) -> None:
    """Log a PyMC-Marketing MMM as a native MLflow model for the current run.
    Parameters
    ----------
    mmm : MMM
        The MMM to be logged.
    artifact_path : str, optional
        The path to the artifact to be logged. Defaults to "mmm_model".
    conda_env : dict, optional
        A dictionary representation of a Conda environment. Defaults to the default conda environment.
    registered_model_name : str, optional
        The name of the registered model to be logged. Defaults to None.
        If specified, the model will be registered under this name, otherwise it will not be registered.
    extend_idata : bool, optional
        Whether to extend the inference data with predictions. Used for all prediction methods.
        Defaults to False.
    combined : bool, optional
        Whether to combine chain and draw dims into sample. Won't work if a dim named sample
        already exists. Used for posterior/prior predictive sampling. Defaults to True.
    include_last_observations : bool, optional
        Whether to include the last observations of training data for adstock transformation.
        Assumes X are next predictions following training data. Used for all prediction
        methods. Defaults to False.
    original_scale : bool, optional
        Whether to return predictions in original scale of target variable. Used for all
        prediction methods. Defaults to True.
    Notes
    -----
    This function logs the model as a native MLflow model, this is different to the full model object,
    which includes the InferenceData. Doing this allows for the model to be stored in the MLFlow registry,
    helping with model versioning and deployment.
    Examples
    --------
    MLFlow Registering for a PyMC-Marketing MMM:
    .. code-block:: python
        import pandas as pd
        import mlflow
        from pymc_marketing.mmm import (
            GeometricAdstock,
            LogisticSaturation,
            MMM,
        )
        from pymc_marketing.paths import data_dir
        import pymc_marketing.mlflow
        from pymc_marketing.mlflow import log_mmm
        pymc_marketing.mlflow.autolog(log_mmm=True)
        # Usual PyMC-Marketing model code
        file_path = data_dir / "mmm_example.csv"
        data = pd.read_csv(file_path, parse_dates=["date_week"])
        X = data.drop("y", axis=1)
        y = data["y"]
        mmm = MMM(
            adstock=GeometricAdstock(l_max=8),
            saturation=LogisticSaturation(),
            date_column="date_week",
            channel_columns=["x1", "x2"],
            control_columns=[
                "event_1",
                "event_2",
                "t",
            ],
            yearly_seasonality=2,
        )
        mlflow.set_experiment("MMM Experiment")
        with mlflow.start_run():
            idata = mmm.fit(X, y)
            # Additional specific logging
            fig = mmm.plot_components_contributions()
            mlflow.log_figure(fig, "components.png")
            model_info = log_mmm(
                mmm=mmm,
                registered_model_name="my_amazing_mmm",
                include_last_observations=True,
                original_scale=False,
            )
    """
    # Incorporate MMM into MLflow workflow
    mlflow_mmm = MMMWrapper(
        model=mmm,
        extend_idata=extend_idata,
        combined=combined,
        include_last_observations=include_last_observations,
        original_scale=original_scale,
    )
    mlflow.pyfunc.log_model(
        artifact_path=artifact_path,
        python_model=mlflow_mmm,
    )
    run_id = mlflow.active_run().info.run_id
    model_uri = f"runs:/{run_id}/{artifact_path}"
    if registered_model_name:
        mlflow.register_model(model_uri, registered_model_name) 
[docs]
def load_mmm(
    run_id: str,
    full_model: bool = False,
    keep_idata: bool = False,
    artifact_path: str = "model",
    dst_path: str | None = None,
) -> mlflow.pyfunc.PyFuncModel | MMM:
    """
    Load a PyMC-Marketing MMM model from MLflow.
    Can either load the full model including the InferenceData, or just the lighter PyFuncModel version.
    Parameters
    ----------
    run_id : str
        The MLflow run ID from which to load the model.
    full_model : bool, default=True
        If True, load the full MMM model including the InferenceData.
    keep_idata : bool, default=False
        If True, keep the downloaded InferenceData saved locally.
    artifact_path : str, default="model"
        The artifact path within the run where the model is stored.
    dst_path : str | None, default=None
        The local destination path where the InferenceData will be downloaded.
        If None, defaults to "idata_{run_id}" to avoid conflicts when loading multiple models.
    Returns
    -------
    model : mlflow.pyfunc.PyFuncModel | MMM
        The loaded MLflow PyFuncModel or MMM model.
    Examples
    --------
    .. code-block:: python
        # Load model using run_id
        model = load_mmm(run_id="your_run_id", full_model=True, keep_idata=True)
    """
    model_uri = f"runs:/{run_id}/{artifact_path}"
    if not full_model:
        model = mlflow.pyfunc.load_model(model_uri)
        return model
    # Create unique destination path if not provided
    if dst_path is None:
        dst_path = f"idata_{run_id}"
    idata_path = mlflow.artifacts.download_artifacts(
        run_id=run_id, artifact_path="idata.nc", dst_path=dst_path
    )
    model = MMM.load(idata_path)
    if not keep_idata:
        _force_load_idata_groups(model.idata)
        try:
            os.remove(idata_path)
            os.rmdir(dst_path)
        except OSError:
            warnings.warn(
                f"Could not remove temporary files at {dst_path}. You may want to remove them manually.",
                UserWarning,
                stacklevel=2,
            )
    return model 
[docs]
def log_versions() -> None:
    """Log the versions of PyMC-Marketing, PyMC, and ArviZ to MLflow."""
    mlflow.log_param("pymc_marketing_version", __version__)
    mlflow.log_param("pymc_version", pm.__version__)
    mlflow.log_param("arviz_version", az.__version__) 
[docs]
def log_mmm_configuration(mmm: MMM) -> None:
    """Log the configuration of the MMM model to MLflow."""
    attrs = mmm.create_idata_attrs()
    mlflow.log_params(attrs)
    mlflow.log_param("adstock_name", mmm.adstock.lookup_name)
    mlflow.log_param("saturation_name", mmm.saturation.lookup_name) 
[docs]
def log_error(func: Callable, file_name: str):
    """Log arbitrary caught error and traceback to MLflow.
    .. note::
        The error will still be raised with the program. It is just logged
        to MLflow
    Parameters
    ----------
    func : Callable
        Arbitrary function
    file_name : str
        The name of the MLflow artifact
    Examples
    --------
    .. code-block:: python
        import mlflow
        from pymc_marketing.mlflow import log_error
        def raising_function():
            raise NotImplementedError("Sorry. Not implemented")
        func = log_error(raising_function, file_name="raising-function")
        with mlflow.start_run():
            func()
    """
    @wraps(func)
    def wrapped(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception as e:
            with tempfile.TemporaryDirectory() as tmp_dir:
                path = Path(tmp_dir) / file_name
                with path.open("w") as f:
                    traceback.print_exc(file=f)
                mlflow.log_artifact(str(path))
            raise e
    return wrapped 
[docs]
@autologging_integration(FLAVOR_NAME)
def autolog(
    log_sampler_info: bool = True,
    log_metadata_info: bool = True,
    log_model_info: bool = True,
    sample_error_file: str | None = "sample-error.txt",
    summary_var_names: list[str] | None = None,
    arviz_summary_kwargs: dict | None = None,
    log_mmm: bool = True,
    log_clv: bool = True,
    disable: bool = False,
    silent: bool = False,
) -> None:
    """Autologging support for PyMC models and PyMC-Marketing models.
    Includes logging of sampler diagnostics, model information, data used in the
    model, and InferenceData objects upon sampling the models.
    For more information about MLflow, see
    https://mlflow.org/docs/latest/python_api/mlflow.html
    Parameters
    ----------
    log_sampler_info : bool, optional
        Whether to log sampler diagnostics. Default is True.
    log_metadata_info : bool, optional
        Whether to log the metadata of inputs used in the model. Default is True.
    log_model_info : bool, optional
        Whether to log model information. Default is True.
    sample_error_file : str, optional
        The name of the file to log the error if an error occurs during sampling. If
        None, the error will not be logged. Default is "sample-error.txt".
    summary_var_names : list[str], optional
        The names of the variables to include in the ArviZ summary. Default is
        all the variables in the InferenceData object.
    arviz_summary_kwargs : dict, optional
        Additional keyword arguments to pass to `az.summary`.
    log_mmm : bool, optional
        Whether to log PyMC-Marketing MMM models. Default is True.
    log_clv : bool, optional
        Whether to log PyMC-Marketing CLV models. Default is True.
    disable : bool, optional
        Whether to disable autologging. Default is False.
    silent : bool, optional
        Whether to suppress all warnings. Default is False.
    Examples
    --------
    Autologging for a PyMC model:
    .. code-block:: python
        import mlflow
        import pymc as pm
        import pymc_marketing.mlflow
        pymc_marketing.mlflow.autolog()
        # Usual PyMC model code
        with pm.Model() as model:
            mu = pm.Normal("mu", mu=0, sigma=1)
            obs = pm.Normal("obs", mu=mu, sigma=1, observed=[1, 2, 3])
        # Incorporate into MLflow workflow
        mlflow.set_experiment("PyMC Experiment")
        with mlflow.start_run():
            idata = pm.sample(model=model)
    Autologging for a PyMC-Marketing MMM:
    .. code-block:: python
        import pandas as pd
        import mlflow
        from pymc_marketing.mmm import (
            GeometricAdstock,
            LogisticSaturation,
            MMM,
        )
        from pymc_marketing.paths import data_dir
        import pymc_marketing.mlflow
        pymc_marketing.mlflow.autolog(log_mmm=True)
        # Usual PyMC-Marketing model code
        file_path = data_dir / "mmm_example.csv"
        data = pd.read_csv(file_path, parse_dates=["date_week"])
        X = data.drop("y", axis=1)
        y = data["y"]
        mmm = MMM(
            adstock=GeometricAdstock(l_max=8),
            saturation=LogisticSaturation(),
            date_column="date_week",
            channel_columns=["x1", "x2"],
            control_columns=[
                "event_1",
                "event_2",
                "t",
            ],
            yearly_seasonality=2,
        )
        # Incorporate into MLflow workflow
        mlflow.set_experiment("MMM Experiment")
        with mlflow.start_run():
            idata = mmm.fit(X, y)
            posterior_preds = mmm.sample_posterior_predictive(X)
            # Additional specific logging
            fig = mmm.plot_components_contributions()
            mlflow.log_figure(fig, "components.png")
    Autologging for a PyMC-Marketing CLV model:
    .. code-block:: python
        import pandas as pd
        import mlflow
        from pymc_marketing.clv import BetaGeoModel
        from pymc_marketing.paths import data_dir
        import pymc_marketing.mlflow
        pymc_marketing.mlflow.autolog(log_clv=True)
        mlflow.set_experiment("CLV Experiment")
        file_path = data_dir / "clv_quickstart.csv"
        data = pd.read_csv(file_path)
        data["customer_id"] = data.index
        model = BetaGeoModel(data=data)
        with mlflow.start_run():
            model.fit()
        with mlflow.start_run():
            model.fit(fit_method="map")
    """
    arviz_summary_kwargs = arviz_summary_kwargs or {}
    def patch_sample(sample: Callable) -> Callable:
        @wraps(sample)
        def new_sample(*args, **kwargs):
            log_versions()
            model = pm.modelcontext(kwargs.get("model"))
            mlflow.log_param("nuts_sampler", kwargs.get("nuts_sampler", "pymc"))
            if log_model_info:
                log_model_derived_info(model)
            idata = sample(*args, **kwargs)
            # Align with the default values in pymc.sample
            tune = kwargs.get("tune", 1000)
            if log_sampler_info:
                log_sample_diagnostics(idata, tune=tune)
                log_arviz_summary(
                    idata,
                    "summary.html",
                    var_names=summary_var_names,
                    **arviz_summary_kwargs,
                )
            if log_metadata_info:
                log_metadata(model=model, idata=idata)
            return idata
        if sample_error_file:
            new_sample = log_error(new_sample, sample_error_file)
        return new_sample
    pm.sample = patch_sample(pm.sample)
    def patch_find_MAP(find_MAP):
        @wraps(find_MAP)
        def new_find_MAP(*args, **kwargs):
            model = pm.modelcontext(kwargs.get("model"))
            if log_model_info:
                log_model_derived_info(model)
            return find_MAP(*args, **kwargs)
        return new_find_MAP
    pm.find_MAP = patch_find_MAP(pm.find_MAP)
    def patch_mmm_fit(fit: Callable) -> Callable:
        @wraps(fit)
        def new_fit(self, *args, **kwargs):
            log_mmm_configuration(self)
            idata = fit(self, *args, **kwargs)
            log_inference_data(idata, save_file="idata.nc")
            return idata
        return new_fit
    if log_mmm:
        MMM.fit = patch_mmm_fit(MMM.fit)
    def patch_clv_fit(fit):
        @wraps(fit)
        def new_fit(self, fit_method: str = "mcmc", **kwargs):
            mlflow.log_param("model_type", self._model_type)
            mlflow.log_param("fit_method", fit_method)
            idata = fit(self, fit_method, **kwargs)
            mlflow.log_params(
                idata.attrs,
            )
            log_inference_data(idata, save_file="idata.nc")
            return idata
        return new_fit
    if log_clv:
        CLVModel.fit = patch_clv_fit(CLVModel.fit)