| from typing import Optional |
| from pathlib import Path |
| import torch |
| from torch import nn |
| from spikingjelly.activation_based import surrogate, neuron, functional |
| import math |
| import copy |
|
|
|
|
| tau = 2.0 |
| backend = "torch" |
| detach_reset = True |
|
|
|
|
|
|
| def generate_ones_and_minus_ones_matrix(rows, cols): |
| random_matrix = torch.randint(0, 2, (rows, cols)) |
| binary_matrix = torch.where( |
| random_matrix == 0, |
| -1 * torch.ones_like(random_matrix), |
| torch.ones_like(random_matrix), |
| ) |
| return binary_matrix.float() |
|
|
|
|
| class RandomPE(nn.Module): |
| def __init__( |
| self, |
| d_model, |
| pe_mode="concat", |
| num_pe_neuron=10, |
| neuron_pe_scale=1000.0, |
| dropout=0.1, |
| num_steps=4, |
| ): |
| super().__init__() |
| self.max_len = 5000 |
| self.pe_mode = pe_mode |
| self.neuron_pe_scale = neuron_pe_scale |
| self.dropout = nn.Dropout(p=dropout) |
| if self.pe_mode == "concat": |
| self.num_pe_neuron = copy.deepcopy(num_pe_neuron) |
| elif self.pe_mode == "add": |
| self.num_pe_neuron = copy.deepcopy(d_model) |
| pe = generate_ones_and_minus_ones_matrix( |
| self.max_len, self.num_pe_neuron |
| ) |
| pe = pe.unsqueeze(0).transpose(0, 1) |
| print("pe.shape: ", pe.shape) |
| self.register_buffer("pe", pe) |
|
|
| def forward(self, x): |
| |
| T, B, L, _ = x.shape |
| x = x.permute(1, 0, 2, 3) |
| x = x.flatten(1, 2) |
| if self.pe_mode == "concat": |
| |
| tmp = self.pe[: x.size(-2), :].repeat(1, B, 1).transpose(0, 1) |
| x = torch.concat([x, tmp], dim=-1) |
| |
| elif self.pe_mode == "add": |
| |
| x = x + self.pe[: x.size(-2), :].transpose(0, 1) |
| |
| x = x.transpose(0, 1) |
| x = x.reshape(T, L, B, -1) |
| x = x.permute(0, 2, 1, 3) |
| return self.dropout(x) |
|
|
|
|
| class NeuronPE(nn.Module): |
| def __init__( |
| self, |
| d_model, |
| pe_mode="concat", |
| num_pe_neuron=10, |
| neuron_pe_scale=10000.0, |
| dropout=0.1, |
| num_steps=4, |
| ): |
| super().__init__() |
| self.max_len = 50000 |
| self.pe_mode = pe_mode |
| self.neuron_pe_scale = neuron_pe_scale |
| self.dropout = nn.Dropout(p=dropout) |
| if self.pe_mode == "concat": |
| self.num_pe_neuron = copy.deepcopy(num_pe_neuron) |
| elif self.pe_mode == "add": |
| self.num_pe_neuron = copy.deepcopy(d_model) |
| pe = torch.zeros(self.max_len, self.num_pe_neuron) |
| position = torch.arange(0, self.max_len, dtype=torch.float).unsqueeze( |
| 1 |
| ) |
| div_term = torch.exp( |
| torch.arange(0, self.num_pe_neuron, 2).float() |
| * (-math.log(neuron_pe_scale) / self.num_pe_neuron) |
| ) |
| div_term_single = torch.exp( |
| torch.arange(0, self.num_pe_neuron - 1, 2).float() |
| * (-math.log(neuron_pe_scale) / self.num_pe_neuron) |
| ) |
| pe[:, 0::2] = torch.heaviside( |
| torch.sin(position * div_term) - 0.8, torch.tensor([1.0]) |
| ) |
| pe[:, 1::2] = torch.heaviside( |
| torch.cos(position * div_term_single) - 0.8, torch.tensor([1.0]) |
| ) |
| pe = pe.unsqueeze(0).transpose(0, 1) |
| print("pe.shape: ", pe.shape) |
| self.register_buffer("pe", pe) |
|
|
| def forward(self, x): |
| |
| T, B, L, _ = x.shape |
| x = x.permute(1, 0, 2, 3) |
| x = x.flatten(1, 2) |
| if self.pe_mode == "concat": |
| |
| tmp = self.pe[: x.size(-2), :].repeat(1, B, 1).transpose(0, 1) |
| x = torch.concat([x, tmp], dim=-1) |
| |
| elif self.pe_mode == "add": |
| |
| x = x + self.pe[: x.size(-2), :].transpose(0, 1) |
| |
| x = x.transpose(0, 1) |
| x = x.reshape(T, L, B, -1) |
| x = x.permute(0, 2, 1, 3) |
| return self.dropout(x) |
|
|
|
|
| class StaticPE(nn.Module): |
| r"""Inject some information about the relative or absolute position of the tokens |
| in the sequence. The positional encodings have the same dimension as |
| the embeddings, so that the two can be summed. Here, we use sine and cosine |
| functions of different frequencies. |
| .. math:: |
| \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) |
| \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) |
| \text{where pos is the word position and i is the embed idx)""" |
|
|
| def __init__(self, d_model, dropout=0.1, max_len=5000): |
| super().__init__() |
| self.dropout = nn.Dropout(p=dropout) |
| pe = torch.zeros(max_len, d_model) |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div_term = torch.exp( |
| torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) |
| ) |
| div_term_single = torch.exp( |
| torch.arange(0, d_model - 1, 2).float() * (-math.log(10000.0) / d_model) |
| ) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term_single) |
| pe = pe.unsqueeze(0).transpose(0, 1) |
| self.register_buffer("pe", pe) |
|
|
| def forward(self, x): |
| |
| x = x + self.pe[: x.size(0), :] |
| x = self.dropout(x) |
| return x |
|
|
|
|
| class ConvPE(nn.Module): |
| def __init__(self, d_model, dropout=0.1, max_len=5000, num_steps=4): |
|
|
| super().__init__() |
| self.T = num_steps |
| self.rpe_conv = nn.Conv1d( |
| d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False |
| ) |
| self.rpe_bn = nn.BatchNorm1d(d_model) |
| self.rpe_lif = neuron.LIFNode( |
| step_mode="m", |
| detach_reset=True, |
| surrogate_function=surrogate.ATan(), |
| v_threshold=1.0, |
| ) |
| self.dropout = nn.Dropout(p=dropout) |
|
|
| def forward(self, x): |
| |
| L, TB, D = x.shape |
| x_feat = x.permute(1, 2, 0) |
| x_feat = self.rpe_conv(x_feat) |
| x_feat = ( |
| self.rpe_bn(x_feat).reshape(self.T, int(TB / self.T), D, L).contiguous() |
| ) |
| x_feat = self.rpe_lif(x_feat) |
| x_feat = x_feat.flatten(0, 1) |
| x_feat = self.dropout(x_feat) |
| x_feat = x_feat.permute(2, 0, 1) |
| x = x + x_feat |
| return x |
|
|
|
|
| class PositionEmbedding(nn.Module): |
| def __init__( |
| self, |
| input_size: int, |
| pe_type: str, |
| max_len: int = 5000, |
| pe_mode: str = "add", |
| num_pe_neuron: int = 10, |
| neuron_pe_scale: float = 1000.0, |
| dropout=0.1, |
| num_steps=4, |
| ): |
| super().__init__() |
| self.emb_type = pe_type |
| if pe_type in ["learn", "none"]: |
| self.emb = nn.Embedding(max_len, input_size) |
| elif pe_type == "conv": |
| self.emb = ConvPE( |
| d_model=input_size, |
| max_len=max_len, |
| dropout=dropout, |
| num_steps=num_steps, |
| ) |
| elif pe_type == "static": |
| self.emb = StaticPE(d_model=input_size, max_len=max_len, dropout=dropout) |
| elif pe_type == "neuron": |
| self.emb = NeuronPE( |
| d_model=input_size, |
| pe_mode=pe_mode, |
| num_pe_neuron=num_pe_neuron, |
| neuron_pe_scale=neuron_pe_scale, |
| dropout=dropout, |
| num_steps=num_steps, |
| ) |
| elif pe_type == "random": |
| self.emb = RandomPE( |
| d_model=input_size, |
| pe_mode=pe_mode, |
| num_pe_neuron=num_pe_neuron, |
| neuron_pe_scale=neuron_pe_scale, |
| dropout=dropout, |
| num_steps=num_steps, |
| ) |
| else: |
| raise ValueError("Unknown embedding type: {}".format(pe_type)) |
|
|
| def forward(self, x): |
| if self.emb_type == "learn": |
| |
| |
| tmp = torch.arange( |
| end=x.size()[1], device=x.device |
| ) |
| embedding = self.emb(tmp) |
| embedding = embedding.repeat([x.size()[0], 1, 1]) |
| x = x + embedding |
| |
| elif self.emb_type in ["static", "conv"]: |
| T, B, L, _ = x.shape |
| x = x.flatten(0, 1) |
| x = self.emb(x.transpose(0, 1)).transpose(0, 1) |
| x = x.reshape(T, B, L, -1) |
| elif self.emb_type in ["neuron", "random"]: |
| T, B, L, _ = x.shape |
| |
| x = self.emb(x) |
| x = x.reshape(T, B, L, -1) |
| return x |
|
|
|
|
| class RepeatEncoder(nn.Module): |
| def __init__(self, output_size: int): |
| super().__init__() |
| self.out_size = output_size |
| self.lif = neuron.LIFNode( |
| tau=tau, |
| step_mode="m", |
| detach_reset=detach_reset, |
| surrogate_function=surrogate.ATan(), |
| ) |
|
|
| def forward(self, inputs: torch.Tensor): |
| |
| inputs = inputs.repeat( |
| tuple([self.out_size] + torch.ones(len(inputs.size()), dtype=int).tolist()) |
| ) |
| inputs = inputs.permute(0, 1, 3, 2) |
| spks = self.lif(inputs) |
| return spks |
|
|
|
|
| class DeltaEncoder(nn.Module): |
| def __init__(self, output_size: int): |
| super().__init__() |
| self.norm = nn.BatchNorm2d(1) |
| self.enc = nn.Linear(1, output_size) |
| self.lif = neuron.LIFNode( |
| tau=tau, |
| step_mode="m", |
| detach_reset=detach_reset, |
| surrogate_function=surrogate.ATan(), |
| ) |
|
|
| def forward(self, inputs: torch.Tensor): |
| |
| delta = torch.zeros_like(inputs) |
| delta[:, 1:] = inputs[:, 1:, :] - inputs[:, :-1, :] |
| delta = delta.unsqueeze(1).permute(0, 1, 3, 2) |
| delta = self.norm(delta) |
| delta = delta.permute(0, 2, 3, 1) |
| enc = self.enc(delta) |
| enc = enc.permute(3, 0, 1, 2) |
| spks = self.lif(enc) |
| return spks |
|
|
|
|
| class ConvEncoder(nn.Module): |
| def __init__(self, output_size: int, kernel_size: int = 3): |
| super().__init__() |
| self.encoder = nn.Sequential( |
| nn.Conv2d( |
| in_channels=1, |
| out_channels=output_size, |
| kernel_size=(1, kernel_size), |
| stride=1, |
| padding=(0, kernel_size // 2), |
| ), |
| nn.BatchNorm2d(output_size), |
| ) |
| self.lif = neuron.LIFNode( |
| tau=tau, |
| step_mode="m", |
| detach_reset=detach_reset, |
| surrogate_function=surrogate.ATan(), |
| ) |
|
|
| def forward(self, inputs: torch.Tensor): |
| |
| inputs = inputs.permute(0, 2, 1).unsqueeze(1) |
| enc = self.encoder(inputs) |
| enc = enc.permute(1, 0, 2, 3) |
| spks = self.lif(enc) |
| return spks |
|
|
|
|
|
|
| SpikeEncoder = { |
| "snntorch": { |
| "repeat": RepeatEncoder, |
| "conv": ConvEncoder, |
| "delta": DeltaEncoder, |
| }, |
| "spikingjelly": { |
| "repeat": RepeatEncoder, |
| "conv": ConvEncoder, |
| "delta": DeltaEncoder, |
| }, |
| } |
|
|
|
|
|
|
| class SpikeRNNCell(nn.Module): |
| def __init__(self, input_size: int, output_size: int): |
| super().__init__() |
| self.input_size = input_size |
| self.linear = nn.Linear(input_size, output_size) |
| self.lif = neuron.LIFNode( |
| tau=tau, |
| step_mode="m", |
| detach_reset=detach_reset, |
| surrogate_function=surrogate.ATan(), |
| ) |
|
|
| def forward(self, x): |
| |
| T, B, L, _ = x.shape |
| x = x.flatten(0, 1) |
| x = self.linear(x) |
| x = x.reshape(T, B, L, -1) |
| x = self.lif(x) |
| return x |
|
|
|
|
| class SpikeRNN_CPG(nn.Module): |
|
|
| def __init__( |
| self, |
| args, |
| hidden_size: int, |
| layers: int = 1, |
| num_steps: int = 4, |
| input_size: Optional[int] = None, |
| max_length: Optional[int] = 5000, |
| weight_file: Optional[Path] = None, |
| encoder_type: Optional[str] = "conv", |
| num_pe_neuron: int = 40, |
| pe_type: str = "neuron", |
| pe_mode: str = "concat", |
| neuron_pe_scale: float = 10000.0, |
| ): |
| super().__init__() |
| self._snn_backend = "spikingjelly" |
| self.hidden_size = args.hidden_size |
| self.num_steps = args.T |
| self.input_size = args.feature_size |
| self.pre_length = args.pre_length |
| self.layers = args.blocks |
| self.pe_type = pe_type |
| self.pe_mode = pe_mode |
| self.num_pe_neuron = num_pe_neuron |
| self.neuron_pe_scale = neuron_pe_scale |
| self.temporal_encoder = SpikeEncoder[self._snn_backend][encoder_type](self.num_steps) |
| self.args = args |
|
|
| self.pe = PositionEmbedding( |
| pe_type=pe_type, |
| pe_mode=pe_mode, |
| neuron_pe_scale=neuron_pe_scale, |
| input_size=self.input_size, |
| max_len=max_length, |
| num_pe_neuron=self.num_pe_neuron, |
| dropout=0.1, |
| num_steps=self.num_steps, |
| ) |
|
|
| if self.pe_type == "neuron" and self.pe_mode == "concat": |
| self.dim = hidden_size + num_pe_neuron |
| else: |
| self.dim = hidden_size |
|
|
| if self.pe_type == "neuron" and self.pe_mode == "concat": |
| self.encoder = nn.Linear(input_size + num_pe_neuron, self.dim) |
| else: |
| self.encoder = nn.Linear(input_size, self.dim) |
|
|
| self.init_lif = neuron.LIFNode( |
| tau=tau, |
| step_mode="m", |
| detach_reset=detach_reset, |
| surrogate_function=surrogate.ATan(), |
| v_threshold=1.0, |
| backend=backend, |
| ) |
|
|
| self.net = nn.Sequential( |
| *[ |
| SpikeRNNCell(input_size=self.dim, output_size=self.dim) |
| for i in range(layers) |
| ] |
| ) |
|
|
| self.__output_size = self.dim |
| self.fc1 = nn.Linear(self.__output_size, args.feature_size) |
| self.fc2 = nn.Linear(args.seq_length, self.pre_length) |
| self.to('cuda:0') |
|
|
|
|
| def forward( |
| self, |
| inputs: torch.Tensor, |
| ): |
| functional.reset_net(self) |
| if self.args.normalize: |
| mean = inputs.mean(dim=1, keepdim=True).detach() |
| inputs = inputs - mean |
|
|
| std = torch.sqrt(torch.var(inputs, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() |
| inputs = inputs / std |
|
|
|
|
| hiddens = self.temporal_encoder(inputs) |
| hiddens = hiddens.transpose(-2, -1) |
| T, B, L, _ = hiddens.size() |
| if self.pe_type != "none": |
| hiddens = self.pe(hiddens) |
| hiddens = self.encoder(hiddens.flatten(0, 1)).reshape(T, B, L, -1) |
| hiddens = self.init_lif(hiddens) |
| hiddens = self.net(hiddens) |
| out = hiddens.mean(0) |
| preds = self.fc1(out) |
| preds = self.fc2(preds.permute(0, 2, 1)) |
| preds = preds.permute(0, 2, 1).contiguous() |
|
|
| if self.args.normalize: |
| preds = preds * std + mean |
| |
| aux = {'gate_l0': torch.tensor(0.0, device=preds.device)} |
| |
| return preds, aux |
|
|
|
|
|
|