Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating Causal Identification module #1166

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
18ccc61
Creating Causal Identification module
cetagostini Nov 4, 2024
6a08373
Pre-commit
cetagostini Nov 4, 2024
9f4af46
Merge branch 'main' into causal_identification
wd60622 Nov 6, 2024
a53b7d7
adding missing libraries
cetagostini Nov 6, 2024
c45e5c1
Merge branch 'main' into causal_identification
cetagostini Nov 6, 2024
8c26976
Merge branch 'main' into causal_identification
cetagostini Nov 12, 2024
7b09ef6
Pushing for push
cetagostini Nov 13, 2024
8d51555
Another random push
cetagostini Nov 14, 2024
171bd10
Final v1 push
cetagostini Nov 16, 2024
4f281a7
Merge branch 'main' into causal_identification
cetagostini Nov 16, 2024
a77d871
Adding pre-commit
cetagostini Nov 16, 2024
4299b95
Adding to index
cetagostini Nov 16, 2024
d5effba
Functions in the notebook
cetagostini Nov 16, 2024
a81aee6
More adjustment in notebook functions
cetagostini Nov 16, 2024
8ce8d56
Error on description
cetagostini Nov 17, 2024
e4e09a9
Merge branch 'main' into causal_identification
cetagostini Nov 21, 2024
3c4f5c7
Requested changes
cetagostini Nov 21, 2024
abd01d3
Trying to solve dependency error test.
cetagostini Nov 25, 2024
49216d8
Solving errors
cetagostini Nov 25, 2024
dcb52d6
Pydantic
cetagostini Nov 25, 2024
dab6784
Merge branch 'main' into causal_identification
cetagostini Nov 27, 2024
3fb3cc1
add support for save and load
wd60622 Nov 28, 2024
fb886b0
support for backwards compat
wd60622 Nov 28, 2024
f044ca6
Merge branch 'main' into causal_identification
cetagostini Nov 28, 2024
3856be0
A fancy commit
cetagostini Nov 28, 2024
b31a86f
Modify
cetagostini Nov 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions pymc_marketing/mmm/causal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import networkx as nx
from dowhy.graph import get_backdoor_paths, get_instruments, build_graph_from_str, build_graph
from dowhy.causal_identifier.auto_identifier import (
construct_backdoor_estimand,
construct_frontdoor_estimand
)
from typing import list, Set, Optional
wd60622 marked this conversation as resolved.
Show resolved Hide resolved

class CausalGraphModel:
"""
A class representing a causal model based on a Directed Acyclic Graph (DAG).
It provides methods to analyze causal relationships, determine adjustment sets,
and assess the possibility of backdoor and frontdoor adjustments.

Parameters
----------
graph : nx.DiGraph
A directed acyclic graph representing the causal relationships among variables.
treatment : list[str]
A list of treatment variable names.
outcome : list[str]
A list of outcome variable names.
"""
def __init__(self, graph: nx.DiGraph, treatment: list[str], outcome: list[str]):
self.graph = graph
self.treatment = treatment
self.outcome = outcome

@classmethod
def from_string(cls, graph_str: str, treatment: list[str], outcome: list[str]):
"""
Constructs a CausalModel from a string representation of a graph.

Parameters
----------
graph_str : str
A string representation of the graph (e.g., in DOT format).
treatment : list[str]
A list of treatment variable names.
outcome : list[str]
A list of outcome variable names.

Returns
-------
CausalModel
An instance of CausalModel constructed from the given graph string.
"""
graph = build_graph_from_str(graph_str)
return cls(graph, treatment, outcome)

@classmethod
def from_nodes_and_edges(cls, action_nodes: list[str], outcome_nodes: list[str],
common_cause_nodes: Optional[list[str]] = None,
instrument_nodes: Optional[list[str]] = None,
mediator_nodes: Optional[list[str]] = None):
"""
Constructs a CausalModel from lists of nodes categorized by their roles in the causal graph.

Parameters
----------
action_nodes : list[str]
list of treatment (action) variable names.
outcome_nodes : list[str]
list of outcome variable names.
common_cause_nodes : Optional[list[str]], default=None
list of common cause (confounder) variable names.
instrument_nodes : Optional[list[str]], default=None
list of instrumental variable names.
mediator_nodes : Optional[list[str]], default=None
list of mediator variable names.

Returns
-------
CausalModel
An instance of CausalModel constructed from the specified nodes.
"""
graph = build_graph(
action_nodes=action_nodes,
outcome_nodes=outcome_nodes,
common_cause_nodes=common_cause_nodes,
instrument_nodes=instrument_nodes,
mediator_nodes=mediator_nodes
)
return cls(graph, action_nodes, outcome_nodes)

def get_backdoor_paths(self) -> dict[str, dict[str, list[list[str]]]]:
"""
Finds all backdoor paths between treatment and outcome variables and computes adjustment sets.

Returns
-------
dict[str, dict[str, list[list[str]]]]
A dictionary where each key is a treatment variable, and the value is another dictionary containing:
- 'adjustment_sets': A list of adjustment sets (lists of variable names) for backdoor adjustment.
- 'minimal_adjustment_set': The minimal adjustment set (with the least number of variables) required to block all backdoor paths.
"""
backdoor_dict = {}
for treatment_node in self.treatment:
paths = get_backdoor_paths(self.graph, [treatment_node], self.outcome)

# Exclude treatment and outcome nodes from each backdoor path to obtain valid adjustment sets
adjustment_sets = {
tuple(sorted(set(path) - {treatment_node} - set(self.outcome)))
for path in paths
}

backdoor_dict[treatment_node] = {
"adjustment_sets": [list(adjustment_set) for adjustment_set in adjustment_sets],
"minimal_adjustment_set": min(adjustment_sets, key=len) if adjustment_sets else []
}

return backdoor_dict

def is_backdoor_adjustment_possible(self) -> bool:
"""
Determines whether backdoor adjustment is possible for the causal model.

Returns
-------
bool
True if backdoor adjustment is possible (i.e., there exists a backdoor path), False otherwise.
"""
backdoor_paths = self.get_backdoor_paths()
return any(backdoor_paths[node]["minimal_adjustment_set"] for node in backdoor_paths)

def get_minimal_adjustment_sets(self) -> Optional[Set[str]]:
"""
Computes the minimal adjustment set(s) required for backdoor adjustment using DoWhy.

Returns
-------
Optional[Set[str]]
A set of variable names representing the minimal adjustment set, or None if not identifiable.
"""
try:
estimand = construct_backdoor_estimand(
self.graph, self.treatment[0], self.outcome[0]
)
return estimand.get_backdoor_variables()
except Exception as e:
print("Error identifying backdoor adjustment set:", e)
return None
wd60622 marked this conversation as resolved.
Show resolved Hide resolved

def is_frontdoor_adjustment_possible(self) -> bool:
"""
Determines whether frontdoor adjustment is possible for the causal model.

Returns
-------
bool
True if frontdoor adjustment is possible, False otherwise.
"""
try:
frontdoor_estimand = construct_frontdoor_estimand(
self.graph, self.treatment[0], self.outcome[0]
)
return frontdoor_estimand is not None
except Exception:
return False
wd60622 marked this conversation as resolved.
Show resolved Hide resolved

def get_instrumental_variables(self) -> list[str]:
"""
Identifies instrumental variables in the causal graph using DoWhy.

Returns
-------
list[str]
A list of variable names that are instrumental variables, or an empty list if none are found.
"""
try:
instruments = get_instruments(self.graph, self.treatment, self.outcome)
return instruments
except Exception as e:
print("Error identifying instruments:", e)
return []

def get_unique_minimal_adjustment_elements(self) -> Set[str]:
"""
Extracts unique variables from all minimal adjustment sets across all treatments.

Returns
-------
Set[str]
A set of unique variable names that are part of the minimal adjustment sets.
"""
backdoor_info = self.get_backdoor_paths()
unique_elements = set()
for node, info in backdoor_info.items():
unique_elements.update(info["minimal_adjustment_set"])
return unique_elements
48 changes: 48 additions & 0 deletions pymc_marketing/mmm/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import seaborn as sns
from pydantic import Field, InstanceOf, validate_call
from xarray import DataArray, Dataset
import networkx as nx

from pymc_marketing.hsgp_kwargs import HSGPKwargs
from pymc_marketing.mmm.base import BaseValidateMMM
Expand All @@ -54,6 +55,7 @@
from pymc_marketing.mmm.validating import ValidateControlColumns
from pymc_marketing.model_config import parse_model_config
from pymc_marketing.prior import Prior
from pymc_marketing.mmm.causal import CausalGraphModel

__all__ = ["BaseMMM", "MMM"]

Expand Down Expand Up @@ -113,6 +115,13 @@ def __init__(
adstock_first: bool = Field(
True, description="Whether to apply adstock first."
),
dag: str | nx.DiGraph | None = Field(
None,
description="Optional DAG provided as a string or networkx DiGraph for causal identification.",
),
outcome_column: str = Field(
None, description="Name of the outcome variable to use in the causal graph."
),
) -> None:
"""Define the constructor method.

Expand Down Expand Up @@ -149,6 +158,10 @@ def __init__(
Number of Fourier modes to model yearly seasonality, by default None.
adstock_first : bool, optional
Whether to apply adstock first, by default True.
dag : Optional[str | nx.DiGraph], optional
Optional DAG provided as a string or networkx DiGraph for causal modeling, by default None.
outcome_column : str, optional
Name of the outcome variable, by default None.
"""
self.control_columns = control_columns
self.time_varying_intercept = time_varying_intercept
Expand Down Expand Up @@ -178,6 +191,41 @@ def __init__(
)

self.yearly_seasonality = yearly_seasonality

# Begin addition for DAG and CausalGraphModel
if dag is not None and outcome_column is not None:
if isinstance(dag, str):

causal_model = CausalGraphModel.from_string(
graph_str=dag,
treatment=channel_columns,
outcome=[outcome_column],
)
elif isinstance(dag, nx.DiGraph):

causal_model = CausalGraphModel(
graph=dag,
treatment=channel_columns,
outcome=[outcome_column],
)
else:
raise ValueError("dag must be either a string or a networkx DiGraph")

# Get minimal adjustment sets
minimal_adjustment_set = causal_model.get_minimal_adjustment_sets()

if minimal_adjustment_set is not None:
# Update control_columns with minimal adjustment set
self.control_columns = list(
set(self.control_columns).union(minimal_adjustment_set)
)
# Check if seasonality_variable is in the minimal adjustment set
if "yearly_seasonality" not in minimal_adjustment_set:
# Set yearly_seasonality to None to disable it
self.yearly_seasonality = None
else:
warnings.warn("No minimal adjustment set found.")

if self.yearly_seasonality is not None:
self.yearly_fourier = YearlyFourier(
n_order=self.yearly_seasonality,
Expand Down
Loading