Source code for foresight.ei

# -*- coding: UTF-8 -*-

from functools import reduce

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]def H(x, dim=0): r"""Compute the Shannon information entropy of x. Given a tensor x, compute the shannon entropy along one of its axes. If x.shape == (N,) then returns a scalar (0-d tensor). If x.shape == (N, N) then information can be computed along vertical or horizontal axes by passing arguments dim=0 and dim=1, respectively. Note that the function does not check that the axis along which information will be computed represents a valid probability distribution. Args: x (torch.tensor) containing probability distribution dim (int) dimension along which to compute entropy Returns: (torch.tensor) of a lower order than input x """ r = x * torch.log2(x) r[r != r] = 0 return -torch.sum(r, dim=dim)
[docs]def soft_norm(W): r"""Turns 2x2 matrix W into a transition probability matrix. The weight/adjacency matrix of an ANN does not on its own allow for EI to be computed. This is because the out weights of a given neuron are not a probability distribution (they do not necessarily sum to 1). We therefore must normalize them. Applies a softmax function to each row of matrix W to ensure that the out-weights are normalized. Args: W (torch.tensor) of shape (2, 2) Returns: (torch.tensor) of shape (2, 2) """ return F.softmax(W, dim=1)
[docs]def lin_norm(W): r"""Turns 2x2 matrix W into a transition probability matrix. Applies a relu across the rows (to get rid of negative values), and normalize the rows based on their arithmetic mean. Args: W (torch.tensor) of shape (2, 2) Returns: (torch.tensor) of shape (2, 2) """ W = F.relu(W) row_sums = torch.sum(W, dim=1) row_sums[row_sums == 0] = 1 row_sums = row_sums.reshape((-1, 1)) return W / row_sums
def sig_norm(W): r"""Turns 2x2 matrix W into a transition probability matrix. Applies logistic function on each element and normalize across rows. Args: W (torch.tensor) of shape (2, 2) Returns: (torch.tensor) of shape (2, 2) """ W = torch.sigmoid(W) row_sums = torch.sum(W, dim=1).reshape((-1, 1)) return W / row_sums def linear_create_matrix(module, in_shape, out_shape): r"""Returns 2d connectivity matrix of an nn.Linear layer. This matrix has shape: (input_activations, output_activations). Therefore each row contains the output weights of a neuron. To compute the effective information, normalize across the rows of the returned matrix. Args: module (nn.Module): layer in feedforward network in_shape (tuple): shape of module input out_shape (tuple): shape of module output Returns: 2d torch.tensor """ with torch.no_grad(): W = module.weight.t() assert W.shape[0] == in_shape[-1] assert W.shape[1] == out_shape[-1] return W def conv2d_create_matrix(module, in_shape, out_shape): r"""Returns 2d connectivity matrix of an nn.Conv2d layer. This matrix has shape: (input_activations, output_activations). Therefore each row contains the output weights of a neuron. To compute the effective information, normalize across the rows of the returned matrix. Args: module (nn.Module): layer in feedforward network in_shape (tuple): shape of module input out_shape (tuple): shape of module output Returns: 2d torch.tensor """ with torch.no_grad(): assert len(in_shape) == 4 and len(out_shape) == 4 p_h, p_w = module.padding samples, channels, in_height, in_width = in_shape W_in_shape = (samples, channels, in_height + 2*p_h, in_width + 2*p_w) W = torch.zeros(out_shape[1:] + W_in_shape[1:]) # [1:] to ignore batch size weight = module.weight s_h, s_w = module.stride k_h, k_w = module.kernel_size for c_out in range(out_shape[1]): for h in range(0, out_shape[2]): for w in range(0, out_shape[3]): in_h, in_w = h*s_h, w*s_w W[c_out][h][w][:, in_h:in_h+k_h, in_w:in_w+k_w] = weight[c_out] ins = reduce(lambda x, y: x*y, in_shape[1:]) outs = reduce(lambda x, y: x*y, out_shape[1:]) if p_h != 0: W = W[:, :, :, :, p_h:-p_h, :] # get rid of vertical padding if p_w != 0: W = W[:, :, :, :, :, p_w:-p_w] # get rid of horizontal padding return W.reshape((outs, ins)).t() def avgpool2d_create_matrix(module, in_shape, out_shape): r"""Returns 2d connectivity matrix of an nn.AvgPool2d layer. This matrix has shape: (input_activations, output_activations). Therefore each row contains the output weights of a neuron. To compute the effective information, normalize across the rows of the returned matrix. Args: module (nn.Module): layer in feedforward network in_shape (tuple): shape of module input out_shape (tuple): shape of module output Returns: 2d torch.tensor """ with torch.no_grad(): assert module.padding == 0 assert len(in_shape) == 4 and len(out_shape) == 4 W = torch.zeros(out_shape[1:] + in_shape[1:]) # [1:] to ignore batch size if type(module.stride) is tuple: assert len(module.stride) == 2, "stride tuple must have 2 elements for 2d Pool" s_h, s_w = module.stride else: s_h = s_w = module.stride if type(module.kernel_size) is tuple: assert len(module.kernel_size) == 2, "kernel_size tuple must have 2 elements for 2d Pool" k_h, k_w = module.kernel_size else: k_h = k_w = module.kernel_size weight = 1 / (k_h * k_w) for c_out in range(out_shape[1]): for h in range(0, out_shape[2]): for w in range(0, out_shape[3]): in_h, in_w = h*s_h, w*s_w W[c_out][h][w][:, in_h:in_h+k_h, in_w:in_w+k_w] = weight ins = reduce(lambda x, y: x*y, in_shape[1:]) outs = reduce(lambda x, y: x*y, out_shape[1:]) return W.reshape((outs, ins)).t() r""" The modules for which a create_matrix() function has been defined. The create_matrix() function generates a 2d connectivity matrix for each layer. If the network is feedforward, with no skip-connections, then determinism and degeneracy can be computed using each layer's connectivity matrix, without computing the whole network connectivity matrix. """ VALID_MODULES = { nn.Linear: linear_create_matrix, nn.Conv2d: conv2d_create_matrix, nn.AvgPool2d: avgpool2d_create_matrix }
[docs]def get_shapes(model, input): r"""Get a dictionary {module: (in_shape, out_shape), ...} for modules in `model`. Because PyTorch uses a dynamic computation graph, the number of activations that a given module will return is not intrinsic to the definition of the module, but can depend on the shape of its input. We therefore need to pass data through the network to determine its connectivity. This function passes `input` into `model` and gets the shapes of the tensor inputs and outputs of each child module in model, provided that they are instances of VALID_MODULES. Args: model (nn.Module): feedforward neural network input (torch.tensor): a valid input to the network Returns: Dictionary {`nn.Module`: tuple(in_shape, out_shape)} """ shapes = {} hooks = [] def register_hook(module): def hook(module, input, output): shapes[module] = (tuple(input[0].shape), tuple(output.shape)) if type(module) in VALID_MODULES: hooks.append(module.register_forward_hook(hook)) model.apply(register_hook) model(input) for hook in hooks: hook.remove() return shapes
[docs]def determinism(model, input=None, shapes=None, norm=lin_norm, device='cpu'): r"""Compute the determinism of neural network `model`. Determinism is the average entropy of the outweights of each node (neuron) in the graph: .. math:: \text{determinism} = \langle H(W^\text{out}) \rangle If a `shapes` argument is provided, then `input` will not be used and need not be provided. If no `shapes` argument is provided, then an `input` argument must be provided to build its computation graph. Args: model (nn.Module): neural network defined with PyTorch input (torch.tensor): an input for the model (needed to build computation graph) shapes (dict): dictionary containing mappings from child modules to their input and output shape (created by get_shapes() function) norm (func): function to normalize the out weights of each neuron. device: (str): must be 'cpu' or 'cuda' Returns: determinism (float) """ if shapes is None: if input is None: raise Exception("Missing argument `input` needed to compute shapes.") shapes = get_shapes(model, input) H_sum = 0 N = 0 for module, (in_shape, out_shape) in shapes.items(): create_matrix = VALID_MODULES[type(module)] W = create_matrix(module, in_shape, out_shape) if norm: W = norm(W) H_sum += torch.sum(H(W, dim=1)).item() N += W.shape[0] return (H_sum / N)
[docs]def degeneracy(model, input=None, shapes=None, norm=lin_norm, device='cpu'): r"""Compute the degeneracy of neural network `model`. Degeneracy is the entropy of the cumulative, normalized in-weights for each neuron in the graph: .. math:: \text{degeneracy} = H( \langle W^\text{out} \rangle ) If a `shapes` argument is provided, then `input` will not be used and need not be provided. If no `shapes` argument is provided, then an `input` argument must be provided to build its computation graph. Args: model (nn.Module): neural network defined with PyTorch input (torch.tensor): an input for the model (needed to build computation graph) shapes (dict): dictionary containing mappings from child modules to their input and output shape (created by get_shapes() function) norm (func): function to normalize the out weights of each neuron. device: (str): must be 'cpu' or 'cuda' Returns: degeneracy (float) """ if shapes is None: if input is None: raise Exception("Missing argument `input` needed to compute shapes.") shapes = get_shapes(model, input) in_weights = torch.zeros((0,)).to(device) for module, (in_shape, out_shape) in shapes.items(): create_matrix = VALID_MODULES[type(module)] W = create_matrix(module, in_shape, out_shape) if norm: W = norm(W) in_weights = torch.cat((in_weights, torch.sum(W, dim=0))) total_weight = torch.sum(in_weights).item() return H(in_weights / total_weight).item()
[docs]def ei(model, input=None, shapes=None, norm=lin_norm, device='cpu'): r"""Compute the effective information of neural network `model`. Effective information is a useful measure of the information contained in the weighted connectivity structure of a network. It is used in theoretical neuroscience to study emergent structure in networks. It is defined by: .. math:: \text{EI} = \text{determinism} - \text{degeneracy} explicitly: .. math:: \text{EI} = \langle H(W^\text{out}) \rangle - H( \langle W^\text{out} \rangle ) Which is equal to the average KL-divergence between the normalized out-weights of the neurons and the distribution of in-weights across the neurons in the network. If a `shapes` argument is provided, then `input` will not be used and need not be provided. If no `shapes` argument is provided, then an `input` argument must be provided to build its computation graph. Args: model (nn.Module): neural network defined with PyTorch input (torch.tensor): an input for the model (needed to build computation graph) shapes (dict): dictionary containing mappings from child modules to their input and output shape (created by get_shapes() function) norm (func): function to normalize the out weights of each neuron. device: (str): must be 'cpu' or 'cuda' Returns: ei (float) """ if shapes is None: if input is None: raise Exception("Missing argument `input` needed to compute shapes.") shapes = get_shapes(model, input) return degeneracy(model, shapes=shapes, norm=norm, device=device) \ - determinism(model, shapes=shapes, norm=norm, device=device)