Source code for networkx.algorithms.d_separation

"""
Algorithm for testing d-separation in DAGs.

*d-separation* is a test for conditional independence in probability
distributions that can be factorized using DAGs.  It is a purely
graphical test that uses the underlying graph and makes no reference
to the actual distribution parameters.  See [1]_ for a formal
definition.

The implementation is based on the conceptually simple linear time
algorithm presented in [2]_.  Refer to [3]_, [4]_ for a couple of
alternative algorithms.


Examples
--------

>>>
>>> # HMM graph with five states and observation nodes
... g = nx.DiGraph()
>>> g.add_edges_from(
...     [
...         ("S1", "S2"),
...         ("S2", "S3"),
...         ("S3", "S4"),
...         ("S4", "S5"),
...         ("S1", "O1"),
...         ("S2", "O2"),
...         ("S3", "O3"),
...         ("S4", "O4"),
...         ("S5", "O5"),
...     ]
... )
>>>
>>> # states/obs before 'S3' are d-separated from states/obs after 'S3'
... nx.d_separated(g, {"S1", "S2", "O1", "O2"}, {"S4", "S5", "O4", "O5"}, {"S3"})
True


References
----------

.. [1] Pearl, J.  (2009).  Causality.  Cambridge: Cambridge University Press.

.. [2] Darwiche, A.  (2009).  Modeling and reasoning with Bayesian networks. 
   Cambridge: Cambridge University Press.

.. [3] Shachter, R.  D.  (1998).
   Bayes-ball: rational pastime (for determining irrelevance and requisite
   information in belief networks and influence diagrams).
   In , Proceedings of the Fourteenth Conference on Uncertainty in Artificial
   Intelligence (pp.  480–487).
   San Francisco, CA, USA: Morgan Kaufmann Publishers Inc.

.. [4] Koller, D., & Friedman, N. (2009).
   Probabilistic graphical models: principles and techniques. The MIT Press.

"""

from collections import deque

import networkx as nx
from networkx.utils import UnionFind, not_implemented_for

__all__ = ["d_separated"]


[docs]@not_implemented_for("undirected") def d_separated(G, x, y, z): """ Return whether node sets ``x`` and ``y`` are d-separated by ``z``. Parameters ---------- G : graph A NetworkX DAG. x : set First set of nodes in ``G``. y : set Second set of nodes in ``G``. z : set Set of conditioning nodes in ``G``. Can be empty set. Returns ------- b : bool A boolean that is true if ``x`` is d-separated from ``y`` given ``z`` in ``G``. Raises ------ NetworkXError The *d-separation* test is commonly used with directed graphical models which are acyclic. Accordingly, the algorithm raises a :exc:`NetworkXError` if the input graph is not a DAG. NodeNotFound If any of the input nodes are not found in the graph, a :exc:`NodeNotFound` exception is raised. """ if not nx.is_directed_acyclic_graph(G): raise nx.NetworkXError("graph should be directed acyclic") union_xyz = x.union(y).union(z) if any(n not in G.nodes for n in union_xyz): raise nx.NodeNotFound("one or more specified nodes not found in the graph") G_copy = G.copy() # transform the graph by removing leaves that are not in x | y | z # until no more leaves can be removed. leaves = deque([n for n in G_copy.nodes if G_copy.out_degree[n] == 0]) while len(leaves) > 0: leaf = leaves.popleft() if leaf not in union_xyz: for p in G_copy.predecessors(leaf): if G_copy.out_degree[p] == 1: leaves.append(p) G_copy.remove_node(leaf) # transform the graph by removing outgoing edges from the # conditioning set. edges_to_remove = list(G_copy.out_edges(z)) G_copy.remove_edges_from(edges_to_remove) # use disjoint-set data structure to check if any node in `x` # occurs in the same weakly connected component as a node in `y`. disjoint_set = UnionFind(G_copy.nodes()) for component in nx.weakly_connected_components(G_copy): disjoint_set.union(*component) disjoint_set.union(*x) disjoint_set.union(*y) if x and y and disjoint_set[next(iter(x))] == disjoint_set[next(iter(y))]: return False else: return True