Source code for infomeasure.estimators.utils.discrete_interaction_information
"""Functions for interaction information, a multivariate generalization of
mutual information."""
from collections import Counter
from numpy import (
clip,
uint64,
log,
ndarray,
ones,
prod,
ravel,
unique,
zeros,
count_nonzero,
)
from numpy import (
sum as np_sum,
)
from scipy.sparse import find as sp_find
from scipy.stats.contingency import crosstab
from sparse import COO, asnumpy
[docs]
def mutual_information_global(
*data: tuple,
log_func: callable = log,
miller_madow_correction: str | float | int = None,
) -> float:
"""Estimate the global mutual information between multiple random variables.
Parameters
----------
*data : array-like, shape (n_samples,)
The data used to estimate the global mutual information.
You can pass an arbitrary number of data arrays as positional arguments.
log_func : callable, optional
The logarithm function to use. Default is the natural logarithm.
miller_madow_correction : str | float | int, optional
If not None, apply the Miller-Madow correction to the global mutual
information in the information unit of the passed value.
``log_func`` and ``miller_madow_correction`` should be the same base.
Returns
-------
float
The global mutual information between the random variables.
"""
# check that miller_madow_correction is None, float, int or "e"
if miller_madow_correction is not None and (
not isinstance(miller_madow_correction, (str, float, int))
or (isinstance(miller_madow_correction, str) and miller_madow_correction != "e")
):
raise ValueError(
f"miller_madow_correction must be None, float or 'e', "
f"got {miller_madow_correction}."
)
if all(d.ndim == 1 for d in data):
if len(data) == 2:
return _mutual_information_global_2d_int(
*data,
log_func=log_func,
miller_madow_correction=miller_madow_correction,
)
else:
return _mutual_information_global_nd_int(
*data,
log_func=log_func,
miller_madow_correction=miller_madow_correction,
)
else:
return _mutual_information_global_nd_other(
*data,
log_func=log_func,
miller_madow_correction=miller_madow_correction,
)
def _mutual_information_global_nd_int(
*data: tuple,
log_func: callable = log,
miller_madow_correction: str | float | int = None,
) -> float:
"""Estimate the global mutual information between an arbitrary number of
random variables."""
uniques, indices = zip(*[unique(var, return_inverse=True, axis=0) for var in data])
contingency_coo = COO(
coords=indices,
data=ones(len(indices[0]), dtype=uint64),
shape=tuple(len(uniq) for uniq in uniques),
fill_value=0,
)
# Non-zero indices and values
idxs = contingency_coo.nonzero()
vals = contingency_coo.data
# Marginal probabilities
count_marginals = [
asnumpy(
contingency_coo.sum( # all axes, except i
axis=tuple(range(contingency_coo.ndim))[:i]
+ tuple(range(contingency_coo.ndim))[i + 1 :]
)
)
for i in range(contingency_coo.ndim)
]
# Early return if any of the marginal entropies is zero
if any(count_m.size == 1 for count_m in count_marginals):
return 0.0
# Calculate the expected logarithm values for the outer product of marginal
# probabilities, only for non-zero entries.
outer = prod(
[
count.take(idx).astype(uint64, copy=False)
for count, idx in zip(count_marginals, idxs)
],
axis=0,
)
# Normalized contingency table (joint probability)
contingency_sum = contingency_coo.sum()
p_joint = vals / contingency_sum
# Logarithm of the non-zero elements
log_p_joint = log_func(vals)
log_outer = -log_func(outer) + len(count_marginals) * log_func(contingency_sum)
# Combine the terms to calculate the mutual information
mi = p_joint * (log_p_joint - log_func(contingency_sum)) + p_joint * log_outer
misum = mi.sum() # interaction information can be negative, do not clip
if miller_madow_correction is None:
return misum
else:
corr = millermadow_mi_corr(
contingency_coo.shape, len(vals), len(data[0]), miller_madow_correction
)
return misum + corr
def _mutual_information_global_2d_int(
*data: tuple[ndarray[int]],
log_func: callable = log,
miller_madow_correction: str | float | int = None,
) -> float:
"""Estimate the global mutual information between two random variables.
The approach relies on the contingency table of the two variables.
Instead of calculating the full outer product, only the non-zero elements are
considered.
Code adapted from
the :func:`mutual_info_score() <sklearn.metrics.mutual_info_score>` function in
scikit-learn.
"""
# Contingency table - COOrdinate sparse matrix
contingency_coo = crosstab(*data, sparse=True).count
# Non-zero indices and values
nzx, nzy, nzv = sp_find(contingency_coo)
# Normalized contingency table (joint probability)
contingency_sum = contingency_coo.sum()
p_joint = nzv / contingency_sum
# Marginal probabilities
pi = ravel(contingency_coo.sum(axis=1))
pj = ravel(contingency_coo.sum(axis=0))
# Early return if any of the marginal entropies is zero
if pi.size == 1 or pj.size == 1:
return 0.0
# Logarithm of the non-zero elements
log_p_joint = log_func(nzv)
# Calculate the expected logarithm values for the outer product of marginal
# probabilities, only for non-zero entries.
outer = pi.take(nzx).astype(uint64, copy=False) * pj.take(nzy).astype(
uint64, copy=False
)
log_outer = -log_func(outer) + log_func(pi.sum()) + log_func(pj.sum())
# Combine the terms to calculate the mutual information
mi = p_joint * (log_p_joint - log_func(contingency_sum)) + p_joint * log_outer
misum = clip(mi.sum(), 0.0, None)
if miller_madow_correction is None:
return misum
else:
corr = millermadow_mi_corr(
contingency_coo.shape, len(nzv), len(data[0]), miller_madow_correction
)
return misum + corr
def _mutual_information_global_nd_other(
*data: tuple[ndarray],
log_func: callable = log,
miller_madow_correction: str | float | int = None,
) -> float:
"""Alternative method to estimate the global mutual information between an
arbitrary number of random variables.
Same as :func:`_mutual_information_global_nd_int`, but for non-integer data.
"""
# data is a tuple of ndarrays, for joint data, concatenate these rows
# joint_data = [tuple(row) for row in column_stack(data)]
joint_data = [tuple(tuple(val) for val in row) for row in zip(*data)]
# Count joint and marginal occurrences
joint_counts = Counter(joint_data)
joint_total = sum(joint_counts.values())
marginal_counts = [Counter([tuple(val) for val in var]) for var in data]
marginal_totals = [sum(counts.values()) for counts in marginal_counts]
# Estimate probabilities
joint_prob = {key: val / joint_total for key, val in joint_counts.items()}
marginal_prob = [
{key: val / total for key, val in counts.items()}
for counts, total in zip(marginal_counts, marginal_totals)
]
# Calculate the mutual information
mi_sum = [
joint_prob[key]
* log_func(
joint_prob[key]
/ prod([marginal_prob[i][key[i]] for i in range(len(data))], axis=0)
)
for key in joint_prob
]
if len(mi_sum) == 0:
return 0.0
misum = np_sum(mi_sum)
if miller_madow_correction is None:
return misum
else:
corr = millermadow_mi_corr(
list(len(counter.keys()) for counter in marginal_counts),
len(joint_counts),
len(data[0]),
miller_madow_correction,
)
return misum + corr
[docs]
def mutual_information_local(
*data: tuple,
log_func: callable = log,
miller_madow_correction: str | float | int = None,
) -> ndarray:
"""Estimate the local mutual information between multiple random variables.
The mean of the local mutual information is the global mutual information.
Only calculating the global value is more efficient,
so evaluating the local mutual information should only be done
when explicitly needed.
Parameters
----------
*data : array-like, shape (n_samples,)
The data used to estimate the local mutual information.
You can pass an arbitrary number of data arrays as positional arguments.
log_func : callable, optional
The logarithm function to use. Default is the natural logarithm.
miller_madow_correction : str | float | int, optional
If not None, apply the Miller-Madow correction to the global mutual
information in the information unit of the passed value.
``log_func`` and ``miller_madow_correction`` should be the same base.
Returns
-------
ndarray
The local mutual information between the random variables.
"""
# check that miller_madow_correction is None, float, int or "e"
if miller_madow_correction is not None and (
not isinstance(miller_madow_correction, (str, float, int))
or (isinstance(miller_madow_correction, str) and miller_madow_correction != "e")
):
raise ValueError(
f"miller_madow_correction must be None, float or 'e', "
f"got {miller_madow_correction}."
)
# Contingency table - COOrdinate sparse matrix
uniques, indices = zip(*[unique(var, return_inverse=True, axis=0) for var in data])
contingency_coo = COO(
coords=indices,
data=ones(len(indices[0]), dtype=uint64),
shape=tuple(len(uniq) for uniq in uniques),
fill_value=0,
)
# Normalized contingency table (joint probability)
contingency_sum = contingency_coo.sum()
# Marginal probabilities
count_marginals = [
asnumpy(
contingency_coo.sum( # all axes, except i
axis=tuple(range(contingency_coo.ndim))[:i]
+ tuple(range(contingency_coo.ndim))[i + 1 :]
)
)
for i in range(contingency_coo.ndim)
]
# Early return if any of the marginal entropies is zero
if any(count_m.size == 1 for count_m in count_marginals):
return zeros(len(data[0]))
# To get local values we iterate over *data
# for each row in the input data: log( p(data) / p(data1) * p(data2) )
p_joint = contingency_coo[indices].data / contingency_sum
outer = prod(
[count[indices[i]] for i, count in enumerate(count_marginals)]
/ contingency_sum,
axis=0,
)
mi_local = -log_func(outer) + log_func(p_joint)
if miller_madow_correction is None:
return mi_local
else:
corr = millermadow_mi_corr(
contingency_coo.shape,
len(contingency_coo.data),
len(data[0]),
miller_madow_correction,
)
return mi_local + corr
[docs]
def conditional_mutual_information_global(
*data: tuple,
cond: ndarray,
log_func: callable = log,
miller_madow_correction: str | float | int = None,
) -> float:
"""Estimate the global conditional mutual information
between multiple random variables and a conditioning variable.
Parameters
----------
*data : array-like, shape (n_samples,)
The data used to estimate the global mutual information.
You can pass an arbitrary number of data arrays as positional arguments.
cond : array-like, shape (n_samples,)
The conditioning variable.
log_func : callable, optional
The logarithm function to use. Default is the natural logarithm.
miller_madow_correction : str | float | int, optional
If not None, apply the Miller-Madow correction to the global mutual
information in the information unit of the passed value.
``log_func`` and ``miller_madow_correction`` should be the same base.
Returns
-------
float
The global conditional mutual information between the random variables.
Notes
-----
If wanting a condition of joint random variables, one must join them beforehand
into one dimension. This is due to the complexity of the calculation, keeping it
arbitrary enough.
One can join discrete random variables through
:py:func:`~infomeasure.estimators.utils.ordinal.reduce_joint_space`, but when using
CMI and CTE, this will happen automatically when passing a tuple of RVs as the
``cond``.
Raises
------
ValueError
If the conditioning variable is not one-dimensional.
"""
# check that miller_madow_correction is None, float, int or "e"
if miller_madow_correction is not None and (
not isinstance(miller_madow_correction, (str, float, int))
or (isinstance(miller_madow_correction, str) and miller_madow_correction != "e")
):
raise ValueError(
f"miller_madow_correction must be None, float or 'e', "
f"got {miller_madow_correction}."
)
if cond.ndim != 1:
raise ValueError("The conditioning variable must be one-dimensional.")
return _conditional_mutual_information_global_nd_int(
*data,
cond=cond,
log_func=log_func,
miller_madow_correction=miller_madow_correction,
)
def _conditional_mutual_information_global_nd_int(
*data: tuple,
cond: ndarray,
log_func: callable = log,
miller_madow_correction: str | float | int = None,
) -> float:
"""Estimate the global conditional mutual information between an arbitrary number of
random variables and a conditioning variable."""
uniques, indices = zip(
*[unique(var, return_inverse=True, axis=0) for var in (data + (cond,))]
)
contingency_coo = COO(
coords=indices,
data=ones(len(indices[0]), dtype=uint64),
shape=tuple(len(uniq) for uniq in uniques),
fill_value=0,
)
# Non-zero indices and values
idxs = contingency_coo.nonzero()
vals = contingency_coo.data
# Marginal-conditioned probabilities
count_marginals_cond = [
asnumpy(
contingency_coo.sum( # all axes, except i and cond
axis=tuple(range(contingency_coo.ndim - 1))[:i]
+ tuple(range(contingency_coo.ndim - 1))[i + 1 :]
)
)
for i in range(contingency_coo.ndim - 1)
]
count_cond = asnumpy(
contingency_coo.sum(axis=tuple(range(contingency_coo.ndim - 1)))
)
# all axes, except cond
# Early return if any of the marginal entropies is zero
if any(count_m.size == 1 for count_m in count_marginals_cond):
return 0.0
# Calculate the expected logarithm values for the outer product of marginal
# probabilities, only for non-zero entries.
outer = prod(
[
count[idx, idxs[-1]].astype(uint64, copy=False)
for count, idx in zip(count_marginals_cond, idxs[:-1])
],
axis=0,
)
# Normalized contingency table (joint probability)
contingency_sum = contingency_coo.sum()
p_joint = vals / contingency_sum
# Logarithm of the non-zero elements
p_cond = count_cond.take(idxs[-1]).astype(uint64, copy=False)
log_p_joint = log_func(vals * p_cond) - log_func(contingency_sum)
log_outer = -log_func(outer) + sum(
log_func(count_m.sum()) for count_m in count_marginals_cond[:-1]
)
# Combine the terms to calculate the mutual information
mi = p_joint * log_p_joint + p_joint * log_outer
misum = mi.sum() # interaction information can be negative, do not clip
if miller_madow_correction is None:
return misum
else:
corr = millermadow_mi_corr(
list(count_nonzero(marg) for marg in count_marginals_cond),
len(vals),
len(data[0]),
miller_madow_correction,
k_cond=len(count_cond),
)
return misum + corr
[docs]
def conditional_mutual_information_local(
*data: tuple,
cond: ndarray,
log_func: callable = log,
miller_madow_correction: str | float | int = None,
) -> ndarray:
"""Estimate the local conditional mutual information between multiple
random variables and a conditioning variable.
The mean of the local conditional mutual information is the
global conditional mutual information.
Only calculating the global value is more efficient,
so evaluating the local conditional mutual information should only be done
when explicitly needed.
Parameters
----------
*data : array-like, shape (n_samples,)
The data used to estimate the local mutual information.
You can pass an arbitrary number of data arrays as positional arguments.
cond : array-like, shape (n_samples,)
The conditioning variable.
log_func : callable, optional
The logarithm function to use. Default is the natural logarithm.
miller_madow_correction : str | float | int, optional
If not None, apply the Miller-Madow correction to the global mutual
information in the information unit of the passed value.
``log_func`` and ``miller_madow_correction`` should be the same base.
Returns
-------
ndarray
The local conditional mutual information between the random variables.
"""
# check that miller_madow_correction is None, float, int or "e"
if miller_madow_correction is not None and (
not isinstance(miller_madow_correction, (str, float, int))
or (isinstance(miller_madow_correction, str) and miller_madow_correction != "e")
):
raise ValueError(
f"miller_madow_correction must be None, float or 'e', "
f"got {miller_madow_correction}."
)
# Contingency table - COOrdinate sparse matrix
uniques, indices = zip(
*[unique(var, return_inverse=True, axis=0) for var in (data + (cond,))]
)
contingency_coo = COO(
coords=indices,
data=ones(len(indices[0]), dtype=uint64),
shape=tuple(len(uniq) for uniq in uniques),
fill_value=0,
)
# Normalized contingency table (joint probability)
contingency_sum = contingency_coo.sum()
# Marginal-conditioned probabilities
count_marginals_cond = [
asnumpy(
contingency_coo.sum( # all axes, except i and cond
axis=tuple(range(contingency_coo.ndim - 1))[:i]
+ tuple(range(contingency_coo.ndim - 1))[i + 1 :]
)
)
for i in range(contingency_coo.ndim - 1)
]
count_cond = asnumpy(
contingency_coo.sum(axis=tuple(range(contingency_coo.ndim - 1)))
)
# Early return if any of the marginal entropies is zero
if any(count_m.size == 1 for count_m in count_marginals_cond):
return zeros(len(data[0])).astype(float)
# To get local values we iterate over *data
# for each row in the input data: log( p(data) / p(data1) * p(data2) )
p_joint = contingency_coo[indices].data / contingency_sum
p_cond = count_cond[indices[-1]] / contingency_sum
outer = prod(
[count[indices[i], indices[-1]] for i, count in enumerate(count_marginals_cond)]
/ contingency_sum,
axis=0,
)
misum = log_func(p_joint * p_cond) - log_func(outer)
if miller_madow_correction is None:
return misum
else:
corr = millermadow_mi_corr(
list(count_nonzero(marg) for marg in count_marginals_cond),
len(contingency_coo.data),
len(data[0]),
miller_madow_correction,
k_cond=len(count_cond),
)
return misum + corr
[docs]
def millermadow_mi_corr(k_i, k_joint, n, base, k_cond=None):
"""
Computes the Miller-Madow mutual information correction term.
This function calculates a correction term used to adjust the bias in mutual
information estimates, which arise due to finite sample size issues.
The correction is based on the marginal counts and joint count in the observed
distributions.
Parameters
----------
k_i : list[int]
A list containing the marginal cardinalities of individual variables in the
dataset.
Each element represents the number of unique values for the respective
variable.
k_joint : int
The cardinality of the joint distribution.
Represents the number of unique
observations across all combined dimensions.
n : int
The sample size, representing the total number of observations in the data.
base : str | float | int
The logarithmic base used for the mutual information computation.
If set to
"e", natural logarithm is used.
Otherwise, log of the specified base is used.
k_cond : int, optional
The cardinality of the conditional variable.
When this is used, k_i can be used as k_iZ, to calculate the correction for
conditional MI.
Returns
-------
float
The calculated Miller-Madow correction term to adjust the mutual information
value.
"""
corr = (
sum(k_i) # sum K_i
- len(k_i) # sum -1
- k_joint
+ 1 # -(K_{1,...,i} +1)
- ((k_cond - 1) if k_cond is not None else 0)
) / (2 * n) # / 2N
if base != "e":
corr /= log(base)
return corr