Source code for norse.torch.functional.receptive_field

"""
A module for creating receptive fields.
"""

from typing import List, Tuple, Union, Optional

import torch


[docs] def gaussian_kernel( size: int, c: torch.Tensor, x: torch.Tensor, y: torch.Tensor, domain: int = 8 ) -> torch.Tensor: """ Efficiently creates a differentiable 2d gaussian kernel. Arguments: size (int): The size of the kernel c (torch.Tensor): A 2x2 covariance matrix describing the eccentricity of the gaussian x (torch.Tensor): The receptive's field center position in x-axis y (torch.Tensor): The receptive's field center position in y-axis domain (int): The domain of the kernel. Defaults to 8 (sampling -8 to 8). """ ci = torch.linalg.inv(c) cd = torch.linalg.det(c) fraction = 1 / (2 * torch.pi * torch.sqrt(cd)) a = torch.linspace(-domain, domain, size).to(c.device) xs, ys = torch.meshgrid(a, a, indexing="xy") xs = xs - x ys = ys - y coo = torch.stack([xs, ys], dim=2) b = torch.einsum("bimj,jk->bik", -coo.unsqueeze(2), ci) a = torch.einsum("bij,bij->bi", b, coo) return fraction * torch.exp(a / 2)
def covariance_matrix( sigma1: torch.Tensor, sigma2: torch.Tensor, phi: torch.Tensor ) -> torch.Tensor: """ Creates a 2-dimensional covariance matrix given two variances and an angle for the major axis. """ lambda1 = torch.as_tensor(sigma1) ** 2 lambda2 = torch.as_tensor(sigma2) ** 2 phi = torch.as_tensor(phi) cxx = lambda1 * phi.cos() ** 2 + lambda2 * phi.sin() ** 2 cxy = (lambda1 - lambda2) * phi.cos() * phi.sin() cyy = lambda1 * phi.sin() ** 2 + lambda2 * phi.cos() ** 2 cov = torch.ones(2, 2, device=phi.device) cov[0][0] = cxx cov[0][1] = cxy cov[1][0] = cxy cov[1][1] = cyy return cov def derive_kernel(kernel, angle) -> torch.Tensor: """ Takes the spatial derivative at a given angle """ dirx = torch.cos(angle) diry = torch.sin(angle) gradx = torch.gradient(kernel, dim=0)[0] * dirx grady = torch.gradient(kernel, dim=1)[0] * diry derived = gradx + grady return derived def calculate_normalization(dx: int, scale: float, gamma: float = 1): """ Calculates scale normalization for a spatial receptive field at a given directional derivative Lindeberg: Feature detection with automatic scale selection, eq. 20 https://doi.org/10.1023/A:1008045108935 Arguments: dx (int): The nth directional derivative scale (float): The scale of the receptive field gamma (float): A normalization parameter """ t = scale**2 scale_norm = scale ** (dx * (1 - gamma)) xi_norm = t ** (gamma / 2) return scale_norm * xi_norm def derive_spatial_receptive_field_single( field: torch.Tensor, scale: float, angle: float, dx: int, dy: int ) -> torch.Tensor: """ Calculate the derivative of a single spatial receptive field at a given angle and scale with respect to x and y derivatives. Example: >>> field = spatial_receptive_field(0, 1, 16) >>> derived = derive_spatial_receptive_field_xy(field, 0, 1, 1, 0) Arguments: field (torch.Tensor): The spatial receptive field scale (float): The scale of the receptive field angle (float): The angle of the receptive field dx (int): The x-th derivative dy (int): The y-th derivative Returns: torch.Tensor: The derived spatial receptive field """ derived = field dx = int(dx) dy = int(dy) while dx > 0 or dy > 0: if dx > 0: derived = derive_kernel(derived, angle) * calculate_normalization( 1, scale, 1 ) dx -= 1 if dy > 0: derived = derive_kernel( derived, angle + torch.pi / 2 ) * calculate_normalization(1, scale, 1) dy -= 1 return derived def derive_spatial_receptive_field( field: torch.Tensor, scale: float, angle: float, derivatives: List[Tuple[int, int]] ) -> torch.Tensor: """ Derive spatial receptive field at a given angle and scale with respect to a list of derivatives. Returns a tensor of shape (len(derivatives), size, size), where size is the size of the receptive field. Arguments: field (torch.Tensor): The spatial receptive field scale (float): The scale of the receptive field angle (float): The angle of the receptive field derivatives (List[Tuple[int, int]]): A list of tuples of derivatives Returns: torch.Tensor: A list of derived spatial receptive field with the same length as the input list of derivatives """ angle = torch.as_tensor(angle) kernels = [] for dx, dy in derivatives: derived = derive_spatial_receptive_field_single(field, scale, angle, dx, dy) kernels.append(derived) return torch.stack(kernels)
[docs] def spatial_receptive_field( scale: torch.Tensor, angle: torch.Tensor, ratio: torch.Tensor, x: torch.Tensor, y: torch.Tensor, size: int, dx: int = 0, dy: int = 0, domain: float = 10, ) -> torch.Tensor: """ Creates a (size x size) receptive field kernel at a given scale, angle and ratio with respect to x and y derivatives. Arguments: scale (torch.Tensor): The scale of the field. Defaults to 2.5 angle (torch.Tensor): The rotation of the kernel in radians ratio (torch.Tensor): The eccentricity as a ratio x (torch.Tensor): The receptive's field center position in x-axis y (torch.Tensor): The receptive's field center position in y-axis size (int): The size of the square kernel in pixels dx (int): The x-th derivative of the field dy (int): The y-th derivative of the field domain (float): The initial coordinates from which the field is sampled. Defaults to 8 (sampling -8 to 8). """ angle = torch.as_tensor(angle) c = covariance_matrix(ratio, 1 / ratio, angle) * scale k = gaussian_kernel(size, c, x, y, domain=domain) k = k / k.sum() return derive_spatial_receptive_field_single(k, scale, angle, dx, dy)
def _extract_derivatives( derivatives: Union[int, List[Tuple[int, int]]] ) -> Tuple[List[Tuple[int, int]], int]: if isinstance(derivatives, int): if derivatives == 0: return [(0, 0)], 0 else: return [ (x, y) for x in range(derivatives + 1) for y in range(derivatives + 1) ], derivatives elif isinstance(derivatives, list): return derivatives, max([max(x, y) for (x, y) in derivatives]) else: raise ValueError( f"Derivatives expected either a number or a list of tuples, but got {derivatives}" ) def spatial_parameters( scales: torch.Tensor, angles: torch.Tensor, ratios: torch.Tensor, derivatives: Union[int, List[Tuple[int, int]]], x: torch.Tensor, y: torch.Tensor, include_replicas: bool = False, ) -> torch.Tensor: """ Combines the parameters of scales, angles, ratios, xand y coordinates of the center of the rf and derivatives as cartesian products to produce a set of parameters for spatial receptive fields. """ if include_replicas or not (ratios == 1).any(): parameters = torch.cartesian_prod(scales, angles, ratios, x, y) else: mask = ratios != 1 asymmetric_ratios = ratios[mask] symmetric_ratios = ratios[~mask] asymmetric_fields = torch.cartesian_prod( scales, angles, asymmetric_ratios, x, y ) symmetric_rings = torch.cartesian_prod(scales, angles, symmetric_ratios, x, y) parameters = torch.cat([asymmetric_fields, symmetric_rings]) # Add derivatives derivatives, _ = _extract_derivatives(derivatives) derivatives = torch.tensor(derivatives, device=scales.device).float() parameters_repeated = parameters.repeat_interleave(len(derivatives), 0) derivatives_repeated = derivatives.repeat(len(parameters), 1) return torch.cat([parameters_repeated, derivatives_repeated], 1)
[docs] def spatial_receptive_fields_with_derivatives( combinations: torch.Tensor, size: int, domain: float = 1, ) -> torch.Tensor: r""" Creates a number of receptive fields based on the spatial parameters and size of the receptive field. """ return torch.stack( [ spatial_receptive_field( scale=p[0], angle=p[1], ratio=p[2], x=p[3], y=p[4], size=size, dx=p[5], dy=p[6], domain=domain, ) for p in combinations ] )
[docs] def temporal_scale_distribution( n_scales: int, min_scale: float = 1, max_scale: Optional[float] = None, c: Optional[float] = 1.41421, ): r""" Provides temporal scales according to [Lindeberg2016]. The scales will be logarithmic by default, but can be changed by providing other values for c. .. math: \tau_k = c^{2(k - K)} \tau_{max} \mu_k = \sqrt(\tau_k - \tau_{k - 1}) Arguments: n_scales (int): Number of scales to generate min_scale (float): The minimum scale max_scale (Optional[float]): The maximum scale. Defaults to None. If set, c is ignored. c (Optional[float]): The base from which to generate scale values. Should be a value between 1 to 2, exclusive. Defaults to sqrt(2). Ignored if max_scale is set. .. [Lindeberg2016] Lindeberg 2016, Time-Causal and Time-Recursive Spatio-Temporal Receptive Fields, https://link.springer.com/article/10.1007/s10851-015-0613-9. """ xs = torch.linspace(1, n_scales, n_scales) if max_scale is not None: if n_scales > 1: # Avoid division by zero when having a single scale c = (min_scale / max_scale) ** (1 / (2 * (n_scales - 1))) else: return torch.tensor([min_scale]).sqrt() else: max_scale = (c ** (2 * (n_scales - 1))) * min_scale taus = c ** (2 * (xs - n_scales)) * max_scale return taus.sqrt()
def spatio_temporal_parameters( scales: torch.Tensor, angles: torch.Tensor, ratios: torch.Tensor, derivatives: Union[int, List[Tuple[int, int]]], temporal_scales: torch.Tensor, include_replicas: bool = False, ) -> torch.Tensor: """ Combines the parameters of scales, angles, ratios and derivatives as cartesian products to produce a set of parameters for spatial receptive fields. """ p = spatial_parameters(scales, angles, ratios, derivatives, include_replicas) return torch.cartesian_prod(p, temporal_scales)