Source code for norse.torch.module.receptive_field

"""
These receptive fields are derived from scale-space theory, specifically in the paper `Normative theory of visual receptive fields by Lindeberg, 2021 <https://www.sciencedirect.com/science/article/pii/S2405844021000025>`_.

For use in spiking / binary signals, see the paper on `Translation and Scale Invariance for Event-Based Object tracking by Pedersen et al., 2023 <https://dl.acm.org/doi/10.1145/3584954.3584996>`_
"""

from typing import Callable, NamedTuple, Optional, Tuple

import torch

from norse.torch.module.snn import SNNCell
from norse.torch.module.leaky_integrator_box import LIBoxCell, LIBoxParameters
from norse.torch.functional.receptive_field import (
    spatial_receptive_fields_with_derivatives,
    spatial_parameters,
    temporal_scale_distribution,
)


[docs] class SpatialReceptiveField2d(torch.nn.Module): """Creates a spatial receptive field as 2-dimensional convolutions. The `rf_parameters` are a tensor of shape `(n, 5)` where `n` is the number of receptive fields. If the `optimize_fields` flag is set to `True`, the `rf_parameters` will be optimized during training. Example: >>> import torch >>> from norse.torch import SpatialReceptiveField2d >>> parameters = torch.tensor([[1., 1., 1., 0., 0., 0., 0.]]) >>> m = SpatialReceptiveField2d(1, 9, parameters) >>> m.weights.shape torch.Size([1, 1, 9, 9]) >>> y = m(torch.empty(1, 1, 9, 9)) >>> y.shape torch.Size([1, 1, 1, 1]) Arguments: in_channels (int): Number of input channels size (int): Size of the receptive field rf_parameters (torch.Tensor): Parameters for the receptive fields in the order (scale, angle, ratio, x, y, dx, dy) aggregate (bool): If `True`, the receptive fields will be aggregated across channels. Defaults to `True`. domain (float): The domain of the receptive field. Defaults to `8`. optimize_fields (bool): If `True`, the `rf_parameters` will be optimized during training. Defaults to `True`. **kwargs: Additional arguments for the `torch.nn.functional.conv2d` function. """
[docs] def __init__( self, in_channels: int, size: int, rf_parameters: torch.Tensor, aggregate: bool = True, domain: float = 8, optimize_fields: bool = True, optimize_log: bool = True, **kwargs, ) -> None: super().__init__() self.optimize_log = optimize_log self.optimize_fields = optimize_fields if optimize_fields: if not self.optimize_log: self.register_parameter( "rf_parameters", torch.nn.Parameter(rf_parameters), ) else: log_scales = torch.log(rf_parameters[:, 0]) log_ratios = torch.log(rf_parameters[:, 2]) log_rf_parameters = torch.cat( ( torch.stack( (log_scales, rf_parameters[:, 1], log_ratios), dim=1 ), rf_parameters[:, 3:], ), dim=1, ) self.register_parameter( "log_rf_parameters", torch.nn.Parameter(log_rf_parameters), ) self.register_buffer( "rf_parameters", torch.cat( ( torch.stack( ( torch.exp(self.log_rf_parameters[:, 0]), self.log_rf_parameters[:, 1], torch.exp(self.log_rf_parameters[:, 2]), ), dim=1, ), self.log_rf_parameters[:, 3:], ), dim=1, ), ) else: self.register_buffer( "rf_parameters", rf_parameters, ) self.register_buffer( "rf_parameters_previous", torch.zeros_like(self.rf_parameters), ) self.in_channels = in_channels self.size = size self.aggregate = aggregate self.domain = domain self.out_channels = ( rf_parameters.shape[0] if aggregate else rf_parameters.shape[0] * in_channels ) self.kwargs = kwargs if "bias" not in self.kwargs: self.kwargs["bias"] = None # Register update and update the fields for the first time self.has_updated = True self._update_weights() if optimize_fields: def update_hook(m, gi, go): m.has_updated = True self.register_full_backward_hook(update_hook)
def _set_weights(self, weights): if hasattr(self, "weights"): self.weights = weights else: self.register_buffer("weights", weights) def _exp_log_rf_parameters(self): return torch.cat( ( torch.stack( ( torch.exp(self.log_rf_parameters[:, 0]), self.log_rf_parameters[:, 1], torch.exp(self.log_rf_parameters[:, 2]), ), dim=1, ), self.log_rf_parameters[:, 3:], ), dim=1, ) def _update_weights(self): if self.has_updated: # Reset the flag self.has_updated = False if self.optimize_fields and self.optimize_log: self.rf_parameters = self._exp_log_rf_parameters() self.rf_parameters_previous = self.rf_parameters.detach().clone() # Calculate new weights fields = spatial_receptive_fields_with_derivatives( self.rf_parameters, size=self.size, domain=self.domain, ) if self.aggregate: self.out_channels = fields.shape[0] self._set_weights(fields.unsqueeze(1).repeat(1, self.in_channels, 1, 1)) else: self.out_channels = self.fields.shape[0] * self.in_channels empty_weights = torch.zeros( self.in_channels, fields.shape[0], self.size, self.size, device=self.rf_parameters.device, ) weights = [] for i in range(self.in_channels): in_weights = empty_weights.clone() in_weights[i] = fields weights.append(in_weights) self._set_weights(torch.concat(weights, 1).permute(1, 0, 2, 3)) def forward(self, x: torch.Tensor): self._update_weights() # Update weights if necessary return torch.nn.functional.conv2d(x, self.weights, **self.kwargs)
class SampledSpatialReceptiveField2d(torch.nn.Module): """ Creates a spatial receptive field as 2-dimensional convolutions, sampled over a set of scales, angles, ratios, x, y and derivatives. This module allows for the optimization of the input parameters for scales, angles, ratios, x and y (not derivatives) and will update the parameters (and, by extension, the receptive fields) accordingly if the respective parameters are set to True. This module is a wrapper around the `SpatialReceptiveField2d` module and will forward the kwargs. Example: >>> import torch >>> from norse.torch import SampledSpatialReceptiveField2d >>> scales = torch.tensor([1.0, 2.0]) >>> angles = torch.tensor([0.0, 1.0]) >>> ratios = torch.tensor([0.5, 1.0]) >>> x = torch.tensor([0.0, 1.0]) >>> y = torch.tensor([0.0, 1.0]) >>> derivatives = torch.tensor([[0, 0]]) >>> m = SampledSpatialReceptiveField2d(1, 9, scales, angles, ratios, derivatives, x, y, >>> optimize_scales=False, optimize_angles=False, >>> optimize_ratios=True, optimize_x=False, >>> optimize_y=False)) >>> optim = torch.optim.SGD(list(m.parameters()), lr=1) >>> y = m(torch.ones(1, 1, 9, 9)) >>> y.sum().backward() >>> optim.step() # Will update the ratios >>> m.ratios() # Are now _different_ than the initial ratios """ def __init__( self, in_channels: int, size: int, scales: torch.Tensor, angles: torch.Tensor, ratios: torch.Tensor, derivatives: torch.Tensor, x: torch.Tensor = torch.Tensor([0.0]), y: torch.Tensor = torch.Tensor([0.0]), optimize_scales: bool = True, optimize_angles: bool = True, optimize_ratios: bool = True, optimize_x: bool = True, optimize_y: bool = True, optimize_log: bool = True, **kwargs, ): super().__init__() x = x.to(scales.device) y = y.to(scales.device) self.optimize_log = optimize_log if not self.optimize_log: self.scales = torch.nn.Parameter(scales) if optimize_scales else scales self.ratios = torch.nn.Parameter(ratios) if optimize_ratios else ratios else: self.log_scales = ( torch.nn.Parameter(torch.log(scales)) if optimize_scales else torch.log(scales) ) self.log_ratios = ( torch.nn.Parameter(torch.log(ratios)) if optimize_ratios else torch.log(ratios) ) self.scales = torch.exp(self.log_scales) self.ratios = torch.exp(self.log_ratios) self.angles = torch.nn.Parameter(angles) if optimize_angles else angles self.x = torch.nn.Parameter(x) if optimize_x else x self.y = torch.nn.Parameter(y) if optimize_y else y self.derivatives = derivatives self.has_updated = True self.submodule = SpatialReceptiveField2d( in_channels=in_channels, size=size, rf_parameters=spatial_parameters( self.scales, self.angles, self.ratios, self.derivatives, self.x, self.y, ), optimize_fields=False, optimize_log=False, **kwargs, ) if ( optimize_angles or optimize_scales or optimize_ratios or optimize_x or optimize_y ): def update_hook(m, gi, go): self.has_updated = True self.register_full_backward_hook(update_hook) def forward(self, x: torch.Tensor): self._update_weights() return self.submodule(x) def _update_weights(self): if self.has_updated: if self.optimize_log: self.scales = torch.exp(self.log_scales) self.ratios = torch.exp(self.log_ratios) self.submodule.rf_parameters = spatial_parameters( self.scales, self.angles, self.ratios, self.derivatives, self.x, self.y ) self.submodule.has_updated = True self.submodule._update_weights() class ParameterizedSpatialReceptiveField2d(torch.nn.Module): """ A parameterized version of the `SpatialReceptiveField2d` module, where the scales, angles, ratios, x and y are optimized and updated for each kernel individually during training. This is opposite to the `SampledSpatialReceptiveField2d` module, where the scales, angles, ratios, x and y are updated individually (as generating functions for the kernels). This module wraps the `SpatialReceptiveField2d` module. This module is a wrapper around the `SpatialReceptiveField2d` module and will forward the kwargs. Example: >>> import torch >>> from norse.torch import ParameterizedSpatialReceptiveField2d >>> scales = torch.tensor([1.0, 2.0]) >>> angles = torch.tensor([0.0, 1.0]) >>> ratios = torch.tensor([0.5, 1.0]) >>> x = torch.tensor([0.0, 1.0]) >>> y = torch.tensor([0.0, 1.0]) >>> m = ParameterizedSpatialReceptiveField2d(1, 9, scales, angles, ratios, 1, x, y, >>> optimize_scales=False, optimize_angles=False, >>> optimize_ratios=True, optimize_x=True, >>> optimize_y=True) """ def __init__( self, in_channels: int, size: int, scales: torch.Tensor, angles: torch.Tensor, ratios: torch.Tensor, derivatives: torch.Tensor, x: torch.Tensor = torch.Tensor([0.0]), y: torch.Tensor = torch.Tensor([0.0]), optimize_scales: bool = True, optimize_angles: bool = True, optimize_ratios: bool = True, optimize_x: bool = True, optimize_y: bool = True, optimize_log: bool = True, **kwargs, ): super().__init__() x = x.to(scales.device) y = y.to(scales.device) self.optimize_log = optimize_log self.register_buffer( "initial_parameters", spatial_parameters(scales, angles, ratios, derivatives, x, y), ) if self.optimize_log: self.log_scales = ( torch.nn.Parameter(torch.log(self.initial_parameters[:, 0])) if optimize_scales else torch.log(self.initial_parameters[:, 0]) ) self.scales = torch.exp(self.log_scales) else: self.scales = ( torch.nn.Parameter(self.initial_parameters[:, 0]) if optimize_scales else self.initial_parameters[:, 0] ) self.angles = ( torch.nn.Parameter(self.initial_parameters[:, 1]) if optimize_angles else self.initial_parameters[:, 1] ) if optimize_log: self.log_ratios = ( torch.nn.Parameter(torch.log(self.initial_parameters[:, 2])) if optimize_ratios else torch.log(self.initial_parameters[:, 2]) ) self.ratios = torch.exp(self.log_ratios) else: self.ratios = ( torch.nn.Parameter(self.initial_parameters[:, 2]) if optimize_ratios else self.initial_parameters[:, 2] ) self.x = ( torch.nn.Parameter(self.initial_parameters[:, 3]) if optimize_x else self.initial_parameters[:, 3] ) self.y = ( torch.nn.Parameter(self.initial_parameters[:, 4]) if optimize_y else self.initial_parameters[:, 4] ) rf_parameters = torch.concat( [ torch.stack([self.scales, self.angles, self.ratios, self.x, self.y], 1), self.initial_parameters[:, 5:], ], 1, ) self.submodule = SpatialReceptiveField2d( in_channels=in_channels, size=size, rf_parameters=rf_parameters, optimize_fields=False, optimize_log=False, **kwargs, ) self.has_updated = True if ( optimize_angles or optimize_scales or optimize_ratios or optimize_x or optimize_y ): def update_hook(m, gi, go): self.has_updated = True self.register_full_backward_hook(update_hook) def forward(self, x: torch.Tensor): self._update_weights() return self.submodule(x) def _update_weights(self): if self.has_updated: if self.optimize_log: self.scales = torch.exp(self.log_scales) self.ratios = torch.exp(self.log_ratios) self.submodule.rf_parameters = torch.concat( [ torch.stack( [self.scales, self.angles, self.ratios, self.x, self.y], 1 ), self.initial_parameters[:, 5:], ], 1, ) self.submodule.has_updated = True self.submodule._update_weights()
[docs] class TemporalReceptiveField(torch.nn.Module): """Creates ``n_scales`` temporal receptive fields for arbitrary n-dimensional inputs. The scale spaces are selected in a range of [min_scale, max_scale] using an exponential distribution, scattered using ``torch.linspace``. Parameters: shape (torch.Size): The shape of the incoming tensor, where the first dimension denote channels n_scales (int): The number of temporal scale spaces to iterate over. activation (SNNCell): The activation neuron. Defaults to LIBoxCell activation_state_map (Callable): A function that takes a tensor and provides a neuron parameter tuple. Required if activation is changed, since the default behaviour provides LIBoxParameters. min_scale (float): The minimum scale space. Defaults to 1. max_scale (Optional[float]): The maximum scale. Defaults to None. If set, c is ignored. c (Optional[float]): The base from which to generate scale values. Should be a value between 1 to 2, exclusive. Defaults to sqrt(2). Ignored if max_scale is set. time_constants (Optional[torch.Tensor]): Hardcoded time constants. Will overwrite the automatically generated, logarithmically distributed scales, if set. Defaults to None. dt (float): Neuron simulation timestep. Defaults to 0.001. """
[docs] def __init__( self, shape: torch.Size, n_scales: int = 4, activation: SNNCell.type = LIBoxCell, activation_state_map: Callable[ [torch.Tensor], NamedTuple ] = lambda t: LIBoxParameters(tau_mem_inv=t), min_scale: float = 1, max_scale: Optional[float] = None, c: float = 1.41421, time_constants: Optional[torch.Tensor] = None, dt: float = 0.001, ): super().__init__() if time_constants is None: taus = (1 / dt) / temporal_scale_distribution( n_scales, min_scale=min_scale, max_scale=max_scale, c=c ) self.time_constants = torch.stack( [ torch.full( [shape[0], *[1 for i in range(len(shape) - 1)]], tau, dtype=torch.float32, ) for tau in taus ] ) else: self.time_constants = time_constants self.ps = torch.nn.Parameter(self.time_constants) # pytype: disable=missing-parameter self.neurons = activation(p=activation_state_map(self.ps), dt=dt) # pytype: enable=missing-parameter self.rf_dimension = len(shape) self.n_scales = n_scales
def forward(self, x: torch.Tensor, state: Optional[NamedTuple] = None): x_repeated = torch.stack( [x for _ in range(self.n_scales)], dim=-self.rf_dimension - 1 ) return self.neurons(x_repeated, state)