#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Saturation transformations for the MMM model.
Each of these transformations is a subclass of
:class:`pymc_marketing.mmm.components.saturation.SaturationTransformation` and defines a function
that takes media and return the saturated media. The parameters of the function
are the parameters of the saturation transformation.
Examples
--------
Create a new saturation transformation:
.. code-block:: python
    from pymc_marketing.mmm import SaturationTransformation
    from pymc_extras.prior import Prior
    class InfiniteReturns(SaturationTransformation):
        lookup_name: str = "infinite_returns"
        def function(self, x, b):
            return b * x
        default_priors = {"b": Prior("HalfNormal", sigma=1)}
Plot the default priors for a saturation transformation:
.. code-block:: python
    from pymc_marketing.mmm import HillSaturation
    import matplotlib.pyplot as plt
    saturation = HillSaturation()
    prior = saturation.sample_prior()
    curve = saturation.sample_curve(prior)
    saturation.plot_curve(curve)
    plt.show()
Define a hierarchical saturation function with only hierarchical parameters
for saturation parameter of logistic saturation.
.. code-block:: python
    from pymc_extras.prior import Prior
    from pymc_marketing.mmm import LogisticSaturation
    hierarchical_lam = Prior(
        "Gamma",
        alpha=Prior("HalfNormal"),
        beta=Prior("HalfNormal"),
        dims="channel",
    )
    priors = {
        "lam": hierarchical_lam,
        "beta": Prior("HalfNormal", dims="channel"),
    }
    saturation = LogisticSaturation(priors=priors)
"""
from __future__ import annotations
import numpy as np
import pytensor.tensor as pt
import xarray as xr
from pydantic import Field, InstanceOf, validate_call
from pymc_extras.deserialize import deserialize, register_deserialization
from pymc_extras.prior import Prior
from pymc_marketing.mmm.components.base import (
    Transformation,
    create_registration_meta,
)
from pymc_marketing.mmm.transformers import (
    hill_function,
    hill_saturation_sigmoid,
    inverse_scaled_logistic_saturation,
    logistic_saturation,
    michaelis_menten,
    root_saturation,
    tanh_saturation,
    tanh_saturation_baselined,
)
SATURATION_TRANSFORMATIONS: dict[str, type[SaturationTransformation]] = {}
SaturationRegistrationMeta = create_registration_meta(SATURATION_TRANSFORMATIONS)
[docs]
class LogisticSaturation(SaturationTransformation):
    """Wrapper around logistic saturation function.
    For more information, see :func:`pymc_marketing.mmm.transformers.logistic_saturation`.
    .. plot::
        :context: close-figs
        import matplotlib.pyplot as plt
        import numpy as np
        from pymc_marketing.mmm import LogisticSaturation
        rng = np.random.default_rng(0)
        adstock = LogisticSaturation()
        prior = adstock.sample_prior(random_seed=rng)
        curve = adstock.sample_curve(prior)
        adstock.plot_curve(curve, random_seed=rng)
        plt.show()
    """
    lookup_name = "logistic"
[docs]
    def function(self, x, lam, beta):
        """Logistic saturation function."""
        return beta * logistic_saturation(x, lam) 
    default_priors = {
        "lam": Prior("Gamma", alpha=3, beta=1),
        "beta": Prior("HalfNormal", sigma=2),
    } 
[docs]
class InverseScaledLogisticSaturation(SaturationTransformation):
    """Wrapper around inverse scaled logistic saturation function.
    For more information, see :func:`pymc_marketing.mmm.transformers.inverse_scaled_logistic_saturation`.
    .. plot::
        :context: close-figs
        import matplotlib.pyplot as plt
        import numpy as np
        from pymc_marketing.mmm import InverseScaledLogisticSaturation
        rng = np.random.default_rng(0)
        adstock = InverseScaledLogisticSaturation()
        prior = adstock.sample_prior(random_seed=rng)
        curve = adstock.sample_curve(prior)
        adstock.plot_curve(curve, random_seed=rng)
        plt.show()
    """
    lookup_name = "inverse_scaled_logistic"
[docs]
    def function(self, x, lam, beta):
        """Inverse scaled logistic saturation function."""
        return beta * inverse_scaled_logistic_saturation(x, lam) 
    default_priors = {
        "lam": Prior("Gamma", alpha=0.5, beta=1),
        "beta": Prior("HalfNormal", sigma=2),
    } 
[docs]
class TanhSaturation(SaturationTransformation):
    """Wrapper around tanh saturation function.
    For more information, see :func:`pymc_marketing.mmm.transformers.tanh_saturation`.
    .. plot::
        :context: close-figs
        import matplotlib.pyplot as plt
        import numpy as np
        from pymc_marketing.mmm import TanhSaturation
        rng = np.random.default_rng(0)
        adstock = TanhSaturation()
        prior = adstock.sample_prior(random_seed=rng)
        curve = adstock.sample_curve(prior)
        adstock.plot_curve(curve, random_seed=rng)
        plt.show()
    """
    lookup_name = "tanh"
[docs]
    def function(self, x, b, c):
        """Tanh saturation function."""
        return tanh_saturation(x, b, c) 
    default_priors = {
        "b": Prior("HalfNormal", sigma=1),
        "c": Prior("HalfNormal", sigma=1),
    } 
[docs]
class TanhSaturationBaselined(SaturationTransformation):
    """Wrapper around tanh saturation function.
    For more information, see :func:`pymc_marketing.mmm.transformers.tanh_saturation_baselined`.
    .. plot::
        :context: close-figs
        import matplotlib.pyplot as plt
        import numpy as np
        from pymc_marketing.mmm import TanhSaturationBaselined
        rng = np.random.default_rng(0)
        adstock = TanhSaturationBaselined()
        prior = adstock.sample_prior(random_seed=rng)
        curve = adstock.sample_curve(prior)
        adstock.plot_curve(curve, random_seed=rng)
        plt.show()
    """
    lookup_name = "tanh_baselined"
[docs]
    def function(self, x, x0, gain, r, beta):
        """Tanh saturation function."""
        return beta * tanh_saturation_baselined(x, x0, gain, r) 
    default_priors = {
        "x0": Prior("HalfNormal", sigma=1),
        "gain": Prior("HalfNormal", sigma=1),
        "r": Prior("HalfNormal", sigma=1),
        "beta": Prior("HalfNormal", sigma=1),
    } 
[docs]
class MichaelisMentenSaturation(SaturationTransformation):
    """Wrapper around Michaelis-Menten saturation function.
    For more information, see :func:`pymc_marketing.mmm.transformers.michaelis_menten`.
    .. plot::
        :context: close-figs
        import matplotlib.pyplot as plt
        import numpy as np
        from pymc_marketing.mmm import MichaelisMentenSaturation
        rng = np.random.default_rng(0)
        adstock = MichaelisMentenSaturation()
        prior = adstock.sample_prior(random_seed=rng)
        curve = adstock.sample_curve(prior)
        adstock.plot_curve(curve, random_seed=rng)
        plt.show()
    """
    lookup_name = "michaelis_menten"
[docs]
    def function(self, x, alpha, lam):
        """Michaelis-Menten saturation function."""
        return pt.as_tensor_variable(michaelis_menten(x, alpha, lam)) 
    default_priors = {
        "alpha": Prior("Gamma", mu=2, sigma=1),
        "lam": Prior("HalfNormal", sigma=1),
    } 
[docs]
class HillSaturation(SaturationTransformation):
    """Wrapper around Hill saturation function.
    For more information, see :func:`pymc_marketing.mmm.transformers.hill_function`.
    .. plot::
        :context: close-figs
        import matplotlib.pyplot as plt
        import numpy as np
        from pymc_marketing.mmm import HillSaturation
        rng = np.random.default_rng(0)
        adstock = HillSaturation()
        prior = adstock.sample_prior(random_seed=rng)
        curve = adstock.sample_curve(prior)
        adstock.plot_curve(curve, random_seed=rng)
        plt.show()
    """
    lookup_name = "hill"
[docs]
    def function(self, x, slope, kappa, beta):
        """Hill saturation function."""
        return beta * hill_function(x, slope, kappa) 
    default_priors = {
        "slope": Prior("HalfNormal", sigma=1.5),
        "kappa": Prior("HalfNormal", sigma=1.5),
        "beta": Prior("HalfNormal", sigma=1.5),
    } 
[docs]
class HillSaturationSigmoid(SaturationTransformation):
    """Wrapper around Hill saturation sigmoid function.
    For more information, see :func:`pymc_marketing.mmm.transformers.hill_saturation_sigmoid`.
    .. plot::
        :context: close-figs
        import matplotlib.pyplot as plt
        import numpy as np
        from pymc_marketing.mmm import HillSaturationSigmoid
        rng = np.random.default_rng(0)
        adstock = HillSaturationSigmoid()
        prior = adstock.sample_prior(random_seed=rng)
        curve = adstock.sample_curve(prior)
        adstock.plot_curve(curve, random_seed=rng)
        plt.show()
    """
    lookup_name = "hill_sigmoid"
    function = hill_saturation_sigmoid
    default_priors = {
        "sigma": Prior("HalfNormal", sigma=1.5),
        "beta": Prior("HalfNormal", sigma=1.5),
        "lam": Prior("HalfNormal", sigma=1.5),
    } 
[docs]
class RootSaturation(SaturationTransformation):
    """Wrapper around Root saturation function.
    For more information, see :func:`pymc_marketing.mmm.transformers.root_saturation`.
    .. plot::
        :context: close-figs
        import matplotlib.pyplot as plt
        import numpy as np
        from pymc_marketing.mmm import RootSaturation
        rng = np.random.default_rng(0)
        saturation = RootSaturation()
        prior = saturation.sample_prior(random_seed=rng)
        curve = saturation.sample_curve(prior)
        saturation.plot_curve(curve, random_seed=rng)
        plt.show()
    """
    lookup_name = "root"
[docs]
    def function(self, x, alpha, beta):
        """Root saturation function."""
        return beta * root_saturation(x, alpha) 
    default_priors = {
        "alpha": Prior("Beta", alpha=1, beta=2),
        "beta": Prior("Gamma", mu=1, sigma=1),
    } 
[docs]
class NoSaturation(SaturationTransformation):
    """Wrapper around linear saturation function.
    .. plot::
        :context: close-figs
        import matplotlib.pyplot as plt
        import numpy as np
        from pymc_marketing.mmm import NoSaturation
        rng = np.random.default_rng(0)
        saturation = NoSaturation()
        prior = saturation.sample_prior(random_seed=rng)
        curve = saturation.sample_curve(prior)
        saturation.plot_curve(curve, random_seed=rng)
        plt.show()
    """
    lookup_name = "no_saturation"
[docs]
    def function(self, x, beta):
        """Linear saturation function."""
        return pt.as_tensor_variable(beta * x) 
    default_priors = {"beta": Prior("HalfNormal", sigma=1)} 
[docs]
def saturation_from_dict(data: dict) -> SaturationTransformation:
    """Get a saturation function from a dictionary."""
    data = data.copy()
    cls = SATURATION_TRANSFORMATIONS[data.pop("lookup_name")]
    if "priors" in data:
        data["priors"] = {
            key: deserialize(value) for key, value in data["priors"].items()
        }
    return cls(**data) 
def _is_saturation(data):
    return "lookup_name" in data and data["lookup_name"] in SATURATION_TRANSFORMATIONS
register_deserialization(_is_saturation, saturation_from_dict)