#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Fourier seasonality transformations.
This modules provides Fourier seasonality transformations for use in
Marketing Mix Models. The Fourier seasonality is a set of sine and cosine
functions that can be used to model periodic patterns in the data.
There are two types of Fourier seasonality transformations available:
- Yearly Fourier: A yearly seasonality with a period of 365.25 days
- Monthly Fourier: A monthly seasonality with a period of 365.25 / 12 days
- Weekly Fourier: A weekly seasonality with a period of 7 days
.. plot::
    :context: close-figs
    import matplotlib.pyplot as plt
    import numpy as np
    import arviz as az
    from pymc_marketing.mmm import YearlyFourier
    from pymc_extras.prior import Prior
    plt.style.use('arviz-darkgrid')
    prior = Prior(
        "Normal",
        mu=[0, 0, -1, 0],
        sigma=Prior("Gamma", mu=0.10, sigma=0.1, dims="fourier"),
        dims=("hierarchy", "fourier"),
    )
    yearly = YearlyFourier(n_order=2, prior=prior)
    coords = {"hierarchy": ["A", "B"]}
    prior = yearly.sample_prior(coords=coords)
    curve = yearly.sample_curve(prior)
    fig, _ = yearly.plot_curve(curve, subplot_kwargs={"ncols": 1})
    fig.suptitle("Yearly Fourier Seasonality")
    plt.show()
Examples
--------
Use yearly fourier seasonality for custom Marketing Mix Model.
.. code-block:: python
    import pandas as pd
    import pymc as pm
    from pymc_marketing.mmm import YearlyFourier
    yearly = YearlyFourier(n_order=3)
    dates = pd.date_range("2023-01-01", periods=52, freq="W-MON")
    dayofyear = dates.dayofyear.to_numpy()
    with pm.Model() as model:
        fourier_trend = yearly.apply(dayofyear)
Plot the prior fourier seasonality trend.
.. code-block:: python
    import matplotlib.pyplot as plt
    prior = yearly.sample_prior()
    curve = yearly.sample_curve(prior)
    yearly.plot_curve(curve)
    plt.show()
Change the prior distribution of the fourier seasonality.
.. code-block:: python
    from pymc_marketing.mmm import YearlyFourier
    from pymc_extras.prior import Prior
    prior = Prior("Normal", mu=0, sigma=0.10)
    yearly = YearlyFourier(n_order=6, prior=prior)
Even make it hierarchical...
.. code-block:: python
    from pymc_marketing.mmm import YearlyFourier
    from pymc_extras.prior import Prior
    # "fourier" is the default prefix!
    prior = Prior(
        "Laplace",
        mu=Prior("Normal", dims="fourier"),
        b=Prior("HalfNormal", sigma=0.1, dims="fourier"),
        dims=("fourier", "hierarchy"),
    )
    yearly = YearlyFourier(n_order=3, prior=prior)
All the plotting will still work! Just pass any coords.
.. code-block:: python
    import matplotlib.pyplot as plt
    coords = {"hierarchy": ["A", "B", "C"]}
    prior = yearly.sample_prior(coords=coords)
    curve = yearly.sample_curve(prior)
    yearly.plot_curve(curve)
    plt.show()
Out of sample predictions with fourier seasonality by changing the day of year
used in the model.
.. code-block:: python
    import pandas as pd
    import pymc as pm
    from pymc_marketing.mmm import YearlyFourier
    periods = 52 * 3
    dates = pd.date_range("2022-01-01", periods=periods, freq="W-MON")
    training_dates = dates[: 52 * 2]
    testing_dates = dates[52 * 2 :]
    yearly = YearlyFourier(n_order=3)
    coords = {
        "date": training_dates,
    }
    with pm.Model(coords=coords) as model:
        dayofyear = pm.Data(
            "dayofyear",
            training_dates.dayofyear.to_numpy(),
            dims="date",
        )
        trend = pm.Deterministic(
            "trend",
            yearly.apply(dayofyear),
            dims="date",
        )
        idata = pm.sample_prior_predictive().prior
    with model:
        pm.set_data(
            {"dayofyear": testing_dates.dayofyear.to_numpy()},
            coords={"date": testing_dates},
        )
        out_of_sample = pm.sample_posterior_predictive(
            idata,
            var_names=["trend"],
        ).posterior_predictive["trend"]
Use yearly and monthly fourier seasonality together.
By default, the prefix of the fourier seasonality is set to "fourier". However,
the prefix can be changed upon initialization in order to avoid variable name
conflicts.
.. code-block:: python
    import pandas as pd
    import pymc as pm
    from pymc_marketing.mmm import (
        MonthlyFourier,
        YearlyFourier,
    )
    yearly = YearlyFourier(n_order=6, prefix="yearly")
    monthly = MonthlyFourier(n_order=3, prefix="monthly")
    dates = pd.date_range("2023-01-01", periods=52, freq="W-MON")
    dayofyear = dates.dayofyear.to_numpy()
    coords = {
        "date": dates,
    }
    with pm.Model(coords=coords) as model:
        yearly_trend = yearly.apply(dayofyear)
        monthly_trend = monthly.apply(dayofyear)
        trend = pm.Deterministic(
            "trend",
            yearly_trend + monthly_trend,
            dims="date",
        )
    with model:
        prior_samples = pm.sample_prior_predictive().prior
"""
import datetime
from abc import abstractmethod
from collections.abc import Callable, Iterable
from typing import Any, Self
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pandas as pd
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    InstanceOf,
    field_serializer,
    model_validator,
)
from pymc_extras.deserialize import deserialize, register_deserialization
from pymc_extras.prior import Prior, VariableFactory, create_dim_handler
from pymc_marketing.constants import DAYS_IN_MONTH, DAYS_IN_WEEK, DAYS_IN_YEAR
from pymc_marketing.plot import SelToString, plot_curve, plot_hdi, plot_samples
X_NAME: str = "day"
NON_GRID_NAMES: frozenset[str] = frozenset({X_NAME})
[docs]
def generate_fourier_modes(
    periods: pt.TensorLike,
    n_order: int,
) -> pt.TensorVariable:
    """Create fourier modes for a given period.
    Parameters
    ----------
    periods : pt.TensorLike
        Periods to generate fourier modes for.
    n_order : int
        Number of fourier modes to generate.
    Returns
    -------
    pt.TensorVariable
        Fourier modes.
    """
    multiples = pt.arange(1, n_order + 1)
    x = 2 * pt.pi * periods
    values = x[:, None] * multiples
    return pt.concatenate(
        [
            pt.sin(values),
            pt.cos(values),
        ],
        axis=1,
    ) 
[docs]
class FourierBase(BaseModel):
    """Base class for Fourier seasonality transformations.
    Parameters
    ----------
    n_order : int
        Number of fourier modes to use.
    days_in_period : float
        Number of days in a period.
    prefix : str, optional
        Alternative prefix for the fourier seasonality, by default None or
        "fourier"
    prior : Prior | VariableFactory, optional
        Prior distribution or VariableFactory for the fourier seasonality beta parameters, by
        default `Prior("Laplace", mu=0, b=1)`
    variable_name : str, optional
        Name of the variable that multiplies the fourier modes. By default None,
        in which case it is set to the `{prefix}_beta`.
    """
    n_order: int = Field(..., gt=0)
    days_in_period: float = Field(..., gt=0)
    prefix: str = Field("fourier")
    prior: InstanceOf[Prior] | InstanceOf[VariableFactory] = Field(
        Prior("Laplace", mu=0, b=1)
    )
    variable_name: str | None = Field(None)
    model_config = ConfigDict(extra="forbid")
[docs]
    def model_post_init(self, __context: Any) -> None:
        """Model post initialization for a Pydantic model."""
        if self.variable_name is None:
            self.variable_name = f"{self.prefix}_beta"
        if not self.prior.dims and isinstance(self.prior, Prior):
            self.prior = self.prior.deepcopy()
            self.prior.dims = self.prefix
        elif not self.prior.dims:
            self.prior.dims = self.prefix 
    @model_validator(mode="after")
    def _check_variable_name(self) -> Self:
        if self.variable_name == self.prefix:
            raise ValueError("Variable name cannot be the same as the prefix")
        return self
    @model_validator(mode="after")
    def _check_prior_has_right_dimensions(self) -> Self:
        if self.prefix not in self.prior.dims:
            raise ValueError(f"Prior distribution must have dimension {self.prefix}")
        return self
[docs]
    @field_serializer("prior", when_used="json")
    def serialize_prior(prior: Any) -> dict[str, Any]:
        """Serialize the prior distribution.
        Parameters
        ----------
        prior : VariableFactory | Prior
            The prior distribution to serialize.
        Returns
        -------
        dict[str, Any]
            The serialized prior distribution.
        """
        return prior.to_dict() 
    @property
    def nodes(self) -> list[str]:
        """Fourier node names for model coordinates."""
        return [
            f"{func}_{i}" for func in ["sin", "cos"] for i in range(1, self.n_order + 1)
        ]
[docs]
    def get_default_start_date(
        self,
        start_date: str | datetime.datetime | None = None,
    ) -> str | datetime.datetime:
        """Get the start date for the Fourier curve.
        If `start_date` is provided, validate its type.
        Otherwise, provide the default start date based on the subclass implementation.
        Parameters
        ----------
        start_date : str or datetime.datetime, optional
            Provided start date. Can be a string or a datetime object.
        Returns
        -------
        str or datetime.datetime
            The validated start date.
        Raises
        ------
        TypeError
            If `start_date` is neither a string nor a datetime object.
        """
        if start_date is None:
            return self._get_default_start_date()
        elif isinstance(start_date, str) | isinstance(start_date, datetime.datetime):
            return start_date
        else:
            raise TypeError(
                "start_date must be a datetime.datetime object, a string, or None"
            ) 
    @abstractmethod
    def _get_default_start_date(self) -> datetime.datetime:
        """Provide the default start date. Must be implemented by subclasses.
        Returns
        -------
        datetime.datetime
            The default start date.
        """
        pass  # pragma: no cover
    @abstractmethod
    def _get_days_in_period(self, dates: pd.DatetimeIndex) -> pd.Index:
        """Return the relevant day within the characteristic periodicity.
        Returns
        -------
        int or float
            The relevant period within the characteristic periodicity
        """
        pass
[docs]
    def apply(
        self,
        dayofperiod: pt.TensorLike,
        result_callback: Callable[[pt.TensorVariable], None] | None = None,
    ) -> pt.TensorVariable:
        """Apply fourier seasonality to day of year.
        Must be used within a PyMC model context.
        Parameters
        ----------
        dayofperiod : pt.TensorLike
            Day of year or weekday
        result_callback : Callable[[pt.TensorVariable], None], optional
            Callback function to apply to the result, by default None
        Returns
        -------
        pt.TensorVariable
            Fourier seasonality
        Examples
        --------
        Save off the result before summing through the prefix dimension.
        .. code-block:: python
            import pandas as pd
            import pymc as pm
            from pymc_marketing.mmm import YearlyFourier
            fourier = YearlyFourier(n_order=3)
            def callback(result):
                pm.Deterministic("fourier_trend", result, dims=("date", "fourier"))
            dates = pd.date_range("2023-01-01", periods=52, freq="W-MON")
            coords = {
                "date": dates,
            }
            with pm.Model(coords=coords) as model:
                dayofyear = dates.dayofyear.to_numpy()
                fourier.apply(dayofyear, result_callback=callback)
        """
        periods = dayofperiod / self.days_in_period
        model = pm.modelcontext(None)
        model.add_coord(self.prefix, self.nodes)
        beta = self.prior.create_variable(self.variable_name)
        fourier_modes = generate_fourier_modes(periods=periods, n_order=self.n_order)
        DUMMY_DIM = "DATE"
        prefix_idx = self.prior.dims.index(self.prefix)
        result_dims = (DUMMY_DIM, *self.prior.dims)
        dim_handler = create_dim_handler(result_dims)
        result = dim_handler(fourier_modes, (DUMMY_DIM, self.prefix)) * dim_handler(
            beta, self.prior.dims
        )
        if result_callback is not None:
            result_callback(result)
        return result.sum(axis=prefix_idx + 1) 
[docs]
    def sample_prior(self, coords: dict | None = None, **kwargs) -> xr.Dataset:
        """Sample the prior distributions.
        Parameters
        ----------
        coords : dict, optional
            Coordinates for the prior distribution, by default None
        kwargs
            Additional keywords for sample_prior_predictive
        Returns
        -------
        xr.Dataset
            Prior distribution.
        """
        coords = coords or {}
        coords[self.prefix] = self.nodes
        return self.prior.sample_prior(coords=coords, name=self.variable_name, **kwargs) 
[docs]
    def sample_curve(
        self,
        parameters: az.InferenceData | xr.Dataset,
        use_dates: bool = False,
        start_date: str | datetime.datetime | None = None,
    ) -> xr.DataArray:
        """Create full period of the Fourier seasonality.
        Parameters
        ----------
        parameters : az.InferenceData | xr.Dataset
            Inference data or dataset containing the Fourier parameters.
            Can be posterior or prior.
        use_dates : bool, optional
            If True, use datetime coordinates for the x-axis. Defaults to False.
        start_date : datetime.datetime, optional
            Starting date for the Fourier curve. If not provided and use_dates is True,
            it will be derived from the current year or month. Defaults to None.
        Returns
        -------
        xr.DataArray
            Full period of the Fourier seasonality.
        """
        full_period = np.arange(self.days_in_period + 1)
        coords: dict[str, npt.NDArray[Any]] = {}
        if use_dates:
            start_date = self.get_default_start_date(start_date=start_date)
            date_range = pd.date_range(
                start=start_date,
                periods=int(np.ceil(self.days_in_period) + 1),
                freq="D",
            )
            coords["date"] = date_range.to_numpy()
            dayofperiod = self._get_days_in_period(date_range).to_numpy()
        else:
            coords["day"] = full_period
            dayofperiod = full_period
        for key, values in parameters[self.variable_name].coords.items():
            if key in {"chain", "draw", self.prefix}:
                continue
            coords[key] = values.to_numpy()
        with pm.Model(coords=coords):
            name = f"{self.prefix}_trend"
            pm.Deterministic(
                name,
                self.apply(dayofperiod=dayofperiod),
                dims=tuple(coords.keys()),
            )
            return pm.sample_posterior_predictive(
                parameters,
                var_names=[name],
            ).posterior_predictive[name] 
[docs]
    def plot_curve(
        self,
        curve: xr.DataArray,
        n_samples: int = 10,
        hdi_probs: float | list[float] | None = None,
        random_seed: np.random.Generator | None = None,
        subplot_kwargs: dict | None = None,
        sample_kwargs: dict | None = None,
        hdi_kwargs: dict | None = None,
        axes: npt.NDArray[plt.Axes] | None = None,
        same_axes: bool = False,
        colors: Iterable[str] | None = None,
        legend: bool | None = None,
        sel_to_string: SelToString | None = None,
    ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]:
        """Plot the seasonality for one full period.
        Parameters
        ----------
        curve : xr.DataArray
            Sampled full period of the fourier seasonality.
        n_samples : int, optional
            Number of samples
        hdi_probs : float | list[float], optional
            HDI probabilities. Defaults to None which uses arviz default for
            stats.ci_prob which is 94%
        random_seed : int | random number generator, optional
            Random number generator. Defaults to None
        subplot_kwargs : dict, optional
            Keyword arguments for the subplot, by default None
        sample_kwargs : dict, optional
            Keyword arguments for the plot_full_period_samples method, by default None
        hdi_kwargs : dict, optional
            Keyword arguments for the plot_full_period_hdi method, by default None
        axes : npt.NDArray[plt.Axes], optional
            Matplotlib axes, by default None
        same_axes : bool, optional
            Use the same axes for all plots, by default False
        colors : Iterable[str], optional
            Colors for the different plots, by default None
        legend : bool, optional
            Show the legend, by default None
        sel_to_string : SelToString, optional
            Function to convert the selection to a string, by default None
        Returns
        -------
        tuple[plt.Figure, npt.NDArray[plt.Axes]]
            Matplotlib figure and axes.
        """
        if "date" in curve.coords:
            x_coord_name = "date"
        elif "day" in curve.coords:
            x_coord_name = "day"
        else:
            raise ValueError("Curve must have either 'day' or 'date' as a coordinate")
        return plot_curve(
            curve,
            non_grid_names={x_coord_name},
            n_samples=n_samples,
            hdi_probs=hdi_probs,
            random_seed=random_seed,
            subplot_kwargs=subplot_kwargs,
            sample_kwargs=sample_kwargs,
            hdi_kwargs=hdi_kwargs,
            axes=axes,
            same_axes=same_axes,
            colors=colors,
            legend=legend,
            sel_to_string=sel_to_string,
        ) 
[docs]
    def plot_curve_hdi(
        self,
        curve: xr.DataArray,
        hdi_kwargs: dict | None = None,
        subplot_kwargs: dict[str, Any] | None = None,
        plot_kwargs: dict[str, Any] | None = None,
        axes: npt.NDArray[plt.Axes] | None = None,
    ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]:
        """Plot full period of the fourier seasonality.
        Parameters
        ----------
        curve : xr.DataArray
            The curve to plot.
        hdi_kwargs : dict, optional
            Keyword arguments for the az.hdi function. Defaults to None.
        plot_kwargs : dict, optional
            Keyword arguments for the fill_between function. Defaults to None.
        subplot_kwargs : dict, optional
            Keyword arguments for plt.subplots
        axes : npt.NDArray[plt.Axes], optional
            The exact axes to plot on. Overrides any subplot_kwargs
        Returns
        -------
        tuple[plt.Figure, npt.NDArray[plt.Axes]]
        """
        if "date" in curve.coords:
            x_coord_name = "date"
        elif "day" in curve.coords:
            x_coord_name = "day"
        else:
            raise ValueError("Curve must have either 'day' or 'date' as a coordinate")
        return plot_hdi(
            curve,
            non_grid_names={x_coord_name},
            hdi_kwargs=hdi_kwargs,
            subplot_kwargs=subplot_kwargs,
            plot_kwargs=plot_kwargs,
            axes=axes,
        ) 
[docs]
    def plot_curve_samples(
        self,
        curve: xr.DataArray,
        n: int = 10,
        rng: np.random.Generator | None = None,
        plot_kwargs: dict[str, Any] | None = None,
        subplot_kwargs: dict[str, Any] | None = None,
        axes: npt.NDArray[plt.Axes] | None = None,
    ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]:
        """Plot samples from the curve.
        Parameters
        ----------
        curve : xr.DataArray
            Samples from the curve.
        n : int, optional
            Number of samples to plot, by default 10
        rng : np.random.Generator, optional
            Random number generator, by default None
        plot_kwargs : dict, optional
            Keyword arguments for the plot function, by default None
        subplot_kwargs : dict, optional
            Keyword arguments for the subplot, by default None
        axes : npt.NDArray[plt.Axes], optional
            Matplotlib axes, by default None
        Returns
        -------
        tuple[plt.Figure, npt.NDArray[plt.Axes]]
            Matplotlib figure and axes.
        """
        if "date" in curve.coords:
            x_coord_name = "date"
        elif "day" in curve.coords:
            x_coord_name = "day"
        else:
            raise ValueError("Curve must have either 'day' or 'date' as a coordinate")
        return plot_samples(
            curve,
            non_grid_names={x_coord_name},
            n=n,
            rng=rng,
            axes=axes,
            subplot_kwargs=subplot_kwargs,
            plot_kwargs=plot_kwargs,
        ) 
[docs]
    def to_dict(self) -> dict[str, Any]:
        """Serialize the Fourier seasonality.
        Returns
        -------
        dict[str, Any]
            Serialized Fourier seasonality
        """
        return {
            "class": self.__class__.__name__,
            "data": self.model_dump(mode="json"),
        } 
[docs]
    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> Self:
        """Deserialize the Fourier seasonality.
        Parameters
        ----------
        data : dict[str, Any]
            Serialized Fourier seasonality
        Returns
        -------
        FourierBase
            Deserialized Fourier seasonality
        """
        data = data["data"]
        data["prior"] = deserialize(data["prior"])
        return cls(**data) 
 
[docs]
class YearlyFourier(FourierBase):
    """Yearly fourier seasonality.
    .. plot::
        :context: close-figs
        import arviz as az
        import matplotlib.pyplot as plt
        import numpy as np
        from pymc_marketing.mmm import YearlyFourier
        from pymc_extras.prior import Prior
        az.style.use("arviz-white")
        seed = sum(map(ord, "Yearly"))
        rng = np.random.default_rng(seed)
        mu = np.array([0, 0, -1, 0])
        b = 0.15
        dist = Prior("Laplace", mu=mu, b=b, dims="fourier")
        yearly = YearlyFourier(n_order=2, prior=dist)
        prior = yearly.sample_prior(random_seed=rng)
        curve = yearly.sample_curve(prior)
        _, axes = yearly.plot_curve(curve)
        axes[0].set(title="Yearly Fourier Seasonality")
        plt.show()
    n_order : int
        Number of fourier modes to use.
    prefix : str, optional
        Alternative prefix for the fourier seasonality, by default None or
        "fourier"
    prior : Prior | VariableFactory, optional
        Prior distribution or VariableFactory for the fourier seasonality beta parameters, by
        default `Prior("Laplace", mu=0, b=1)`
    name : str, optional
        Name of the variable that multiplies the fourier modes, by default None
    variable_name : str, optional
        Name of the variable that multiplies the fourier modes, by default None
    """
    days_in_period: float = DAYS_IN_YEAR
    def _get_default_start_date(self) -> datetime.datetime:
        """Get the default start date for yearly seasonality.
        Returns January 1st of the current year.
        """
        current_year = datetime.datetime.now().year
        return datetime.datetime(year=current_year, month=1, day=1)
    def _get_days_in_period(self, dates: pd.DatetimeIndex) -> pd.Index:
        """Return the dayofyear within the yearly periodicity.
        Returns
        -------
        int or float
            The relevant period within the characteristic periodicity
        """
        return dates.dayofyear 
[docs]
class MonthlyFourier(FourierBase):
    """Monthly fourier seasonality.
    .. plot::
        :context: close-figs
        import arviz as az
        import matplotlib.pyplot as plt
        import numpy as np
        from pymc_marketing.mmm import MonthlyFourier
        from pymc_extras.prior import Prior
        az.style.use("arviz-white")
        seed = sum(map(ord, "Monthly"))
        rng = np.random.default_rng(seed)
        mu = np.array([0, 0, 0.5, 0])
        b = 0.075
        dist = Prior("Laplace", mu=mu, b=b, dims="fourier")
        monthly = MonthlyFourier(n_order=2, prior=dist)
        prior = monthly.sample_prior(samples=100)
        curve = monthly.sample_curve(prior)
        _, axes = monthly.plot_curve(curve)
        axes[0].set(title="Monthly Fourier Seasonality")
        plt.show()
    n_order : int
        Number of fourier modes to use.
    prefix : str, optional
        Alternative prefix for the fourier seasonality, by default None or
        "fourier"
    prior : Prior | VariableFactory, optional
        Prior distribution or VariableFactory for the fourier seasonality beta parameters, by
        default `Prior("Laplace", mu=0, b=1)`
    name : str, optional
        Name of the variable that multiplies the fourier modes, by default None
    variable_name : str, optional
        Name of the variable that multiplies the fourier modes, by default None
    """
    days_in_period: float = DAYS_IN_MONTH
    def _get_default_start_date(self) -> datetime.datetime:
        """Get the default start date for monthly seasonality.
        Returns the first day of the current month.
        """
        now = datetime.datetime.now()
        return datetime.datetime(year=now.year, month=now.month, day=1)
    def _get_days_in_period(self, dates: pd.DatetimeIndex) -> pd.Index:
        """Return the dayofyear within the yearly periodicity.
        Returns
        -------
        int or float
            The relevant period within the characteristic periodicity
        """
        return dates.dayofyear 
[docs]
class WeeklyFourier(FourierBase):
    """Weekly fourier seasonality.
    .. plot::
        :context: close-figs
        import arviz as az
        import matplotlib.pyplot as plt
        import numpy as np
        from pymc_marketing.mmm import WeeklyFourier
        from pymc_extras.prior import Prior
        az.style.use("arviz-white")
        seed = sum(map(ord, "Weekly"))
        rng = np.random.default_rng(seed)
        mu = np.array([0, 0, 0.5, 0])
        b = 0.075
        dist = Prior("Laplace", mu=mu, b=b, dims="fourier")
        weekly = WeeklyFourier(n_order=2, prior=dist)
        prior = weekly.sample_prior(samples=100)
        curve = weekly.sample_curve(prior)
        _, axes = weekly.plot_curve(curve)
        axes[0].set(title="Weekly Fourier Seasonality")
        plt.show()
    n_order : int
        Number of fourier modes to use.
    prefix : str, optional
        Alternative prefix for the fourier seasonality, by default None or
        "fourier"
    prior : Prior | VariableFactory, optional
        Prior distribution or VariableFactory for the fourier seasonality beta parameters, by
        default `Prior("Laplace", mu=0, b=1)`
    name : str, optional
        Name of the variable that multiplies the fourier modes, by default None
    variable_name : str, optional
        Name of the variable that multiplies the fourier modes, by default None
    """
    days_in_period: float = DAYS_IN_WEEK
    def _get_default_start_date(self) -> datetime.datetime:
        """Get the default start date for weekly seasonality.
        Returns the first day of the current month.
        """
        now = datetime.datetime.now()
        return datetime.datetime.fromisocalendar(
            year=now.year, week=now.isocalendar().week, day=1
        )
    def _get_days_in_period(self, dates: pd.DatetimeIndex) -> pd.Index:
        """Return dayofyear associated with dates.
        With no gaps, a pure weekly Fourier basis sin(2π k t/7) will evaluate the same whether
        you pass t=dayofyear or t=weekday, because the sine/cos only depend on t mod 7.
        However, there are empirical reasons to prefer this representation,
        since it avoids divergences when used in combination with tvp.
        Returns
        -------
        int or float
            The relevant period within the characteristic periodicity
        """
        return dates.dayofyear 
def _is_yearly_fourier(data: Any) -> bool:
    return data.get("class") == "YearlyFourier"
def _is_monthly_fourier(data: Any) -> bool:
    return data.get("class") == "MonthlyFourier"
def _is_weekly_fourier(data: Any) -> bool:
    return data.get("class") == "WeeklyFourier"
register_deserialization(
    is_type=_is_yearly_fourier,
    deserialize=lambda data: YearlyFourier.from_dict(data),
)
register_deserialization(
    is_type=_is_monthly_fourier,
    deserialize=lambda data: MonthlyFourier.from_dict(data),
)
register_deserialization(
    is_type=_is_weekly_fourier, deserialize=lambda data: WeeklyFourier.from_dict(data)
)