#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Class that represents a prior distribution.
The `Prior` class is a wrapper around PyMC distributions that allows the user
to create outside of the PyMC model.
.. note::
    This module has been deprecated and is moved to `pymc_extras.prior`.
This is the alternative to using the dictionaries in PyMC-Marketing models.
Examples
--------
Create a normal prior.
.. code-block:: python
    from pymc_extras.prior import Prior
    normal = Prior("Normal")
Create a hierarchical normal prior by using distributions for the parameters
and specifying the dims.
.. code-block:: python
    hierarchical_normal = Prior(
        "Normal",
        mu=Prior("Normal"),
        sigma=Prior("HalfNormal"),
        dims="channel",
    )
Create a non-centered hierarchical normal prior with the `centered` parameter.
.. code-block:: python
    non_centered_hierarchical_normal = Prior(
        "Normal",
        mu=Prior("Normal"),
        sigma=Prior("HalfNormal"),
        dims="channel",
        # Only change needed to make it non-centered
        centered=False,
    )
Create a hierarchical beta prior by using Beta distribution, distributions for
the parameters, and specifying the dims.
.. code-block:: python
    hierarchical_beta = Prior(
        "Beta",
        alpha=Prior("HalfNormal"),
        beta=Prior("HalfNormal"),
        dims="channel",
    )
Create a transformed hierarchical normal prior by using the `transform`
parameter. Here the "sigmoid" transformation comes from `pm.math`.
.. code-block:: python
    transformed_hierarchical_normal = Prior(
        "Normal",
        mu=Prior("Normal"),
        sigma=Prior("HalfNormal"),
        transform="sigmoid",
        dims="channel",
    )
Create a prior with a custom transform function by registering it with
`register_tensor_transform`.
.. code-block:: python
    from pymc_extras.prior import register_tensor_transform
    def custom_transform(x):
        return x**2
    register_tensor_transform("square", custom_transform)
    custom_distribution = Prior("Normal", transform="square")
"""
from __future__ import annotations
import copy
import functools
import warnings
from typing import Any
from pymc_extras import prior
from pymc_extras.deserialize import deserialize, register_deserialization
[docs]
def is_alternative_prior(data: Any) -> bool:
    """Check if the data is a dictionary representing a Prior (alternative check)."""
    return isinstance(data, dict) and "distribution" in data 
[docs]
def deserialize_alternative_prior(data: dict[str, Any]) -> prior.Prior:
    """Alternative deserializer that recursively handles all nested parameters.
    This implementation is more general and handles cases where any parameter
    might be a nested prior, and also extracts centered and transform parameters.
    Examples
    --------
    This handles cases like:
    .. code-block:: yaml
        distribution: Gamma
        alpha: 1
        beta:
            distribution: HalfNormal
            sigma: 1
            dims: channel
        dims: [brand, channel]
    """
    data = copy.deepcopy(data)
    distribution = data.pop("distribution")
    dims = data.pop("dims", None)
    centered = data.pop("centered", True)
    transform = data.pop("transform", None)
    parameters = data
    # Recursively deserialize any nested parameters
    parameters = {
        key: value if not isinstance(value, dict) else deserialize(value)
        for key, value in parameters.items()
    }
    return prior.Prior(
        distribution,
        transform=transform,
        centered=centered,
        dims=dims,
        **parameters,
    ) 
# Register the alternative prior deserializer for more complex nested cases
register_deserialization(is_alternative_prior, deserialize_alternative_prior)
[docs]
def warn_class_deprecation(func):
    """Warn about the deprecation of this module."""
    @functools.wraps(func)
    def wrapper(self, *args, **kwargs):
        name = self.__class__.__name__
        warnings.warn(
            f"The {name} class has moved to pymc_extras.prior module and will be removed in a future release. "
            f"Import it from `from pymc_extras.prior import {name}`. ",
            DeprecationWarning,
            stacklevel=2,
        )
        return func(self, *args, **kwargs)
    return wrapper 
[docs]
def warn_function_deprecation(func):
    """Warn about the deprecation of this function."""
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        name = func.__name__
        warnings.warn(
            f"The {name} function has moved to pymc_extras.prior module and will be removed in a future release. "
            f"Import it from `from pymc_extras.prior import {name}`.",
            DeprecationWarning,
            stacklevel=2,
        )
        return func(*args, **kwargs)
    return wrapper 
[docs]
class Prior(prior.Prior):
    """Backwards-compatible wrapper for the Prior class."""
[docs]
    @warn_class_deprecation
    def __init__(self, *args, **kwargs):
        """Initialize the Prior class with the given arguments."""
        super().__init__(*args, **kwargs) 
 
[docs]
class Censored(prior.Censored):
    """Backwards-compatible wrapper for the CensoredPrior class."""
[docs]
    @warn_class_deprecation
    def __init__(self, *args, **kwargs):
        """Initialize the CensoredPrior class with the given arguments."""
        super().__init__(*args, **kwargs) 
 
[docs]
class Scaled(prior.Scaled):
    """Backwards-compatible wrapper for the ScaledPrior class."""
[docs]
    @warn_class_deprecation
    def __init__(self, *args, **kwargs):
        """Initialize the ScaledPrior class with the given arguments."""
        super().__init__(*args, **kwargs) 
 
sample_prior = warn_function_deprecation(prior.sample_prior)
create_dim_handler = warn_function_deprecation(prior.create_dim_handler)
handle_dims = warn_function_deprecation(prior.handle_dims)
register_tensor_transform = warn_function_deprecation(prior.register_tensor_transform)