Source code for norse.torch.module.receptive_field

"""
These receptive fields are derived from scale-space theory, specifically in the paper `Normative theory of visual receptive fields by Lindeberg, 2021 <https://www.sciencedirect.com/science/article/pii/S2405844021000025>`_.

For use in spiking / binary signals, see the paper on `Translation and Scale Invariance for Event-Based Object tracking by Pedersen et al., 2023 <https://dl.acm.org/doi/10.1145/3584954.3584996>`_
"""

from typing import Callable, NamedTuple, Optional, Tuple

import torch

from norse.torch.module.snn import SNNCell
from norse.torch.module.leaky_integrator_box import LIBoxCell, LIBoxParameters
from norse.torch.functional.receptive_field import (
    spatial_receptive_fields_with_derivatives,
    spatial_parameters,
    temporal_scale_distribution,
)


[docs] class SpatialReceptiveField2d(torch.nn.Module): """Creates a spatial receptive field as 2-dimensional convolutions. The `rf_parameters` are a tensor of shape `(n, 5)` where `n` is the number of receptive fields. If the `optimize_fields` flag is set to `True`, the `rf_parameters` will be optimized during training. Example: >>> import torch >>> from norse.torch import SpatialReceptiveField2d >>> parameters = torch.tensor([[1., 1., 1., 0., 0.]]) >>> m = SpatialReceptiveField2d(1, 9, parameters) >>> m.weights.shape torch.Size([1, 1, 9, 9]) >>> y = m(torch.empty(1, 1, 9, 9)) >>> y.shape torch.Size([1, 1, 1, 1]) Arguments: in_channels (int): Number of input channels size (int): Size of the receptive field rf_parameters (torch.Tensor): Parameters for the receptive fields in the order (scale, angle, ratio, dx, dy) aggregate (bool): If `True`, the receptive fields will be aggregated across channels. Defaults to `True`. domain (float): The domain of the receptive field. Defaults to `8`. optimize_fields (bool): If `True`, the `rf_parameters` will be optimized during training. Defaults to `True`. **kwargs: Additional arguments for the `torch.nn.functional.conv2d` function. """
[docs] def __init__( self, in_channels: int, size: int, rf_parameters: torch.Tensor, aggregate: bool = True, domain: float = 8, optimize_fields: bool = True, **kwargs, ) -> None: super().__init__() self.rf_parameters = ( torch.nn.Parameter(rf_parameters) if optimize_fields else rf_parameters ) self.rf_parameters_previous = torch.zeros_like(self.rf_parameters) self.in_channels = in_channels self.size = size self.aggregate = aggregate self.domain = domain self.out_channels = ( rf_parameters.shape[0] if aggregate else rf_parameters.shape[0] * in_channels ) self.kwargs = kwargs if "bias" not in self.kwargs: self.kwargs["bias"] = None # Register update and update the fields for the first time self.has_updated = True self._update_weights() if optimize_fields: def update_hook(m, gi, go): m.has_updated = True self.register_full_backward_hook(update_hook)
def _update_weights(self): if self.has_updated: if not torch.all(torch.eq(self.rf_parameters_previous, self.rf_parameters)): # Reset the flag self.has_updated = False self.rf_parameters_previous = self.rf_parameters.detach().clone() # Calculate new weights fields = spatial_receptive_fields_with_derivatives( self.rf_parameters, size=self.size, domain=self.domain, ) if self.aggregate: self.out_channels = fields.shape[0] self.weights = fields.unsqueeze(1).repeat(1, self.in_channels, 1, 1) else: self.out_channels = self.fields.shape[0] * self.in_channels empty_weights = torch.zeros( self.in_channels, fields.shape[0], self.size, self.size, device=self.rf_parameters.device, ) weights = [] for i in range(self.in_channels): in_weights = empty_weights.clone() in_weights[i] = fields weights.append(in_weights) self.weights = torch.concat(weights, 1).permute(1, 0, 2, 3) self.weights.requires_grad_(True) def forward(self, x: torch.Tensor): self._update_weights() # Update weights if necessary return torch.nn.functional.conv2d(x, self.weights, **self.kwargs)
class SampledSpatialReceptiveField2d(torch.nn.Module): """ Creates a spatial receptive field as 2-dimensional convolutions, sampled over a set of scales, angles, ratios, and derivatives. This module allows for the optimization of the input parameters for scales, angles, and ratios (not derivatives) and will update the parameters (and, by extension, the receptive fields) accordingly if the respective parameters are set to True. This module is a wrapper around the `SpatialReceptiveField2d` module and will forward the kwargs. Example: >>> import torch >>> from norse.torch import SampledSpatialReceptiveField2d >>> scales = torch.tensor([1.0, 2.0]) >>> angles = torch.tensor([0.0, 1.0]) >>> ratios = torch.tensor([0.5, 1.0]) >>> derivatives = torch.tensor([[0, 0]]) >>> m = SampledSpatialReceptiveField2d(1, 9, scales, angles, ratios, derivatives, >>> optimize_scales=False, optimize_angles=False, >>> optimize_ratios=True) >>> optim = torch.optim.SGD(list(m.parameters()), lr=1) >>> y = m(torch.ones(1, 1, 9, 9)) >>> y.sum().backward() >>> optim.step() # Will update the ratios >>> m.ratios() # Are now _different_ than the initial ratios """ def __init__( self, in_channels: int, size: int, scales: torch.Tensor, angles: torch.Tensor, ratios: torch.Tensor, derivatives: torch.Tensor, optimize_scales: bool = True, optimize_angles: bool = True, optimize_ratios: bool = True, **kwargs, ): super().__init__() self.scales = torch.nn.Parameter(scales) if optimize_scales else scales self.angles = torch.nn.Parameter(angles) if optimize_angles else angles self.ratios = torch.nn.Parameter(ratios) if optimize_ratios else ratios self.derivatives = derivatives self.has_updated = False self.submodule = SpatialReceptiveField2d( in_channels=in_channels, size=size, rf_parameters=spatial_parameters( self.scales, self.angles, self.ratios, self.derivatives ), optimize_fields=False, **kwargs, ) if optimize_angles or optimize_scales or optimize_ratios: def update_hook(m, gi, go): self.has_updated = True self.register_full_backward_hook(update_hook) def forward(self, x: torch.Tensor): self._update_weights() return self.submodule(x) def _update_weights(self): if self.has_updated: self.submodule.rf_parameters = spatial_parameters( self.scales, self.angles, self.ratios, self.derivatives ) self.submodule.has_updated = True self.submodule._update_weights() class ParameterizedSpatialReceptiveField2d(torch.nn.Module): """ A parameterized version of the `SpatialReceptiveField2d` module, where the scales, angles, and ratios are optimized and updated for each kernel individually during training. This is opposite to the `SampledSpatialReceptiveField2d` module, where the scales, angles, and ratios are updated individually (as generating functions for the kernels). This module wraps the `SpatialReceptiveField2d` module. This module is a wrapper around the `SpatialReceptiveField2d` module and will forward the kwargs. Example: >>> import torch >>> from norse.torch import ParameterizedSpatialReceptiveField2d >>> scales = torch.tensor([1.0, 2.0]) >>> angles = torch.tensor([0.0, 1.0]) >>> ratios = torch.tensor([0.5, 1.0]) >>> m = ParameterizedSpatialReceptiveField2d(1, 9, scales, angles, ratios, 1, >>> optimize_scales=False, optimize_angles=False, >>> optimize_ratios=True) """ def __init__( self, in_channels: int, size: int, scales: torch.Tensor, angles: torch.Tensor, ratios: torch.Tensor, derivatives: torch.Tensor, optimize_scales: bool = True, optimize_angles: bool = True, optimize_ratios: bool = True, **kwargs, ): super().__init__() self.initial_parameters = spatial_parameters( scales, angles, ratios, derivatives ) self.scales = ( torch.nn.Parameter(self.initial_parameters[:, 0]) if optimize_scales else self.initial_parameters[:, 0] ) self.angles = ( torch.nn.Parameter(self.initial_parameters[:, 1]) if optimize_angles else self.initial_parameters[:, 1] ) self.ratios = ( torch.nn.Parameter(self.initial_parameters[:, 2]) if optimize_ratios else self.initial_parameters[:, 2] ) rf_parameters = torch.concat( [ torch.stack([self.scales, self.angles, self.ratios], 1), self.initial_parameters[:, 3:], ], 1, ) self.submodule = SpatialReceptiveField2d( in_channels=in_channels, size=size, rf_parameters=rf_parameters, optimize_fields=False, **kwargs, ) self.has_updated = False if optimize_angles or optimize_scales or optimize_ratios: def update_hook(m, gi, go): self.has_updated = True self.register_full_backward_hook(update_hook) def forward(self, x: torch.Tensor): self._update_weights() return self.submodule(x) def _update_weights(self): if self.has_updated: self.submodule.rf_parameters = torch.concat( [ torch.stack([self.scales, self.angles, self.ratios], 1), self.initial_parameters[:, 3:], ], 1, ) self.submodule.has_updated = True self.submodule._update_weights()
[docs] class TemporalReceptiveField(torch.nn.Module): """Creates ``n_scales`` temporal receptive fields for arbitrary n-dimensional inputs. The scale spaces are selected in a range of [min_scale, max_scale] using an exponential distribution, scattered using ``torch.linspace``. Parameters: shape (torch.Size): The shape of the incoming tensor, where the first dimension denote channels n_scales (int): The number of temporal scale spaces to iterate over. activation (SNNCell): The activation neuron. Defaults to LIBoxCell activation_state_map (Callable): A function that takes a tensor and provides a neuron parameter tuple. Required if activation is changed, since the default behaviour provides LIBoxParameters. min_scale (float): The minimum scale space. Defaults to 1. max_scale (Optional[float]): The maximum scale. Defaults to None. If set, c is ignored. c (Optional[float]): The base from which to generate scale values. Should be a value between 1 to 2, exclusive. Defaults to sqrt(2). Ignored if max_scale is set. time_constants (Optional[torch.Tensor]): Hardcoded time constants. Will overwrite the automatically generated, logarithmically distributed scales, if set. Defaults to None. dt (float): Neuron simulation timestep. Defaults to 0.001. """
[docs] def __init__( self, shape: torch.Size, n_scales: int = 4, activation: SNNCell.type = LIBoxCell, activation_state_map: Callable[ [torch.Tensor], NamedTuple ] = lambda t: LIBoxParameters(tau_mem_inv=t), min_scale: float = 1, max_scale: Optional[float] = None, c: float = 1.41421, time_constants: Optional[torch.Tensor] = None, dt: float = 0.001, ): super().__init__() if time_constants is None: taus = (1 / dt) / temporal_scale_distribution( n_scales, min_scale=min_scale, max_scale=max_scale, c=c ) self.time_constants = torch.stack( [ torch.full( [shape[0], *[1 for i in range(len(shape) - 1)]], tau, dtype=torch.float32, ) for tau in taus ] ) else: self.time_constants = time_constants self.ps = torch.nn.Parameter(self.time_constants) # pytype: disable=missing-parameter self.neurons = activation(p=activation_state_map(self.ps), dt=dt) # pytype: enable=missing-parameter self.rf_dimension = len(shape) self.n_scales = n_scales
def forward(self, x: torch.Tensor, state: Optional[NamedTuple] = None): x_repeated = torch.stack( [x for _ in range(self.n_scales)], dim=-self.rf_dimension - 1 ) return self.neurons(x_repeated, state)