SpikF-GO / model /Spikformer_CPG.py
bakhshaliyev's picture
Upload folder using huggingface_hub
d88ab6d verified
raw
history blame contribute delete
14.1 kB
from typing import Optional
from pathlib import Path
import torch
from torch import nn
from spikingjelly.activation_based import surrogate, neuron, functional
import math
from dataclasses import dataclass
import warnings
tau = 2.0 # beta = 1 - 1/tau
backend = "torch"
detach_reset = True
@dataclass
class CPG(nn.Module):
num_neurons: int = 40
w_max: float = 10000.0
l_max: int = 5000
def __post_init__(self):
self._cpg = torch.zeros(self.l_max, self.num_neurons)
position = torch.arange(0, self.l_max, dtype=torch.float).unsqueeze(
1
) # MaxL, 1
div_term = torch.exp(
torch.arange(0, self.num_neurons, 2).float()
* (-math.log(self.w_max) / self.num_neurons)
)
div_term_single = torch.exp(
torch.arange(0, self.num_neurons - 1, 2).float()
* (-math.log(self.w_max) / self.num_neurons)
)
self._cpg[:, 0::2] = torch.heaviside(
torch.sin(position * div_term) - 0.8, torch.tensor([1.0])
)
self._cpg[:, 1::2] = torch.heaviside(
torch.cos(position * div_term_single) - 0.8, torch.tensor([1.0])
)
@property
def cpg(self):
return self._cpg
class CPGLinear(nn.Module):
def __init__(
self, input_size: int, output_size: int, cpg: CPG = CPG(), dropout: float = 0.1
):
super().__init__()
self.cpg = nn.Parameter(cpg.cpg, requires_grad=False)
self.inp_linear = nn.Linear(input_size, output_size)
self.cpg_linear = nn.Linear(cpg.num_neurons, output_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor):
# B TL D
cpg = self.cpg[: x.size(-2)]
x = self.dropout(x)
return self.inp_linear(x) + self.cpg_linear(cpg)
class RepeatEncoder(nn.Module):
def __init__(self, output_size: int):
super().__init__()
self.out_size = output_size
self.lif = neuron.LIFNode(
tau=tau,
step_mode="m",
detach_reset=detach_reset,
surrogate_function=surrogate.ATan(),
)
def forward(self, inputs: torch.Tensor):
# inputs: B, L, C
inputs = inputs.repeat(
tuple([self.out_size] + torch.ones(len(inputs.size()), dtype=int).tolist())
) # T B L C
inputs = inputs.permute(0, 1, 3, 2) # T B C L
spks = self.lif(inputs) # T B C L
return spks
class DeltaEncoder(nn.Module):
def __init__(self, output_size: int):
super().__init__()
self.norm = nn.BatchNorm2d(1)
self.enc = nn.Linear(1, output_size)
self.lif = neuron.LIFNode(
tau=tau,
step_mode="m",
detach_reset=detach_reset,
surrogate_function=surrogate.ATan(),
)
def forward(self, inputs: torch.Tensor):
# inputs: B, L, C
delta = torch.zeros_like(inputs)
delta[:, 1:] = inputs[:, 1:, :] - inputs[:, :-1, :]
delta = delta.unsqueeze(1).permute(0, 1, 3, 2) # B, 1, C, L
delta = self.norm(delta)
delta = delta.permute(0, 2, 3, 1) # B, C, L, 1
enc = self.enc(delta) # B, C, L, T
enc = enc.permute(3, 0, 1, 2) # T, B, C, L
spks = self.lif(enc)
return spks
class ConvEncoder(nn.Module):
def __init__(self, output_size: int, kernel_size: int = 3):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=output_size,
kernel_size=(1, kernel_size),
stride=1,
padding=(0, kernel_size // 2),
),
nn.BatchNorm2d(output_size),
)
self.lif = neuron.LIFNode(
tau=tau,
step_mode="m",
detach_reset=detach_reset,
surrogate_function=surrogate.ATan(),
)
def forward(self, inputs: torch.Tensor):
# inputs: B, L, C
inputs = inputs.permute(0, 2, 1).unsqueeze(1) # B, 1, C, L
enc = self.encoder(inputs) # B, T, C, L
enc = enc.permute(1, 0, 2, 3) # T, B, C, L
spks = self.lif(enc) # T, B, C, L
return spks
SpikeEncoder = {
"snntorch": {
"repeat": RepeatEncoder,
"conv": ConvEncoder,
"delta": DeltaEncoder,
},
"spikingjelly": {
"repeat": RepeatEncoder,
"conv": ConvEncoder,
"delta": DeltaEncoder,
},
}
class SSA(nn.Module):
def __init__(
self, length, tau, common_thr, dim, heads=8, qkv_bias=False, qk_scale=0.25
):
super().__init__()
assert dim % heads == 0, f"dim {dim} should be divided by num_heads {heads}."
self.dim = dim
self.heads = heads
self.qk_scale = qk_scale
self.q_m = nn.Linear(dim, dim)
self.q_bn = nn.BatchNorm1d(dim)
self.q_lif = neuron.LIFNode(
tau=tau,
step_mode="m",
detach_reset=detach_reset,
surrogate_function=surrogate.ATan(),
v_threshold=common_thr,
backend=backend,
)
self.k_m = nn.Linear(dim, dim)
self.k_bn = nn.BatchNorm1d(dim)
self.k_lif = neuron.LIFNode(
tau=tau,
step_mode="m",
detach_reset=detach_reset,
surrogate_function=surrogate.ATan(),
v_threshold=common_thr,
backend=backend,
)
self.v_m = nn.Linear(dim, dim)
self.v_bn = nn.BatchNorm1d(dim)
self.v_lif = neuron.LIFNode(
tau=tau,
step_mode="m",
detach_reset=detach_reset,
surrogate_function=surrogate.ATan(),
v_threshold=common_thr,
backend=backend,
)
self.attn_lif = neuron.LIFNode(
tau=tau,
step_mode="m",
detach_reset=detach_reset,
surrogate_function=surrogate.ATan(),
v_threshold=common_thr / 2,
backend=backend,
)
self.last_m = nn.Linear(dim, dim)
self.last_bn = nn.BatchNorm1d(dim)
self.last_lif = neuron.LIFNode(
tau=tau,
step_mode="m",
detach_reset=detach_reset,
surrogate_function=surrogate.ATan(),
v_threshold=common_thr,
backend=backend,
)
def forward(self, x):
T, B, L, D = x.shape
x_for_qkv = x.flatten(0, 1) # TB L D
q_m_out = self.q_m(x_for_qkv) # TB L D
q_m_out = (
self.q_bn(q_m_out.transpose(-1, -2))
.transpose(-1, -2)
.reshape(T, B, L, D)
.contiguous()
)
q_m_out = self.q_lif(q_m_out)
q = (
q_m_out.reshape(T, B, L, self.heads, D // self.heads)
.permute(0, 1, 3, 2, 4)
.contiguous()
)
k_m_out = self.k_m(x_for_qkv)
k_m_out = (
self.k_bn(k_m_out.transpose(-1, -2))
.transpose(-1, -2)
.reshape(T, B, L, D)
.contiguous()
)
k_m_out = self.k_lif(k_m_out)
k = (
k_m_out.reshape(T, B, L, self.heads, D // self.heads)
.permute(0, 1, 3, 2, 4)
.contiguous()
)
v_m_out = self.v_m(x_for_qkv)
v_m_out = (
self.v_bn(v_m_out.transpose(-1, -2))
.transpose(-1, -2)
.reshape(T, B, L, D)
.contiguous()
)
v_m_out = self.v_lif(v_m_out)
v = (
v_m_out.reshape(T, B, L, self.heads, D // self.heads)
.permute(0, 1, 3, 2, 4)
.contiguous()
)
attn = (q @ k.transpose(-2, -1)) * self.qk_scale
x = attn @ v # x_shape: T * B * heads * L * D//heads
x = x.transpose(2, 3).reshape(T, B, L, D).contiguous()
x = self.attn_lif(x)
x = x.flatten(0, 1)
x = self.last_m(x)
x = self.last_bn(x.transpose(-1, -2)).transpose(-1, -2)
x = self.last_lif(x.reshape(T, B, L, D).contiguous())
return x
class MLP(nn.Module):
def __init__(
self,
length,
tau,
common_thr,
in_features,
hidden_features=None,
out_features=None,
):
super().__init__()
out_features = out_features or in_features
self.in_features = in_features
self.hidden_features = hidden_features
self.out_features = out_features
self.fc1 = CPGLinear(in_features, hidden_features)
self.bn1 = nn.BatchNorm1d(hidden_features)
self.lif1 = neuron.LIFNode(
tau=tau,
step_mode="m",
detach_reset=detach_reset,
surrogate_function=surrogate.ATan(),
v_threshold=common_thr,
backend=backend,
)
self.fc2 = CPGLinear(hidden_features, out_features)
self.bn2 = nn.BatchNorm1d(out_features)
self.lif2 = neuron.LIFNode(
tau=tau,
step_mode="m",
detach_reset=detach_reset,
surrogate_function=surrogate.ATan(),
v_threshold=common_thr,
backend=backend,
)
def forward(self, x):
T, B, L, D = x.shape
x = x.transpose(0, 1).flatten(1, 2) # B TL D
x = self.fc1(x) # B TL H
x = (
self.bn1(x.transpose(-1, -2))
.transpose(-1, -2)
.reshape(B, T, L, self.hidden_features)
.contiguous()
) # B T L H
x = self.lif1(x.transpose(0, 1)).transpose(0, 1) # B T L H
x = x.flatten(1, 2) # B TL H
x = self.fc2(x) # B TL D
x = (
self.bn2(x.transpose(-1, -2))
.transpose(-1, -2)
.reshape(B, T, L, D)
.contiguous()
) # B T L D
x = self.lif2(x.transpose(0, 1)) # T B L D
return x
class Block(nn.Module):
def __init__(
self,
length,
tau,
common_thr,
dim,
d_ff,
heads=8,
qkv_bias=False,
qk_scale=0.125,
):
super().__init__()
self.attn = SSA(
length=length,
tau=tau,
common_thr=common_thr,
dim=dim,
heads=heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
)
self.mlp = MLP(
length=length,
tau=tau,
common_thr=common_thr,
in_features=dim,
hidden_features=d_ff,
)
def forward(self, x):
# T B L D
x = x + self.attn(x)
x = x + self.mlp(x)
return x
class Spikformer_CPG(nn.Module):
def __init__(
self,
args,
dim: int=256,
d_ff: Optional[int] = None,
num_pe_neuron: int = 40,
pe_type: str = "neuron",
pe_mode: str = "concat", # "add" or concat
neuron_pe_scale: float = 10000.0, # "100" or "1000" or "10000"
depths: int = 2,
common_thr: float = 1.0,
max_length: int = 5000,
num_steps: int = 4,
heads: int = 8,
qkv_bias: bool = False,
qk_scale: float = 0.125,
input_size: Optional[int] = None,
weight_file: Optional[Path] = None,
):
super().__init__()
self.dim = 256
self.d_ff = 1024
self.T = args.T
self.depths = args.blocks
self.pe_type = pe_type
self.pe_mode = pe_mode
self.num_pe_neuron = num_pe_neuron
self.input_size = args.feature_size
self.pre_length = args.pre_length
self.args = args
self._snn_backend = "spikingjelly"
self.temporal_encoder = SpikeEncoder[self._snn_backend]["conv"](num_steps)
self.encoder = CPGLinear(self.input_size, dim, CPG(num_neurons=num_pe_neuron))
self.init_lif = neuron.LIFNode(
tau=tau,
step_mode="m",
detach_reset=detach_reset,
surrogate_function=surrogate.ATan(),
v_threshold=common_thr,
backend=backend,
)
self.blocks = nn.ModuleList(
[
Block(
length=max_length,
tau=tau,
common_thr=common_thr,
dim=dim,
d_ff=self.d_ff,
heads=heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
)
for _ in range(depths)
]
)
self.apply(self._init_weights)
self.fc = nn.Linear(args.seq_length*dim, args.pre_length*args.feature_size)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
def forward(self, x: torch.Tensor):
functional.reset_net(self)
if self.args.normalize:
mean = x.mean(dim=1, keepdim=True).detach() # shape [B, 1, D]
x = x - mean
std = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
x = x / std
x = self.temporal_encoder(x) # B L C -> T B C L
T, B, _, L = x.shape
x = x.permute(1, 0, 3, 2) # B T L C
x = x.flatten(1, 2) # B TL C
x = self.encoder(x) # B TL D
x = x.reshape(B, T, L, -1).permute(1, 0, 2, 3) # T B L D
x = self.init_lif(x)
for blk in self.blocks:
x = blk(x) # T B L D
out = x.mean(0)
out = self.fc(out.flatten(-2, -1)).reshape(-1, self.pre_length, self.input_size) # B D L -> B L D
if self.args.normalize:
out = out * std + mean # denormalization
aux = {'gate_l0': torch.tensor(0.0, device=out.device)} # placeholder
return out, aux # B D L -> B L D