Source code for infomeasure.estimators.utils.exponential_family

"""Helper functions for exponential family distributions.

Rényi entropy and Tsallis entropy are special cases of the more general
family of exponential family distributions. This module provides helper
functions for these distributions.
"""

from numpy import pi, mean as np_mean, exp as np_exp
from scipy.spatial import KDTree
from scipy.special import gamma, digamma


[docs] def calculate_common_entropy_components(data, k): """Calculate common components for entropy estimators. Parameters ---------- data : array-like The data used to estimate the entropy. k : int The number of nearest neighbors used in the estimation. Returns ------- tuple Volume of the unit ball, k-th nearest neighbor distances, number of data points, and dimensionality of the data. Raises ------ ValueError If the parameter ``k`` is selected too large. """ N, m = data.shape if k >= N: raise ValueError( "The number of nearest neighbors must be smaller " "than the number of data points." ) # Volume of the unit ball in m-dimensional space V_m = pi ** (m / 2) / gamma(m / 2 + 1) # Build k-d tree for nearest neighbor search tree = KDTree(data) # Get the k-th nearest neighbor distances rho_k = tree.query(data, k=k + 1)[0][ :, k ] # k+1 because the point itself is included return V_m, rho_k, N, m
[docs] def exponential_family_iq(k, q, V_m, rho_k, N, m): r"""Calculate the :math:`I_q` of the exponential family distribution. Parameters ---------- k : int The number of nearest neighbors used in the estimation. q : float | int The Rényi or Tsallis parameter, order or exponent. Sometimes denoted as :math:`\alpha` or :math:`q`. Should not be 1. V_m : float Volume of the unit ball in m-dimensional space. rho_k : array-like The k-th nearest neighbor distances. N : int Number of data points. m : int Dimensionality of the data. Returns ------- float The :math:`I_q` of the exponential family distribution """ C_k = (gamma(k) / gamma(k + 1 - q)) ** (1 / (1 - q)) zeta_N_i_k = (N - 1) * C_k * V_m * rho_k**m return np_mean(zeta_N_i_k ** (1 - q))
[docs] def exponential_family_i1(k, V_m, rho_k, N, m, log_base_func): r"""Calculate the :math:`I_1` of the exponential family distribution. When :math:`q = 1`, the exponential family distribution reduces to the Shannon entropy. Parameters ---------- k : int The number of nearest neighbors used in the estimation. V_m : float Volume of the unit ball in m-dimensional space. rho_k : array-like The k-th nearest neighbor distances. N : int Number of data points. m : int Dimensionality of the data. log_base_func : callable The logarithm function to use for the calculation with the chosen base. Returns ------- float The :math:`I_1` of the exponential family distribution """ zeta_N_i_k = (N - 1) * np_exp(-digamma(k)) * V_m * rho_k**m return np_mean(log_base_func(zeta_N_i_k))