How to make nonlinear predictions on categorical variables in causal inferece given causal graph?

79 views Asked by At

I was learning causal inference and discovery these days and have suffered from this question for a long time.

From my understanding from the literatures, It seems causal inference is quite different from traditional machine learning. For the traditional machine learning, once the model is trained, given a set of X, the model directly predict the value of Y.

However, for the causal inference, the model answers if X1 changes from 1 to 2, for example, it will return the causal effects on Y.

So how can I answer the prediction question using the causal inference?

Here are some simulated data using this graph:

import networkx as nx
import matplotlib.pyplot as plt

# Create a directed graph
G = nx.DiGraph()

# Add nodes X, Y, and Z
G.add_nodes_from(['X', 'Y', 'Z'])

# Add edges representing causal relationships
G.add_edge('X', 'Y')
G.add_edge('Z', 'Y')

# Draw the graph
pos = nx.spring_layout(G)
nx.draw_networkx(G, pos, with_labels=True, node_color='lightblue', node_size=500, font_size=12, edge_color='gray')
plt.title('Causal Graph')
plt.show()

enter image description here

# Create the nonlinear relationships:
import numpy as np
import pandas as pd

# Generate X values
X = np.linspace(0, 10, 100)

# Generate Z values
Z = np.linspace(10, 20, len(X))

# Generate Y values using a non-linear relationship with X
Y = np.sin(X) + np.cos(Z) + np.random.normal(0, 0.1, len(X))

# Combine X, Z, Y into one pandas frame
# Combine X, Z, Y into one pandas DataFrame
df = pd.DataFrame({'X': X, 'Z': Z, 'Y': Y})

# Print the DataFrame
print(df)

X Z Y 0 0.00000 10.00000 -0.781419 1 0.10101 10.10101 -0.691126 2 0.20202 10.20202 -0.603684 3 0.30303 10.30303 -0.418206 4 0.40404 10.40404 -0.087543 .. ... ... ... 95 9.59596 19.59596 0.748085 96 9.69697 19.69697 0.455545 97 9.79798 19.79798 0.365235 98 9.89899 19.89899 0.023566 99 10.00000 20.00000 -0.193462

[100 rows x 3 columns]

# plot the data
import matplotlib.pyplot as plt

# Plot X, Y, and Z
plt.plot(df['X'], label='X')
plt.plot(df['Y'], label='Y')
plt.plot(df['Z'], label='Z')

# Add labels and legend
plt.xlabel('Index')
plt.ylabel('Value')
plt.legend()

# Show the plot
plt.show()

enter image description here

The problem is:

How to make predictions on Y when X = 10, Z = 20 predending you only know the causal graph but not the detailed causal function?

I have tried using microsoft causia to identify the causal graph. And also the causal inference, but they are not prediction problems.

2

There are 2 answers

4
Scriddie On

Causal prediction requires causal structure and function estimation

[Comment: you'll usually get more answers for questions regarding causality on CrossValidated]

Your question points to an important and often misunderstood point about causal structure. The causal structure contains which variable influences which other, but not how they do so. For this, you need not only a causal graph, but a full Structural Causal Model (SCM). An SCM contains a functional definition for each variable in terms of the other variables and exogenous terms.

The causal graph could tell you that some variable X might not be a cause of some other variable Y, in which case you know to predict no change in Y, no matter the value of X. If X was a cause of Y however, there is no way of knowing how to predict Y correctly for some X without knowing the functions.

In summary, causal structure does not solve the problem of function estimation, or the need for extrapolation when it comes to prediction given unseen values.

0
Aleksander Molak On

@Scriddie pointed out to a very important distinction between the graph and the functional form.

Assuming that you already know the graphical structure, there are a couple of ways you can address the challenge of predicting the values:

  1. You can use causal ML model. For non-linear cases you'll need a model capable of modeling non-linearities. One possibility here is to use double machine learning with some expressive base-learners, e.g. boosted trees. This can be implemented using EconML (see some examples here: https://github.com/PacktPublishing/Causal-Inference-and-Discovery-in-Python/blob/main/Chapter_10.ipynb). This approach assumes that you pre-define your treatment and outcome and learns how to predict the effects of changes in the treatment value, given confounders and effect modifiers.

  2. You can also learn a complete SCM, i.e. functional forms of all structural equations in your model. This can be achieved e.g. by using DoWhy's GCM API. See an example here: https://www.pywhy.org/dowhy/v0.9.1/user_guide/gcm_based_inference/introduction.html

  3. You can use a causal probabilistic programming language like PyMC (https://www.pymc-labs.com/blog-posts/causal-analysis-with-pymc-answering-what-if-with-the-new-do-operator/) or ChiRho (https://github.com/BasisResearch/chirho) to build a fully-specified SCM with (Bayesian) uncertainties.