Source code for infomeasure.estimators.utils.discrete_transfer_entropy

"""Functions for efficient computation of discrete transfer entropy."""

from numpy import log, ndarray

from .discrete_interaction_information import (
    conditional_mutual_information_global,
    conditional_mutual_information_local,
)
from .ordinal import reduce_joint_space


[docs] def combined_te_form( slice_method, *data, local: bool = False, log_func: callable = log, miller_madow_correction: str | float | int = None, **slice_kwargs, ) -> float | ndarray: """ Calculate the Transfer Entropy using the combined TE formula. Parameters ---------- slice_method : function The slicing method to use for the symbolized data. *data : array-like The source, destination, and if applicable, conditional data. local : bool, optional Whether to calculate the local transfer entropy. If False, the global transfer entropy is calculated. Default is False. log_func : callable, optional The logarithm function to use. Default is the natural logarithm. miller_madow_correction : str | float | int, optional If not None, apply the Miller-Madow correction to the global mutual information in the information unit of the passed value. ``log_func`` and ``miller_madow_correction`` should be the same base. **slice_kwargs : dict The history lengths for the source, destination, and if applicable, conditional data. Returns ------- float The Transfer Entropy value. """ cmi_func = ( conditional_mutual_information_local if local else conditional_mutual_information_global ) sliced_data = slice_method( *data, **slice_kwargs, construct_joint_spaces=False, ) if len(sliced_data) == 3: src_history, dest_history, dest_future = sliced_data return cmi_func( dest_future, src_history, cond=reduce_joint_space(dest_history), log_func=log_func, miller_madow_correction=miller_madow_correction, ) elif len(sliced_data) == 4: src_history, dest_history, dest_future, cond_history = sliced_data return cmi_func( dest_future, src_history, cond=reduce_joint_space((dest_history, cond_history)), log_func=log_func, miller_madow_correction=miller_madow_correction, ) else: raise ValueError( "Invalid number of data arrays. " "The slice method returned an invalid number of sliced data." )