Source code for infomeasure.estimators.transfer_entropy.kraskov_stoegbauer_grassberger

"""Module for the Kraskov-Stoegbauer-Grassberger (KSG) transfer entropy estimator."""

from abc import ABC

from numpy import array, inf, ndarray, log, nextafter
from scipy.spatial import KDTree
from scipy.special import digamma

from ..base import (
    TransferEntropyEstimator,
    ConditionalTransferEntropyEstimator,
)
from ..utils.te_slicing import te_observations, cte_observations
from ... import Config
from ...utils.types import LogBaseType


[docs] class BaseKSGTEEstimator(ABC): r"""Base class for transfer entropy using the Kraskov-Stoegbauer-Grassberger (KSG) method. Attributes ---------- source, dest : array-like The source (X) and destination (Y) data used to estimate the transfer entropy. cond The conditional data used to estimate the conditional transfer entropy. k : int Number of nearest neighbors to consider. noise_level : float, None or False Standard deviation of Gaussian noise to add to the data. Adds :math:`\mathcal{N}(0, \text{noise}^2)` to each data point. minkowski_p : float, :math:`1 \leq p \leq \infty` The power parameter for the Minkowski metric. Default is np.inf for maximum norm. Use 2 for Euclidean distance. ksg_id : int, default=1 The KSG estimator variant to use (1 or 2). Type I uses strict inequality for neighbour counting and the corresponding formula. Type II uses non-strict inequality and a different formula. prop_time : int, optional Number of positions to shift the data arrays relative to each other (multiple of ``step_size``). Delay/lag/shift between the variables, representing propagation time. Assumed time taken by info to transfer from source to destination. Not compatible with the ``cond`` parameter / conditional TE. Alternatively called `offset`. step_size : int, optional Step size between elements for the state space reconstruction. src_hist_len, dest_hist_len : int, optional Number of past observations to consider for the source and destination data. cond_hist_len : int, optional Number of past observations to consider for the conditional data. Only used for conditional transfer entropy. """ def __init__( self, source, dest, *, # Enforce keyword-only arguments cond=None, k: int = 4, ksg_id: int = 1, noise_level=1e-8, minkowski_p=inf, prop_time: int = 0, step_size: int = 1, src_hist_len: int = 1, dest_hist_len: int = 1, cond_hist_len: int = 1, offset: int = None, base: LogBaseType = Config.get("base"), **kwargs, ): r"""Initialize the BaseKSGTEEstimator. Parameters ---------- source, dest : array-like The source (X) and destination (Y) data used to estimate the transfer entropy. cond The conditional data used to estimate the conditional transfer entropy. k : int Number of nearest neighbors to consider. noise_level : float, None or False Standard deviation of Gaussian noise to add to the data. Adds :math:`\mathcal{N}(0, \text{noise}^2)` to each data point. minkowski_p : float, :math:`1 \leq p \leq \infty` The power parameter for the Minkowski metric. Default is np.inf for maximum norm. Use 2 for Euclidean distance. ksg_id : int, default=1 The KSG estimator variant to use (1 or 2). prop_time : int, optional Number of positions to shift the data arrays relative to each other (multiple of ``step_size``). Delay/lag/shift between the variables, representing propagation time. Assumed time taken by info to transfer from source to destination. Not compatible with the ``cond`` parameter / conditional TE. Alternatively called `offset`. step_size : int, optional Step size between elements for the state space reconstruction. src_hist_len, dest_hist_len : int, optional Number of past observations to consider for the source and destination data. cond_hist_len : int, optional Number of past observations to consider for the conditional data. Only used for conditional transfer entropy. """ if cond is None: super().__init__( source, dest, prop_time=prop_time, src_hist_len=src_hist_len, dest_hist_len=dest_hist_len, step_size=step_size, offset=offset, base=base, **kwargs, ) else: super().__init__( source, dest, cond=cond, step_size=step_size, src_hist_len=src_hist_len, dest_hist_len=dest_hist_len, cond_hist_len=cond_hist_len, prop_time=prop_time, offset=offset, base=base, **kwargs, ) self.k = k if ksg_id not in {1, 2}: raise ValueError(f"ksg_id must be 1 or 2, but got {ksg_id}.") self.ksg_id = ksg_id self.noise_level = noise_level self.minkowski_p = minkowski_p
[docs] class KSGTEEstimator(BaseKSGTEEstimator, TransferEntropyEstimator): r"""Estimator for transfer entropy using the Kraskov-Stoegbauer-Grassberger (KSG) method. Attributes ---------- source, dest : array-like The source (X) and destination (Y) data used to estimate the transfer entropy. k : int Number of nearest neighbors to consider. noise_level : float, None or False Standard deviation of Gaussian noise to add to the data. Adds :math:`\mathcal{N}(0, \text{noise}^2)` to each data point. minkowski_p : float, :math:`1 \leq p \leq \infty` The power parameter for the Minkowski metric. Default is np.inf for maximum norm. Use 2 for Euclidean distance. prop_time : int, optional Number of positions to shift the data arrays relative to each other (multiple of ``step_size``). Delay/lag/shift between the variables, representing propagation time. Assumed time taken by info to transfer from source to destination. Alternatively called `offset`. step_size : int, optional Step size between elements for the state space reconstruction. src_hist_len, dest_hist_len : int Number of past observations to consider for the source and destination data. Notes ----- Changing the number of nearest neighbors ``k`` can change the outcome, but the default value of :math:`k=4` is recommended by :cite:p:`miKSG2004`. """ def _calculate(self) -> ndarray: """Calculate the transfer entropy of the data. Returns ------- local_te : array Local transfer entropy from X to Y for each point. """ # Ensure source and dest are numpy arrays source = self.source.astype(float).copy() dest = self.dest.astype(float).copy() # Add Gaussian noise to the data if the flag is set if self.noise_level: dest += self.rng.normal(0, self.noise_level, dest.shape) source += self.rng.normal(0, self.noise_level, source.shape) ( joint_space_data, data_dest_past_embedded, marginal_1_space_data, marginal_2_space_data, ) = te_observations( source, dest, src_hist_len=self.src_hist_len, dest_hist_len=self.dest_hist_len, step_size=self.step_size, permute_src=self.permute_src, resample_src=self.resample_src, ) # Create KDTree for efficient nearest neighbor search in joint space tree_joint = KDTree(joint_space_data) # Find distances to the k-th nearest neighbor in the joint space distances, _ = tree_joint.query( joint_space_data, k=self.k + 1, p=self.minkowski_p ) kth_distances = distances[:, -1] # get last column with k-th distances # Count points in marginal spaces tree_dest_past_present = KDTree(marginal_2_space_data) tree_source_past_dest_past = KDTree(marginal_1_space_data) tree_dest_past = KDTree(data_dest_past_embedded) if self.ksg_id == 1: r_strict = nextafter(kth_distances, -inf) count_dest_past_present = tree_dest_past_present.query_ball_point( marginal_2_space_data, r=r_strict, p=self.minkowski_p, return_length=True, ) - (kth_distances > 0).astype(int) count_source_past_dest_past = tree_source_past_dest_past.query_ball_point( marginal_1_space_data, r=r_strict, p=self.minkowski_p, return_length=True, ) - (kth_distances > 0).astype(int) count_dest_past = tree_dest_past.query_ball_point( data_dest_past_embedded, r=r_strict, p=self.minkowski_p, return_length=True, ) - (kth_distances > 0).astype(int) # Compute local transfer entropy local_te = ( digamma(self.k) - digamma(array(count_dest_past_present) + 1) - digamma(array(count_source_past_dest_past) + 1) + digamma(array(count_dest_past) + 1) ) else: count_dest_past_present = tree_dest_past_present.query_ball_point( marginal_2_space_data, r=kth_distances, p=self.minkowski_p, return_length=True, ) count_source_past_dest_past = tree_source_past_dest_past.query_ball_point( marginal_1_space_data, r=kth_distances, p=self.minkowski_p, return_length=True, ) count_dest_past = tree_dest_past.query_ball_point( data_dest_past_embedded, r=kth_distances, p=self.minkowski_p, return_length=True, ) # Compute local transfer entropy local_te = ( digamma(self.k) - 1.0 / self.k - digamma(array(count_dest_past_present)) - digamma(array(count_source_past_dest_past)) + digamma(array(count_dest_past)) ) return local_te / log(self.base) if self.base != "e" else local_te
[docs] class KSGCTEEstimator(BaseKSGTEEstimator, ConditionalTransferEntropyEstimator): """Estimator for conditional transfer entropy using the Kraskov-Stoegbauer-Grassberger (KSG) method. Attributes ---------- source, dest, cond : array-like The source (X), destination (Y), and conditional (Z) data used to estimate the conditional transfer entropy. k : int Number of nearest neighbors to consider. noise_level : float, None or False Standard deviation of Gaussian noise to add to the data. Adds :math:`\\mathcal{N}(0, \text{noise}^2)` to each data point. minkowski_p : float, :math:`1 \\leq p \\leq \\infty` The power parameter for the Minkowski metric. Default is np.inf for maximum norm. Use 2 for Euclidean distance. step_size : int Step size between elements for the state space reconstruction. src_hist_len, dest_hist_len, cond_hist_len : int, optional Number of past observations to consider for the source, destination, and conditional data. prop_time : int, optional Not compatible with the ``cond`` parameter / conditional TE. Notes ----- The estimator supports two variants: - **Type I** (``ksg_id=1``): Uses strict inequality for counting neighbors in marginal spaces (dist < eps). - **Type II** (``ksg_id=2``): Uses non-strict inequality (dist <= eps) and a modified formula. Changing the number of nearest neighbors ``k`` can change the outcome, but the default value of :math:`k=4` is recommended by :cite:p:`miKSG2004`. """ def _calculate(self) -> ndarray: """Calculate the conditional transfer entropy of the data. Returns ------- local_cte : array Local conditional transfer entropy from X to Y given Z for each point. """ # Ensure source, dest, and cond are numpy arrays source = self.source.astype(float).copy() dest = self.dest.astype(float).copy() cond = self.cond.astype(float).copy() # Add Gaussian noise to the data if the flag is set if self.noise_level: dest += self.rng.normal(0, self.noise_level, dest.shape) source += self.rng.normal(0, self.noise_level, source.shape) cond += self.rng.normal(0, self.noise_level, cond.shape) ( joint_space_data, data_dest_past_embedded, marginal_1_space_data, marginal_2_space_data, ) = cte_observations( source, dest, cond, src_hist_len=self.src_hist_len, dest_hist_len=self.dest_hist_len, cond_hist_len=self.cond_hist_len, step_size=self.step_size, ) # Create KDTree for efficient nearest neighbor search in joint space tree_joint = KDTree(joint_space_data) # Find distances to the k-th nearest distances, _ = tree_joint.query( joint_space_data, k=self.k + 1, p=self.minkowski_p ) kth_distances = distances[:, -1] # Count points in marginal spaces tree_cond_dest_past_present = KDTree(marginal_2_space_data) tree_source_past_cond_dest_past = KDTree(marginal_1_space_data) tree_dest_past_cond = KDTree(data_dest_past_embedded) if self.ksg_id == 1: r_strict = nextafter(kth_distances, -inf) count_cond_dest_past_present = tree_cond_dest_past_present.query_ball_point( marginal_2_space_data, r=r_strict, p=self.minkowski_p, return_length=True, ) - (kth_distances > 0).astype(int) count_source_past_cond_dest_past = ( tree_source_past_cond_dest_past.query_ball_point( marginal_1_space_data, r=r_strict, p=self.minkowski_p, return_length=True, ) - (kth_distances > 0).astype(int) ) count_dest_past_cond = tree_dest_past_cond.query_ball_point( data_dest_past_embedded, r=r_strict, p=self.minkowski_p, return_length=True, ) - (kth_distances > 0).astype(int) # Compute local conditional transfer entropy local_cte = ( digamma(self.k) - digamma(array(count_cond_dest_past_present) + 1) - digamma(array(count_source_past_cond_dest_past) + 1) + digamma(array(count_dest_past_cond) + 1) ) else: count_cond_dest_past_present = tree_cond_dest_past_present.query_ball_point( marginal_2_space_data, r=kth_distances, p=self.minkowski_p, return_length=True, ) count_source_past_cond_dest_past = ( tree_source_past_cond_dest_past.query_ball_point( marginal_1_space_data, r=kth_distances, p=self.minkowski_p, return_length=True, ) ) count_dest_past_cond = tree_dest_past_cond.query_ball_point( data_dest_past_embedded, r=kth_distances, p=self.minkowski_p, return_length=True, ) # Compute local conditional transfer entropy local_cte = ( digamma(self.k) - 1.0 / self.k - digamma(array(count_cond_dest_past_present)) - digamma(array(count_source_past_cond_dest_past)) + digamma(array(count_dest_past_cond)) ) local_cte = array(local_cte) return local_cte / log(self.base) if self.base != "e" else local_cte