| from typing import Optional |
|
|
| from pathlib import Path |
| import torch |
| from torch import nn |
| from spikingjelly.activation_based import surrogate, neuron, functional |
|
|
| import math |
| from dataclasses import dataclass |
| import warnings |
|
|
|
|
|
|
| tau = 2.0 |
| backend = "torch" |
| detach_reset = True |
|
|
|
|
|
|
| @dataclass |
| class CPG(nn.Module): |
| num_neurons: int = 40 |
| w_max: float = 10000.0 |
| l_max: int = 5000 |
|
|
| def __post_init__(self): |
| self._cpg = torch.zeros(self.l_max, self.num_neurons) |
| position = torch.arange(0, self.l_max, dtype=torch.float).unsqueeze( |
| 1 |
| ) |
| div_term = torch.exp( |
| torch.arange(0, self.num_neurons, 2).float() |
| * (-math.log(self.w_max) / self.num_neurons) |
| ) |
| div_term_single = torch.exp( |
| torch.arange(0, self.num_neurons - 1, 2).float() |
| * (-math.log(self.w_max) / self.num_neurons) |
| ) |
| self._cpg[:, 0::2] = torch.heaviside( |
| torch.sin(position * div_term) - 0.8, torch.tensor([1.0]) |
| ) |
| self._cpg[:, 1::2] = torch.heaviside( |
| torch.cos(position * div_term_single) - 0.8, torch.tensor([1.0]) |
| ) |
|
|
| @property |
| def cpg(self): |
| return self._cpg |
|
|
|
|
| class CPGLinear(nn.Module): |
| def __init__( |
| self, input_size: int, output_size: int, cpg: CPG = CPG(), dropout: float = 0.1 |
| ): |
| super().__init__() |
| self.cpg = nn.Parameter(cpg.cpg, requires_grad=False) |
| self.inp_linear = nn.Linear(input_size, output_size) |
| self.cpg_linear = nn.Linear(cpg.num_neurons, output_size) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor): |
| |
| cpg = self.cpg[: x.size(-2)] |
| x = self.dropout(x) |
| return self.inp_linear(x) + self.cpg_linear(cpg) |
|
|
|
|
|
|
|
|
| class RepeatEncoder(nn.Module): |
| def __init__(self, output_size: int): |
| super().__init__() |
| self.out_size = output_size |
| self.lif = neuron.LIFNode( |
| tau=tau, |
| step_mode="m", |
| detach_reset=detach_reset, |
| surrogate_function=surrogate.ATan(), |
| ) |
|
|
| def forward(self, inputs: torch.Tensor): |
| |
| inputs = inputs.repeat( |
| tuple([self.out_size] + torch.ones(len(inputs.size()), dtype=int).tolist()) |
| ) |
| inputs = inputs.permute(0, 1, 3, 2) |
| spks = self.lif(inputs) |
| return spks |
|
|
|
|
| class DeltaEncoder(nn.Module): |
| def __init__(self, output_size: int): |
| super().__init__() |
| self.norm = nn.BatchNorm2d(1) |
| self.enc = nn.Linear(1, output_size) |
| self.lif = neuron.LIFNode( |
| tau=tau, |
| step_mode="m", |
| detach_reset=detach_reset, |
| surrogate_function=surrogate.ATan(), |
| ) |
|
|
| def forward(self, inputs: torch.Tensor): |
| |
| delta = torch.zeros_like(inputs) |
| delta[:, 1:] = inputs[:, 1:, :] - inputs[:, :-1, :] |
| delta = delta.unsqueeze(1).permute(0, 1, 3, 2) |
| delta = self.norm(delta) |
| delta = delta.permute(0, 2, 3, 1) |
| enc = self.enc(delta) |
| enc = enc.permute(3, 0, 1, 2) |
| spks = self.lif(enc) |
| return spks |
|
|
|
|
| class ConvEncoder(nn.Module): |
| def __init__(self, output_size: int, kernel_size: int = 3): |
| super().__init__() |
| self.encoder = nn.Sequential( |
| nn.Conv2d( |
| in_channels=1, |
| out_channels=output_size, |
| kernel_size=(1, kernel_size), |
| stride=1, |
| padding=(0, kernel_size // 2), |
| ), |
| nn.BatchNorm2d(output_size), |
| ) |
| self.lif = neuron.LIFNode( |
| tau=tau, |
| step_mode="m", |
| detach_reset=detach_reset, |
| surrogate_function=surrogate.ATan(), |
| ) |
|
|
| def forward(self, inputs: torch.Tensor): |
| |
| inputs = inputs.permute(0, 2, 1).unsqueeze(1) |
| enc = self.encoder(inputs) |
| enc = enc.permute(1, 0, 2, 3) |
| spks = self.lif(enc) |
| return spks |
|
|
|
|
|
|
| SpikeEncoder = { |
| "snntorch": { |
| "repeat": RepeatEncoder, |
| "conv": ConvEncoder, |
| "delta": DeltaEncoder, |
| }, |
| "spikingjelly": { |
| "repeat": RepeatEncoder, |
| "conv": ConvEncoder, |
| "delta": DeltaEncoder, |
| }, |
| } |
|
|
|
|
|
|
|
|
|
|
| class SSA(nn.Module): |
| def __init__( |
| self, length, tau, common_thr, dim, heads=8, qkv_bias=False, qk_scale=0.25 |
| ): |
| super().__init__() |
| assert dim % heads == 0, f"dim {dim} should be divided by num_heads {heads}." |
|
|
| self.dim = dim |
| self.heads = heads |
| self.qk_scale = qk_scale |
|
|
| self.q_m = nn.Linear(dim, dim) |
| self.q_bn = nn.BatchNorm1d(dim) |
| self.q_lif = neuron.LIFNode( |
| tau=tau, |
| step_mode="m", |
| detach_reset=detach_reset, |
| surrogate_function=surrogate.ATan(), |
| v_threshold=common_thr, |
| backend=backend, |
| ) |
|
|
| self.k_m = nn.Linear(dim, dim) |
| self.k_bn = nn.BatchNorm1d(dim) |
| self.k_lif = neuron.LIFNode( |
| tau=tau, |
| step_mode="m", |
| detach_reset=detach_reset, |
| surrogate_function=surrogate.ATan(), |
| v_threshold=common_thr, |
| backend=backend, |
| ) |
|
|
| self.v_m = nn.Linear(dim, dim) |
| self.v_bn = nn.BatchNorm1d(dim) |
| self.v_lif = neuron.LIFNode( |
| tau=tau, |
| step_mode="m", |
| detach_reset=detach_reset, |
| surrogate_function=surrogate.ATan(), |
| v_threshold=common_thr, |
| backend=backend, |
| ) |
|
|
| self.attn_lif = neuron.LIFNode( |
| tau=tau, |
| step_mode="m", |
| detach_reset=detach_reset, |
| surrogate_function=surrogate.ATan(), |
| v_threshold=common_thr / 2, |
| backend=backend, |
| ) |
|
|
| self.last_m = nn.Linear(dim, dim) |
| self.last_bn = nn.BatchNorm1d(dim) |
| self.last_lif = neuron.LIFNode( |
| tau=tau, |
| step_mode="m", |
| detach_reset=detach_reset, |
| surrogate_function=surrogate.ATan(), |
| v_threshold=common_thr, |
| backend=backend, |
| ) |
|
|
| def forward(self, x): |
| T, B, L, D = x.shape |
| x_for_qkv = x.flatten(0, 1) |
| q_m_out = self.q_m(x_for_qkv) |
| q_m_out = ( |
| self.q_bn(q_m_out.transpose(-1, -2)) |
| .transpose(-1, -2) |
| .reshape(T, B, L, D) |
| .contiguous() |
| ) |
| q_m_out = self.q_lif(q_m_out) |
| q = ( |
| q_m_out.reshape(T, B, L, self.heads, D // self.heads) |
| .permute(0, 1, 3, 2, 4) |
| .contiguous() |
| ) |
|
|
| k_m_out = self.k_m(x_for_qkv) |
| k_m_out = ( |
| self.k_bn(k_m_out.transpose(-1, -2)) |
| .transpose(-1, -2) |
| .reshape(T, B, L, D) |
| .contiguous() |
| ) |
| k_m_out = self.k_lif(k_m_out) |
| k = ( |
| k_m_out.reshape(T, B, L, self.heads, D // self.heads) |
| .permute(0, 1, 3, 2, 4) |
| .contiguous() |
| ) |
|
|
| v_m_out = self.v_m(x_for_qkv) |
| v_m_out = ( |
| self.v_bn(v_m_out.transpose(-1, -2)) |
| .transpose(-1, -2) |
| .reshape(T, B, L, D) |
| .contiguous() |
| ) |
| v_m_out = self.v_lif(v_m_out) |
| v = ( |
| v_m_out.reshape(T, B, L, self.heads, D // self.heads) |
| .permute(0, 1, 3, 2, 4) |
| .contiguous() |
| ) |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.qk_scale |
| x = attn @ v |
|
|
| x = x.transpose(2, 3).reshape(T, B, L, D).contiguous() |
| x = self.attn_lif(x) |
|
|
| x = x.flatten(0, 1) |
| x = self.last_m(x) |
| x = self.last_bn(x.transpose(-1, -2)).transpose(-1, -2) |
| x = self.last_lif(x.reshape(T, B, L, D).contiguous()) |
| return x |
|
|
|
|
| class MLP(nn.Module): |
| def __init__( |
| self, |
| length, |
| tau, |
| common_thr, |
| in_features, |
| hidden_features=None, |
| out_features=None, |
| ): |
| super().__init__() |
| out_features = out_features or in_features |
| self.in_features = in_features |
| self.hidden_features = hidden_features |
| self.out_features = out_features |
|
|
| self.fc1 = CPGLinear(in_features, hidden_features) |
| self.bn1 = nn.BatchNorm1d(hidden_features) |
| self.lif1 = neuron.LIFNode( |
| tau=tau, |
| step_mode="m", |
| detach_reset=detach_reset, |
| surrogate_function=surrogate.ATan(), |
| v_threshold=common_thr, |
| backend=backend, |
| ) |
|
|
| self.fc2 = CPGLinear(hidden_features, out_features) |
| self.bn2 = nn.BatchNorm1d(out_features) |
| self.lif2 = neuron.LIFNode( |
| tau=tau, |
| step_mode="m", |
| detach_reset=detach_reset, |
| surrogate_function=surrogate.ATan(), |
| v_threshold=common_thr, |
| backend=backend, |
| ) |
|
|
| def forward(self, x): |
| T, B, L, D = x.shape |
| x = x.transpose(0, 1).flatten(1, 2) |
| x = self.fc1(x) |
| x = ( |
| self.bn1(x.transpose(-1, -2)) |
| .transpose(-1, -2) |
| .reshape(B, T, L, self.hidden_features) |
| .contiguous() |
| ) |
| x = self.lif1(x.transpose(0, 1)).transpose(0, 1) |
| x = x.flatten(1, 2) |
| x = self.fc2(x) |
| x = ( |
| self.bn2(x.transpose(-1, -2)) |
| .transpose(-1, -2) |
| .reshape(B, T, L, D) |
| .contiguous() |
| ) |
| x = self.lif2(x.transpose(0, 1)) |
| return x |
|
|
|
|
| class Block(nn.Module): |
| def __init__( |
| self, |
| length, |
| tau, |
| common_thr, |
| dim, |
| d_ff, |
| heads=8, |
| qkv_bias=False, |
| qk_scale=0.125, |
| ): |
| super().__init__() |
| self.attn = SSA( |
| length=length, |
| tau=tau, |
| common_thr=common_thr, |
| dim=dim, |
| heads=heads, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| ) |
| self.mlp = MLP( |
| length=length, |
| tau=tau, |
| common_thr=common_thr, |
| in_features=dim, |
| hidden_features=d_ff, |
| ) |
|
|
| def forward(self, x): |
| |
| x = x + self.attn(x) |
| x = x + self.mlp(x) |
| return x |
|
|
|
|
| class Spikformer_CPG(nn.Module): |
| def __init__( |
| self, |
| args, |
| dim: int=256, |
| d_ff: Optional[int] = None, |
| num_pe_neuron: int = 40, |
| pe_type: str = "neuron", |
| pe_mode: str = "concat", |
| neuron_pe_scale: float = 10000.0, |
| depths: int = 2, |
| common_thr: float = 1.0, |
| max_length: int = 5000, |
| num_steps: int = 4, |
| heads: int = 8, |
| qkv_bias: bool = False, |
| qk_scale: float = 0.125, |
| input_size: Optional[int] = None, |
| weight_file: Optional[Path] = None, |
| ): |
| super().__init__() |
| self.dim = 256 |
| self.d_ff = 1024 |
| self.T = args.T |
| self.depths = args.blocks |
| self.pe_type = pe_type |
| self.pe_mode = pe_mode |
| self.num_pe_neuron = num_pe_neuron |
| self.input_size = args.feature_size |
| self.pre_length = args.pre_length |
| self.args = args |
|
|
|
|
| self._snn_backend = "spikingjelly" |
|
|
| self.temporal_encoder = SpikeEncoder[self._snn_backend]["conv"](num_steps) |
| self.encoder = CPGLinear(self.input_size, dim, CPG(num_neurons=num_pe_neuron)) |
|
|
| self.init_lif = neuron.LIFNode( |
| tau=tau, |
| step_mode="m", |
| detach_reset=detach_reset, |
| surrogate_function=surrogate.ATan(), |
| v_threshold=common_thr, |
| backend=backend, |
| ) |
|
|
| self.blocks = nn.ModuleList( |
| [ |
| Block( |
| length=max_length, |
| tau=tau, |
| common_thr=common_thr, |
| dim=dim, |
| d_ff=self.d_ff, |
| heads=heads, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| ) |
| for _ in range(depths) |
| ] |
| ) |
|
|
| self.apply(self._init_weights) |
|
|
| self.fc = nn.Linear(args.seq_length*dim, args.pre_length*args.feature_size) |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, std=0.02) |
| if isinstance(m, nn.Linear) and m.bias is not None: |
| nn.init.constant_(m.bias, 0.0) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.constant_(m.weight, 1.0) |
| nn.init.constant_(m.bias, 0.0) |
|
|
| def forward(self, x: torch.Tensor): |
| functional.reset_net(self) |
|
|
| if self.args.normalize: |
|
|
| mean = x.mean(dim=1, keepdim=True).detach() |
| x = x - mean |
|
|
| std = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() |
| x = x / std |
|
|
|
|
| x = self.temporal_encoder(x) |
| T, B, _, L = x.shape |
| x = x.permute(1, 0, 3, 2) |
| x = x.flatten(1, 2) |
| x = self.encoder(x) |
| x = x.reshape(B, T, L, -1).permute(1, 0, 2, 3) |
| x = self.init_lif(x) |
|
|
| for blk in self.blocks: |
| x = blk(x) |
| out = x.mean(0) |
| out = self.fc(out.flatten(-2, -1)).reshape(-1, self.pre_length, self.input_size) |
| if self.args.normalize: |
| out = out * std + mean |
| aux = {'gate_l0': torch.tensor(0.0, device=out.device)} |
| return out, aux |
|
|
|
|