"""Mixin classes for estimators from .base.py."""
from io import UnsupportedOperation
from typing import Union, Callable
from numpy import (
issubdtype,
integer,
ndarray,
asarray,
mean as np_mean,
std,
sum as np_sum,
nan,
)
from numpy.random import default_rng
from infomeasure import Config
from infomeasure.utils.config import logger
from infomeasure.utils.data import StatisticalTestResult
[docs]
class RandomGeneratorMixin:
"""Mixin for random state generation.
Attributes
----------
rng : Generator
The random state generator.
"""
def __init__(self, *args, seed=None, **kwargs):
"""Initialize the random state generator."""
self.rng = default_rng(seed)
super().__init__(*args, **kwargs)
[docs]
class StatisticalTestingMixin(RandomGeneratorMixin):
"""Mixin for comprehensive statistical testing including *p*-values, *t*-scores,
and confidence intervals.
There are two methods to perform statistical tests:
- Permutation test: shuffle the data and calculate the measure.
- Bootstrap: resample the data and calculate the measure.
The :func:`statistical_test` method provides comprehensive statistical analysis
including *p*-value, *t*-score, and confidence intervals in a single call.
To be used as a mixin class with other :class:`Estimator` Estimator classes.
Inherit before the main class.
Notes
-----
The permutation test is a non-parametric statistical test to determine if the
observed effect is significant. The null hypothesis is that the measure is
not different from random, and the *p*-value is the proportion of permuted
measures greater than the observed measure.
Confidence intervals are calculated using percentiles of the null distribution
from the resampling procedure.
Raises
------
NotImplementedError
If the statistical test is not implemented for the estimator.
"""
def __init__(self, *args, **kwargs):
"""Initialize the statistical test mixin."""
self.original_data = None
super().__init__(*args, **kwargs)
if not any(
name in [cls.__name__ for cls in self.__class__.__mro__]
for name in [
"MutualInformationEstimator",
"ConditionalMutualInformationEstimator",
"TransferEntropyEstimator",
"ConditionalTransferEntropyEstimator",
]
):
raise NotImplementedError(
"Statistical test is not implemented for the estimator."
)
[docs]
def statistical_test(
self, n_tests: int = None, method: str = None
) -> StatisticalTestResult:
"""Perform comprehensive statistical test including *p*-value, *t*-score,
and confidence intervals.
Method can be "permutation_test" or "bootstrap".
- Permutation test: shuffle the data and calculate the measure.
- Bootstrap: resample the data and calculate the measure.
Parameters
----------
n_tests : int, optional
Number of permutations or bootstrap samples.
Needs to be a positive integer.
Default is the value set in the configuration.
method : str, optional
The method to calculate the statistical test.
Options are "permutation_test" or "bootstrap".
Default is the value set in the configuration.
Returns
-------
~infomeasure.utils.data.StatisticalTestResult
Comprehensive statistical test result containing *p*-value, *t*-score,
and metadata. Percentiles can be calculated on demand using
the percentile() method.
Raises
------
ValueError
If the chosen method is unknown.
io.UnsupportedOperation
If the statistical test is not supported for the estimator type.
"""
method, n_tests, test_values = self._statistical_test(method, n_tests)
# Make a test result
return self._statistical_test_result(
observed_value=self.global_val(),
test_values=test_values,
n_tests=n_tests,
method=method,
)
def _statistical_test(self, method, n_tests):
# Set defaults
if n_tests is None:
n_tests = Config.get("statistical_test_n_tests")
if method is None:
method = Config.get("statistical_test_method")
logger.debug(
"Calculating statistical test "
f"of the measure {self.__class__.__name__} "
f"using the {method} method with {n_tests} tests."
)
# Validate inputs
if not issubdtype(type(n_tests), integer) or n_tests < 1:
raise ValueError(
"Number of tests must be a positive integer, "
f"not {n_tests} ({type(n_tests)})."
)
class_names = [cls.__name__ for cls in self.__class__.__mro__]
if any(
name in class_names
for name in [
"MutualInformationEstimator",
"ConditionalMutualInformationEstimator",
]
):
if len(self.data) != 2:
raise UnsupportedOperation(
"Statistical test on mutual information is only supported "
"for two variables."
)
test_method = self._test_mi
elif any(
name in class_names
for name in [
"TransferEntropyEstimator",
"ConditionalTransferEntropyEstimator",
]
):
test_method = self._test_te
else:
raise NotImplementedError(
"Statistical test is not implemented for this estimator."
)
# Generate test values and calculate comprehensive result
test_values = test_method(n_tests, method)
return method, n_tests, test_values
@staticmethod
def _statistical_test_result(
observed_value: float,
test_values: Union[ndarray, list, tuple],
n_tests: int,
method: str,
) -> StatisticalTestResult:
"""
Calculate comprehensive statistical test result including *p*-value, *t*-score,
and confidence intervals.
Parameters
----------
observed_value : float
The observed value.
test_values : array-like
The test values from permutation/bootstrap sampling.
n_tests : int
Number of tests performed (permutations or bootstrap samples).
method : str
The statistical test method used ("permutation_test" or "bootstrap").
Returns
-------
StatisticalTestResult
Comprehensive statistical test result object.
Raises
------
ValueError
If the observed value is not numeric.
ValueError
If the test values are not array-like.
"""
# Input validation
if not isinstance(observed_value, (int, float)):
raise ValueError("Observed value must be numeric.")
if not isinstance(test_values, (ndarray, list, tuple)):
raise ValueError("Test values must be array-like.")
if len(test_values) < 2:
raise ValueError("Not enough test values for statistical test.")
test_values = asarray(test_values)
# Calculate basic statistics
null_mean = np_mean(test_values)
null_std = std(test_values, ddof=1) # Unbiased estimator (dividing by N-1)
# Compute *p*-value: proportion of test values greater than the observed value
p_value = np_sum(test_values > observed_value) / len(test_values)
# Compute *t*-score
t_score = (observed_value - null_mean) / null_std if null_std > 0 else nan
return StatisticalTestResult(
p_value=p_value,
t_score=t_score,
test_values=test_values.copy(),
observed_value=float(observed_value),
null_mean=null_mean,
null_std=null_std,
n_tests=n_tests,
method=method,
)
def _calculate_mi_with_data_selection(self, method_resample_src: Callable):
"""Calculate the measure for the resampled data using specific method."""
if len(self.original_data) != 2:
raise ValueError(
"MI with data selection is only supported for two variables."
)
# Shuffle the data
self.data = (
method_resample_src(self.original_data[0]),
self.original_data[1],
)
# Calculate the measure
res_permuted = self._calculate()
return (
res_permuted if isinstance(res_permuted, float) else np_mean(res_permuted)
)
def _test_mi(self, n_tests: int, method: str) -> ndarray:
"""Generate test values for mutual information using permutation test or
bootstrap.
Parameters
----------
n_tests : int
The number of permutations or bootstrap samples.
method : str
The method to use ("permutation_test" or "bootstrap").
Returns
-------
ndarray
Array of test values from resampling.
Raises
------
ValueError
If the method is invalid.
"""
# Store unshuffled data
self.original_data = self.data
# Set up resampling method
if method == "permutation_test":
method_resample_src = lambda data_src: self.rng.permutation(
data_src, axis=0
)
elif method == "bootstrap":
method_resample_src = lambda data_src: self.rng.choice(
data_src, size=data_src.shape[0], replace=True, axis=0
)
else:
raise ValueError(f"Invalid statistical test method: {method}.")
# Generate test values
permuted_values = [
self._calculate_mi_with_data_selection(method_resample_src)
for _ in range(n_tests)
]
# Restore the original data
self.data = self.original_data
return asarray(permuted_values)
def _test_te(self, n_tests: int, method: str) -> ndarray:
"""Generate test values for transfer entropy using permutation test or
bootstrap.
Parameters
----------
n_tests : int
The number of permutations or bootstrap samples.
method : str
The method to use ("permutation_test" or "bootstrap").
Returns
-------
ndarray
Array of test values from resampling.
Raises
------
ValueError
If the method is invalid.
"""
# Set up resampling method
if method == "permutation_test":
self.permute_src = self.rng
self.resample_src = False
elif method == "bootstrap":
self.permute_src = False
self.resample_src = self.rng
else:
raise ValueError(f"Invalid statistical test method: {method}.")
# Generate test values
permuted_values = [self._calculate() for _ in range(n_tests)]
if isinstance(permuted_values[0], ndarray):
permuted_values = [np_mean(x) for x in permuted_values]
# Deactivate the permutation/resample flags
self.permute_src, self.resample_src = False, False
return asarray(permuted_values)
[docs]
class DiscreteMIMixin:
"""Mixin for handling discrete mutual information computations.
Provides utilities and checks necessary for estimating discrete mutual
information and conditional mutual information.
Ensures that input data is suitable for these calculations and provides warnings
when pre-processing steps, such as symbolizing or discretizing, are required.
Attributes
----------
data : Any
The primary data to be used in mutual information estimation.
It should be symbolized or discretized if it contains floating-point types.
cond : Any, optional
The conditional data for conditional mutual information estimation.
If provided, it should also be symbolized or discretized if it contains
floating-point types.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _check_data_mi(self):
"""Check the input data for discrete mutual information calculations.
Verifies the types of the data attribute and condition attribute (if present)
to ensure they are suitable for mutual information estimation. Warns if the
data contains floating-point types and suggests appropriate transformations.
Notes
-----
This method checks if the data attribute contains elements with a floating-point
data type. If such types are detected, it logs a warning suggesting the need for
symbolization or discretization for mutual information calculations. Similarly,
if the `cond` attribute is present and has a floating-point type, the method logs
a warning suggesting preprocessing for conditional mutual information calculations.
This step ensures the validity and reliability of the mutual information estimation.
Attributes
----------
data : Any
The primary data used in the computation. It must be symbolized or
discretized for mutual information estimation if floating-point types
are present.
cond : Any, optional
The conditional data used for conditional mutual information estimation.
If present, it must also be symbolized or discretized if it contains
elements of floating-point types.
"""
if any(var.dtype.kind == "f" for var in self.data):
logger.warning(
"The data looks like a float array ("
f"{[var.dtype for var in self.data]}). "
"Make sure it is properly symbolized or discretized "
"for the mutual information estimation."
)
if hasattr(self, "cond") and self.cond.dtype.kind == "f":
logger.warning(
"The conditional data looks like a float array ("
f"{self.cond.dtype}). "
"Make sure it is properly symbolized or discretized "
"for the conditional mutual information estimation."
)
[docs]
class DiscreteTEMixin:
"""
Mixin class for discrete transfer entropy calculations.
Provides functionality to validate input data types for transfer entropy
estimation processes. Ensures that source, destination, and conditional
datasets are properly symbolized or discretized to prevent invalid results
from using continuous floating-point data.
Attributes
----------
source : array-like
The source data array utilized in transfer entropy calculations.
dest : array-like
The destination data array utilized in transfer entropy calculations.
cond : array-like, optional
The conditional data array utilized in transfer entropy calculations
when applicable.
"""
def _check_data_te(self):
"""Check the input data for discrete transfer entropy calculations.
Checks the data types of the source, destination, and conditional data
attributes involved in the transfer entropy estimation process.
Issues warnings if any of these datasets are floating-point,
as they may need proper symbolization or discretization in order to ensure the
validity of the calculations.
Notes
-----
Transfer entropy estimation requires input data to be symbolized or discretized,
as raw continuous floating-point arrays may lead to incorrect results.
This method specifically warns users when it detects floating-point arrays for
critical data inputs (source, destination, or conditional data).
Attributes
----------
source : array-like
The source data array whose data type is validated in this method.
dest : array-like
The destination data array whose data type is validated in this method.
cond : array-like, optional
The conditional data array whose data type is validated in this method if
present.
Warnings
--------
Issues a warning when the data type of `source` or `dest` is floating-point.
If the conditional data array (`cond`) exists and its data type is
floating-point, a separate warning will be issued.
"""
if self.source.dtype.kind == "f" or self.dest.dtype.kind == "f":
logger.warning(
"The data looks like a float array ("
f"source: {self.source.dtype}, dest: {self.dest.dtype}). "
"Make sure the data is properly symbolized or discretized "
"for the transfer entropy estimation."
)
if hasattr(self, "cond") and self.cond.dtype.kind == "f":
logger.warning(
"The conditional data looks like a float array ("
f"{self.cond.dtype}). "
"Make sure the data is properly symbolized or discretized "
"for the conditional transfer entropy estimation."
)
[docs]
class EffectiveValueMixin(StatisticalTestingMixin):
"""Mixin for effective value calculation.
To be used as a mixin class with :class:`TransferEntropyEstimator` derived classes.
Inherit before the main class.
Attributes
----------
res_effective : float | None
The effective transfer entropy.
Notes
-----
The effective value is the difference between the original
value and the value calculated for the permuted data.
"""
def __init__(self, *args, **kwargs):
"""Initialize the estimator with the effective value."""
self.res_effective = None
super().__init__(*args, **kwargs)
[docs]
def effective_val(self, method: str = None):
"""Return the effective value.
Calculates the effective value if not already done,
otherwise returns the stored value.
Returns
-------
effective : float
The effective value.
"""
_, _, test_values = self._statistical_test(n_tests=1, method=method)
return self.global_val() - test_values[0]
[docs]
class WorkersMixin:
"""Mixin that adds an attribute for the numbers of workers to use.
Attributes
----------
n_workers : int, optional
The number of workers to use. Default is 1.
-1: Use as many workers as CPU cores available.
"""
def __init__(self, *args, workers=1, **kwargs):
if workers == -1:
from multiprocessing import cpu_count
workers = cpu_count()
super().__init__(*args, **kwargs)
self.n_workers = workers