Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 67 additions & 31 deletions causalpy/data/simulate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,26 @@
Functions that generate data sets used in examples
"""

from typing import Any

import numpy as np
import pandas as pd
from scipy.stats import dirichlet, gamma, norm, uniform
from statsmodels.nonparametric.smoothers_lowess import lowess

default_lowess_kwargs = {"frac": 0.2, "it": 0}
RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)
default_lowess_kwargs: dict[str, float] = {"frac": 0.2, "it": 0}
RANDOM_SEED: int = 8927
rng: np.random.Generator = np.random.default_rng(RANDOM_SEED)


def _smoothed_gaussian_random_walk(
gaussian_random_walk_mu, gaussian_random_walk_sigma, N, lowess_kwargs
):
gaussian_random_walk_mu: float,
gaussian_random_walk_sigma: float,
N: int,
lowess_kwargs: dict[str, Any],
) -> tuple[np.ndarray, np.ndarray]:
"""
Generates Gaussian random walk data and applies LOWESS
Generates Gaussian random walk data and applies LOWESS.

:param gaussian_random_walk_mu:
Mean of the random walk
Expand All @@ -48,12 +53,12 @@ def _smoothed_gaussian_random_walk(


def generate_synthetic_control_data(
N=100,
treatment_time=70,
grw_mu=0.25,
grw_sigma=1,
lowess_kwargs=default_lowess_kwargs,
):
N: int = 100,
treatment_time: int = 70,
grw_mu: float = 0.25,
grw_sigma: float = 1,
lowess_kwargs: dict[str, Any] | None = None,
) -> tuple[pd.DataFrame, np.ndarray]:
"""
Generates data for synthetic control example.

Expand All @@ -73,6 +78,8 @@ def generate_synthetic_control_data(
>>> from causalpy.data.simulate_data import generate_synthetic_control_data
>>> df, weightings_true = generate_synthetic_control_data(treatment_time=70)
"""
if lowess_kwargs is None:
lowess_kwargs = default_lowess_kwargs

# 1. Generate non-treated variables
df = pd.DataFrame(
Expand Down Expand Up @@ -108,8 +115,12 @@ def generate_synthetic_control_data(


def generate_time_series_data(
N=100, treatment_time=70, beta_temp=-1, beta_linear=0.5, beta_intercept=3
):
N: int = 100,
treatment_time: int = 70,
beta_temp: float = -1,
beta_linear: float = 0.5,
beta_intercept: float = 3,
) -> pd.DataFrame:
"""
Generates interrupted time series example data

Expand Down Expand Up @@ -155,7 +166,7 @@ def generate_time_series_data(
return df


def generate_time_series_data_seasonal(treatment_time):
def generate_time_series_data_seasonal(treatment_time: pd.Timestamp) -> pd.DataFrame:
"""
Generates 10 years of monthly data with seasonality
"""
Expand Down Expand Up @@ -183,7 +194,9 @@ def generate_time_series_data_seasonal(treatment_time):
return df


def generate_time_series_data_simple(treatment_time, slope=0.0):
def generate_time_series_data_simple(
treatment_time: pd.Timestamp, slope: float = 0.0
) -> pd.DataFrame:
"""Generate simple interrupted time series data, with no seasonality or temporal
structure.
"""
Expand All @@ -205,7 +218,7 @@ def generate_time_series_data_simple(treatment_time, slope=0.0):
return df


def generate_did():
def generate_did() -> pd.DataFrame:
"""
Generate Difference in Differences data

Expand All @@ -223,8 +236,14 @@ def generate_did():

# local functions
def outcome(
t, control_intercept, treat_intercept_delta, trend, Δ, group, post_treatment
):
t: np.ndarray,
control_intercept: float,
treat_intercept_delta: float,
trend: float,
Δ: float,
group: np.ndarray,
post_treatment: np.ndarray,
) -> np.ndarray:
"""Compute the outcome of each unit"""
return (
control_intercept
Expand Down Expand Up @@ -257,8 +276,8 @@ def outcome(


def generate_regression_discontinuity_data(
N=100, true_causal_impact=0.5, true_treatment_threshold=0.0
):
N: int = 100, true_causal_impact: float = 0.5, true_treatment_threshold: float = 0.0
) -> pd.DataFrame:
"""
Generate regression discontinuity example data

Expand All @@ -272,12 +291,12 @@ def generate_regression_discontinuity_data(
... ) # doctest: +SKIP
"""

def is_treated(x):
def is_treated(x: np.ndarray) -> np.ndarray:
"""Check if x was treated"""
return np.greater_equal(x, true_treatment_threshold)

def impact(x):
"""Assign true_causal_impact to all treaated entries"""
def impact(x: np.ndarray) -> np.ndarray:
"""Assign true_causal_impact to all treated entries"""
y = np.zeros(len(x))
y[is_treated(x)] = true_causal_impact
return y
Expand All @@ -289,8 +308,11 @@ def impact(x):


def generate_ancova_data(
N=200, pre_treatment_means=np.array([10, 12]), treatment_effect=2, sigma=1
):
N: int = 200,
pre_treatment_means: np.ndarray = np.array([10, 12]),
treatment_effect: float = 2,
sigma: float = 1,
) -> pd.DataFrame:
"""
Generate ANCOVA example data

Expand All @@ -310,7 +332,7 @@ def generate_ancova_data(
return df


def generate_geolift_data():
def generate_geolift_data() -> pd.DataFrame:
"""Generate synthetic data for a geolift example. This will consists of 6 untreated
countries. The treated unit `Denmark` is a weighted combination of the untreated
units. We additionally specify a treatment effect which takes effect after the
Expand Down Expand Up @@ -360,7 +382,7 @@ def generate_geolift_data():
return df


def generate_multicell_geolift_data():
def generate_multicell_geolift_data() -> pd.DataFrame:
"""Generate synthetic data for a geolift example. This will consists of 6 untreated
countries. The treated unit `Denmark` is a weighted combination of the untreated
units. We additionally specify a treatment effect which takes effect after the
Expand Down Expand Up @@ -422,7 +444,9 @@ def generate_multicell_geolift_data():
# -----------------


def generate_seasonality(n=12, amplitude=1, length_scale=0.5):
def generate_seasonality(
n: int = 12, amplitude: float = 1, length_scale: float = 0.5
) -> np.ndarray:
"""Generate monthly seasonality by sampling from a Gaussian process with a
Gaussian kernel, using numpy code"""
# Generate the covariance matrix
Expand All @@ -436,14 +460,26 @@ def generate_seasonality(n=12, amplitude=1, length_scale=0.5):
return seasonality


def periodic_kernel(x1, x2, period=1, length_scale=1, amplitude=1):
def periodic_kernel(
x1: np.ndarray,
x2: np.ndarray,
period: float = 1,
length_scale: float = 1,
amplitude: float = 1,
) -> np.ndarray:
"""Generate a periodic kernel for gaussian process"""
return amplitude**2 * np.exp(
-2 * np.sin(np.pi * np.abs(x1 - x2) / period) ** 2 / length_scale**2
)


def create_series(n=52, amplitude=1, length_scale=2, n_years=4, intercept=3):
def create_series(
n: int = 52,
amplitude: float = 1,
length_scale: float = 2,
n_years: int = 4,
intercept: float = 3,
) -> np.ndarray:
"""
Returns numpy tile with generated seasonality data repeated over
multiple years
Expand Down