#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""
Specialized priors that behave like the Prior class.
The Prior class has certain design constraints that prevent it from
covering all cases. So this module contains a collection of
priors that do not inherit from the Prior class but have many
of the same methods.
"""
import warnings
from typing import Any
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
from pymc_extras.deserialize import deserialize, register_deserialization
from pymc_extras.prior import Prior, VariableFactory, create_dim_handler, sample_prior
from pytensor.tensor import TensorVariable
[docs]
class LogNormalPrior:
    r"""Lognormal prior parameterized by positive-scale mean and std.
    A lognormal prior parameterized by mean and standard deviation
    on the positive domain, with optional centered or non-centered
    parameterization.
    This prior differs from the standard ``LogNormal`` distribution,
    which takes log-scale parameters (``mu_log``, ``sigma_log``).
    Instead, it is parameterized directly in terms of the mean and
    standard deviation (``mean``, ``std``) on the positive scale, making it more intuitive
    and suitable for hierarchical modeling.
    To achieve this, the lognormal parameters are computed internally
    from the positive-domain parameters:
    .. math::
        \mu_{\log} &= \ln \left( \frac{\mean^2}{\sqrt{\mean^2 + \std^2}} \right) \\
        \sigma_{\log} &= \sqrt{ \ln \left( 1 + \frac{\std^2}{\mean^2} \right) }
    where :math:`\\mean > 0` and :math:`\\std > 0`.
    The prior is then defined as:
    .. math::
        \\phi &\\sim \text{LogNormal}(\\mu_{\\log}, \\sigma_{\\log})
    This construction ensures that the resulting random variable
    has approximately the intended mean and variance on the positive scale,
    even when :math:`\\mean` and :math:`\\std` are themselves random variables.
    Parameters
    ----------
    mean : Prior, float, int, array-like
        The mean of the distribution on the positive scale.
    std : Prior, float, int, array-like
        The standard deviation of the distribution on the positive scale.
    dims : tuple[str, ...], optional
        The dimensions of the distribution, by default None.
    centered : bool, optional
        Whether to use the centered parameterization, by default True.
    Examples
    --------
    Build a non-centered hierarchical model where information is shared across groups:
    .. code-block:: python
        from pymc_marketing.special_priors import LogNormalPrior
        prior = LogNormalPrior(
            mean=Prior("Gamma", mu=1.0, sigma=1.0),
            std=Prior("HalfNormal", sigma=1.0),
            dims=("geo",),
            centered=False,
        )
    References
    ----------
    - D. Saunders, *A positive constrained non-centered prior that sparks joy*.
    - Wikipedia, *Log-normal distribution — Definitions*.
    """
[docs]
    def __init__(self, dims: tuple | None = None, centered: bool = True, **parameters):
        # Accept aliases mu->mean and sigma->std for convenience/compatibility
        if "mean" not in parameters and "mu" in parameters:
            parameters["mean"] = parameters.pop("mu")
        if "std" not in parameters and "sigma" in parameters:
            parameters["std"] = parameters.pop("sigma")
        self.parameters = parameters
        self.dims = dims
        self.centered = centered
        self._checks() 
    def _checks(self) -> None:
        self._parameters_are_correct_set()
    def _parameters_are_correct_set(self) -> None:
        # Only allow exactly these keys after alias normalization
        if set(self.parameters.keys()) != {"mean", "std"}:
            raise ValueError("Parameters must be mean and std")
    def _create_parameter(self, param, value, name):
        if not hasattr(value, "create_variable"):
            return value
        child_name = f"{name}_{param}"
        return self.dim_handler(value.create_variable(child_name), value.dims)
[docs]
    def create_variable(self, name: str) -> TensorVariable:
        """Create a variable from the prior distribution."""
        self.dim_handler = create_dim_handler(self.dims)
        parameters = {
            param: self._create_parameter(param, value, name)
            for param, value in self.parameters.items()
        }
        mu_log = pt.log(
            parameters["mean"] ** 2
            / pt.sqrt(parameters["mean"] ** 2 + parameters["std"] ** 2)
        )
        sigma_log = pt.sqrt(
            pt.log(1 + (parameters["std"] ** 2 / parameters["mean"] ** 2))
        )
        if self.centered:
            log_phi = pm.Normal(
                name + "_log", mu=mu_log, sigma=sigma_log, dims=self.dims
            )
        else:
            log_phi_z = pm.Normal(
                name + "_log" + "_offset", mu=0, sigma=1, dims=self.dims
            )
            log_phi = mu_log + log_phi_z * sigma_log
        phi = pm.math.exp(log_phi)
        phi = pm.Deterministic(name, phi, dims=self.dims)
        return phi 
[docs]
    def to_dict(self):
        """Convert the prior distribution to a dictionary."""
        data = {
            "special_prior": "LogNormalPrior",
        }
        if self.parameters:
            def handle_value(value):
                if isinstance(value, Prior):
                    return value.to_dict()
                if isinstance(value, pt.TensorVariable):
                    value = value.eval()
                if isinstance(value, np.ndarray):
                    return value.tolist()
                if hasattr(value, "to_dict"):
                    return value.to_dict()
                return value
            data["kwargs"] = {
                param: handle_value(value) for param, value in self.parameters.items()
            }
        if not self.centered:
            data["centered"] = False
        if self.dims:
            data["dims"] = self.dims
        return data 
[docs]
    @classmethod
    def from_dict(cls, data) -> Prior:
        """Create a LogNormalPrior prior from a dictionary."""
        if not isinstance(data, dict):
            msg = (
                "Must be a dictionary representation of a prior distribution. "
                f"Not of type: {type(data)}"
            )
            raise ValueError(msg)
        kwargs = data.get("kwargs", {})
        def handle_value(value):
            if isinstance(value, dict):
                return deserialize(value)
            if isinstance(value, list):
                return np.array(value)
            return value
        kwargs = {param: handle_value(value) for param, value in kwargs.items()}
        centered = data.get("centered", True)
        dims = data.get("dims")
        if isinstance(dims, list):
            dims = tuple(dims)
        return cls(dims=dims, centered=centered, **kwargs) 
[docs]
    def sample_prior(
        self,
        coords=None,
        name: str = "variable",
        **sample_prior_predictive_kwargs,
    ) -> xr.Dataset:
        """Sample from the prior distribution."""
        return sample_prior(
            factory=self,
            coords=coords,
            name=name,
            **sample_prior_predictive_kwargs,
        ) 
 
def _is_LogNormalPrior_type(data: dict) -> bool:
    if "special_prior" in data:
        return data["special_prior"] == "LogNormalPrior"
    else:
        return False
register_deserialization(
    is_type=_is_LogNormalPrior_type,
    deserialize=LogNormalPrior.from_dict,
)
[docs]
class MaskedPrior:
    """Create variables from a prior over only the active entries of a boolean mask.
    .. warning::
        This class is experimental and its API may change in future versions.
    Parameters
    ----------
    prior : Prior
        Base prior whose variable is defined over `prior.dims`. Internally, the
        variable is created only for the active entries given by `mask` and
        then expanded back to the full shape with zeros at inactive positions.
    mask : xarray.DataArray
        Boolean array with the same dims and shape as `prior.dims` marking active
        (True) and inactive (False) entries.
    active_dim : str, optional
        Name of the coordinate indexing the active subset. If not provided, a
        name is generated as ``"non_null_dims:<dim1>_<dim2>_..."``. If an existing
        coordinate with the same name has a different length, a suffix with the
        active length is appended.
    Examples
    --------
    Simple 1D masking.
    .. code-block:: python
        import numpy as np
        import xarray as xr
        import pymc as pm
        from pymc_extras.prior import Prior
        from pymc_marketing.special_priors import MaskedPrior
        coords = {"country": ["Venezuela", "Colombia"]}
        mask = xr.DataArray(
            [True, False],
            dims=["country"],
            coords={"country": coords["country"]},
        )
        intercept = Prior("Normal", mu=0, sigma=10, dims=("country",))
        with pm.Model(coords=coords):
            masked = MaskedPrior(intercept, mask)
            intercept_full = masked.create_variable("intercept")
    Nested parameter priors with dims remapped to the active subset.
    .. code-block:: python
        import numpy as np
        import xarray as xr
        import pymc as pm
        from pymc_extras.prior import Prior
        from pymc_marketing.special_priors import MaskedPrior
        coords = {"country": ["Venezuela", "Colombia"]}
        mask = xr.DataArray(
            [True, False],
            dims=["country"],
            coords={"country": coords["country"]},
        )
        intercept = Prior(
            "Normal",
            mu=Prior("HalfNormal", sigma=1, dims=("country",)),
            sigma=10,
            dims=("country",),
        )
        with pm.Model(coords=coords):
            masked = MaskedPrior(intercept, mask)
            intercept_full = masked.create_variable("intercept")
    All entries masked (returns deterministic zeros with original dims).
    .. code-block:: python
        import numpy as np
        import xarray as xr
        import pymc as pm
        from pymc_extras.prior import Prior
        from pymc_marketing.special_priors import MaskedPrior
        coords = {"country": ["Venezuela", "Colombia"]}
        mask = xr.DataArray(
            [False, False],
            dims=["country"],
            coords={"country": coords["country"]},
        )
        prior = Prior("Normal", mu=0, sigma=10, dims=("country",))
        with pm.Model(coords=coords):
            masked = MaskedPrior(prior, mask)
            zeros = masked.create_variable("intercept")
    Apply over a saturation function priors:
    .. code-block:: python
        from pymc_marketing.mmm import LogisticSaturation
        from pymc_marketing.special_priors import MaskedPrior
        coords = {
            "country": ["Colombia", "Venezuela"],
            "channel": ["x1", "x2", "x3", "x4"],
        }
        mask_excluded_x4_colombia = xr.DataArray(
            [[True, False, True, False], [True, True, True, True]],
            dims=["country", "channel"],
            coords=coords,
        )
        saturation = LogisticSaturation(
            priors={
                "lam": MaskedPrior(
                    Prior(
                        "Gamma",
                        mu=2,
                        sigma=0.5,
                        dims=("country", "channel"),
                    ),
                    mask=mask_excluded_x4_colombia,
                ),
                "beta": Prior(
                    "Gamma",
                    mu=3,
                    sigma=0.5,
                    dims=("country", "channel"),
                ),
            }
        )
        prior = saturation.sample_prior(coords=coords, random_seed=10)
        curve = saturation.sample_curve(prior)
        saturation.plot_curve(
            curve,
            subplot_kwargs={
                "ncols": 4,
                "figsize": (12, 18),
            },
        )
    Masked likelihood over an arbitrary subset of entries (2D example over (date, country)):
    .. code-block:: python
        import numpy as np
        import xarray as xr
        import pymc as pm
        from pymc_extras.prior import Prior
        from pymc_marketing.special_priors import MaskedPrior
        coords = {
            "date": np.array(["2021-01-01", "2021-01-02"], dtype="datetime64[D]"),
            "country": ["Venezuela", "Colombia"],
        }
        mask = xr.DataArray(
            [[True, False], [True, False]],
            dims=["date", "country"],
            coords={"date": coords["date"], "country": coords["country"]},
        )
        intercept = Prior("Normal", mu=0, sigma=10, dims=("country",))
        likelihood = Prior(
            "Normal", sigma=Prior("HalfNormal", sigma=1), dims=("date", "country")
        )
        observed = np.random.normal(0, 1, size=(2, 2))
        with pm.Model(coords=coords):
            mu = intercept.create_variable("intercept")
            masked = MaskedPrior(likelihood, mask)
            y = masked.create_likelihood_variable("y", mu=mu, observed=observed)
    """
[docs]
    def __init__(
        self, prior: Prior, mask: xr.DataArray, active_dim: str | None = None
    ) -> None:
        self.prior = prior
        self.mask = mask
        self.dims = prior.dims
        self.active_dim = active_dim or f"non_null_dims:{'_'.join(self.dims)}"
        self._validate_mask()
        warnings.warn(
            "This class is experimental and its API may change in future versions.",
            stacklevel=2,
        ) 
    def _validate_mask(self) -> None:
        if tuple(self.mask.dims) != tuple(self.dims):
            raise ValueError("mask dims must match prior.dims order")
    def _remap_dims(self, factory: VariableFactory) -> VariableFactory:
        # Depth-first remap of any nested VariableFactory with dims == parent dims
        # This keeps internal subset checks (_param_dims_work) satisfied.
        if hasattr(factory, "parameters"):
            # Recurse on child parameters first
            for key, value in list(factory.parameters.items()):
                if hasattr(value, "create_variable") and hasattr(value, "dims"):
                    factory.parameters[key] = self._remap_dims(value)  # type: ignore[arg-type]
        # Now remap this object's dims if they exactly match the masked dims
        if hasattr(factory, "dims"):
            dims = factory.dims
            if isinstance(dims, str):
                dims = (dims,)
            if tuple(dims) == tuple(self.dims):
                factory.dims = (self.active_dim,)
        return factory
[docs]
    def create_variable(self, name: str) -> TensorVariable:
        """Create a deterministic variable with full dims using the active subset.
        Creates an underlying variable over the active entries only and expands
        it back to the full masked shape, filling inactive entries with zeros.
        Parameters
        ----------
        name : str
            Base name for the created variables.
        Returns
        -------
        pt.TensorVariable
            Deterministic variable with the original dims, zeros on inactive entries.
        """
        model = pm.modelcontext(None)
        flat_mask = self.mask.values.ravel().astype(bool)
        n_active = int(flat_mask.sum())
        if n_active == 0:
            return pm.Deterministic(name, pt.zeros(self.mask.shape), dims=self.dims)
        # Ensure the coord exists and has the right length
        if (
            self.active_dim in model.coords
            and len(model.coords[self.active_dim]) != n_active
        ):
            self.active_dim = f"{self.active_dim}__{n_active}"
        model.add_coords({self.active_dim: np.arange(n_active)})
        # Make a deep copy and remap dims depth-first before creating the RV
        reduced = self._remap_dims(self.prior.deepcopy())
        active_rv = reduced.create_variable(f"{name}_active")  # shape: (active_dim,)
        flat_full = pt.zeros((self.mask.size,), dtype=active_rv.dtype)
        full = flat_full[flat_mask].set(active_rv).reshape(self.mask.shape)
        return pm.Deterministic(name, full, dims=self.dims) 
[docs]
    def to_dict(self) -> dict[str, Any]:
        """Serialize MaskedPrior to a JSON-serializable dictionary.
        Returns
        -------
        dict
            Dictionary containing the prior, mask, and active_dim.
        """
        # Store mask as a plain nested list of bools to avoid datetime coords serialization
        mask_list = (
            self.mask.values.astype(bool).tolist()
            if hasattr(self.mask, "values")
            else np.asarray(self.mask, dtype=bool).tolist()
        )
        return {
            "class": "MaskedPrior",
            "data": {
                "prior": self.prior.to_dict()
                if hasattr(self.prior, "to_dict")
                else None,
                "mask": mask_list,
                "mask_dims": list(self.dims),
                "active_dim": self.active_dim,
            },
        } 
[docs]
    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> "MaskedPrior":
        """Deserialize MaskedPrior from dictionary created by ``to_dict``.
        Parameters
        ----------
        data : dict
            Dictionary produced by :meth:`to_dict`.
        Returns
        -------
        MaskedPrior
            Reconstructed instance.
        """
        payload = data["data"] if "data" in data else data
        prior = (
            deserialize(payload["prior"])
            if isinstance(payload.get("prior"), dict)
            else payload.get("prior")
        )
        mask_vals = payload.get("mask")
        # Fallback to provided dims or infer from prior if available
        mask_dims = payload.get("mask_dims") or (getattr(prior, "dims", None) or ())
        mask_da = xr.DataArray(np.asarray(mask_vals, dtype=bool), dims=tuple(mask_dims))
        active_dim = payload.get("active_dim")
        return cls(prior=prior, mask=mask_da, active_dim=active_dim) 
[docs]
    def create_likelihood_variable(
        self, name: str, *, mu: pt.TensorLike, observed: pt.TensorLike
    ) -> TensorVariable:
        """Create an observed variable over the active subset and expand to full dims.
        Parameters
        ----------
        name : str
            Base name for the created variables.
        mu : pt.TensorLike
            Mean/location parameter broadcastable to the masked shape.
        observed : pt.TensorLike
            Observations broadcastable to the masked shape.
        Returns
        -------
        pt.TensorVariable
            Deterministic variable over the full dims with observed RV on active entries.
        """
        model = pm.modelcontext(None)
        flat_mask = self.mask.values.ravel().astype(bool)
        n_active = int(flat_mask.sum())
        if n_active == 0:
            return pm.Deterministic(name, pt.zeros(self.mask.shape), dims=self.dims)
        # Ensure the coord exists and has the right length
        if (
            self.active_dim in model.coords
            and len(model.coords[self.active_dim]) != n_active
        ):
            self.active_dim = f"{self.active_dim}__{n_active}"
        model.add_coords({self.active_dim: np.arange(n_active)})
        # Remap dims on a deep copy so nested parameter priors match the active subset
        reduced = self._remap_dims(self.prior.deepcopy())
        # Broadcast mu/observed to full mask shape via arithmetic broadcasting, then select active entries
        mu_tensor = pt.as_tensor_variable(mu)
        mu_full = mu_tensor + pt.zeros(self.mask.shape, dtype=mu_tensor.dtype)
        mu_active = mu_full.reshape((self.mask.size,))[flat_mask]
        obs = observed.values if hasattr(observed, "values") else observed
        obs_tensor = pt.as_tensor_variable(obs)
        obs_full = obs_tensor + pt.zeros(self.mask.shape, dtype=obs_tensor.dtype)
        obs_active = obs_full.reshape((self.mask.size,))[flat_mask]
        # Create the masked observed RV over the active subset
        active_name = f"{name}_active"
        active_rv = reduced.create_likelihood_variable(
            active_name, mu=mu_active, observed=obs_active
        )
        # Expand back to full shape for user-friendly access
        flat_full = pt.zeros((self.mask.size,), dtype=active_rv.dtype)
        full = flat_full[flat_mask].set(active_rv).reshape(self.mask.shape)
        return pm.Deterministic(name, full, dims=self.dims) 
 
def _is_masked_prior_type(data: dict) -> bool:
    return data.keys() == {"class", "data"} and data.get("class") == "MaskedPrior"
register_deserialization(
    is_type=_is_masked_prior_type, deserialize=MaskedPrior.from_dict
)