Upload folder using huggingface_hub
Browse files- .gitattributes +3 -0
- CITATION.cff +12 -0
- LICENSE +21 -0
- README.md +191 -0
- assets/architecture.pdf +3 -0
- assets/spikf-go-architecture.png +3 -0
- assets/supplementary.pdf +3 -0
- data/.gitkeep +0 -0
- data/data_loader.py +243 -0
- model/FourierGNN.py +168 -0
- model/SpikF.py +151 -0
- model/SpikF_GO.py +445 -0
- model/SpikF_GO_CPG.py +514 -0
- model/SpikeGRU.py +241 -0
- model/SpikeRNN_CPG.py +489 -0
- model/SpikeTCN_CPG.py +596 -0
- model/Spikformer_CPG.py +487 -0
- model/TS_Former.py +1365 -0
- model/TS_GRU.py +640 -0
- model/TS_TCN.py +1030 -0
- model/iSpikformer.py +129 -0
- requirements.txt +6 -0
- scripts/ecl.sh +232 -0
- train.py +545 -0
- utils/utils.py +252 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/architecture.pdf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/spikf-go-architecture.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/supplementary.pdf filter=lfs diff=lfs merge=lfs -text
|
CITATION.cff
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cff-version: 1.2.0
|
| 2 |
+
title: "SpikF-GO: Spiking Fourier Graph Operators for Multivariate Time Series Forecasting"
|
| 3 |
+
message: "If you use this code, please cite our ECML PKDD 2026 paper."
|
| 4 |
+
authors:
|
| 5 |
+
- family-names: Bakhshaliyev
|
| 6 |
+
given-names: Jafar
|
| 7 |
+
- family-names: Landwehr
|
| 8 |
+
given-names: Niels
|
| 9 |
+
year: 2026
|
| 10 |
+
conference: "ECML PKDD 2026"
|
| 11 |
+
repository-code: "https://github.com/jafarbakhshaliyev/SpikF-GO"
|
| 12 |
+
license: MIT
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Jafar Bakhshaliyev
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,194 @@
|
|
| 1 |
---
|
| 2 |
license: mit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- time-series
|
| 5 |
+
- forecasting
|
| 6 |
+
- spiking-neural-networks
|
| 7 |
+
- graph-neural-networks
|
| 8 |
+
- multivariate-time-series
|
| 9 |
---
|
| 10 |
+
|
| 11 |
+
# SpikF-GO: Spiking Fourier Graph Operators for Multivariate Time Series Forecasting
|
| 12 |
+
|
| 13 |
+
[](https://arxiv.org/abs/2606.13901)
|
| 14 |
+
[](https://arxiv.org/abs/2606.13901)
|
| 15 |
+
[](LICENSE)
|
| 16 |
+
|
| 17 |
+
📄 **Paper (arXiv):** https://arxiv.org/abs/2606.13901
|
| 18 |
+
💻 **GitHub:** https://github.com/jafarbakhshaliyev/SpikF-GO
|
| 19 |
+
|
| 20 |
+
Official implementation of **SpikF-GO: Spiking Fourier Graph Operators for Multivariate Time Series Forecasting**, accepted to the **ECML PKDD 2026 Research Track**.
|
| 21 |
+
|
| 22 |
+

|
| 23 |
+
|
| 24 |
+
---
|
| 25 |
+
|
| 26 |
+
## Abstract
|
| 27 |
+
|
| 28 |
+
SpikF-GO is a spiking neural architecture for multivariate time series forecasting. It combines the hypervariate graph formulation of FourierGNN with spike-driven Fourier-domain graph processing, enabling joint modeling of intra-series temporal dependencies, inter-series dependencies, and time-varying cross-variable interactions. The model introduces sparse frequency selection and Complex LIF-based spectral gating to preserve event-driven computation in the Fourier domain. We also provide **SpikF-GO w/ CPG**, which incorporates Central Pattern Generator-based positional signals for improved long-range temporal modeling.
|
| 29 |
+
|
| 30 |
+
---
|
| 31 |
+
|
| 32 |
+
## Key Contributions
|
| 33 |
+
|
| 34 |
+
- **Graph-based SNN forecasting:** SpikF-GO brings hypervariate graph modeling into SNN-based multivariate time series forecasting.
|
| 35 |
+
- **Spike-driven Fourier graph operators:** The model combines sparse frequency gating with Complex LIF-based spectral processing to preserve event-driven computation in the Fourier domain.
|
| 36 |
+
- **Unified SNN benchmark:** We evaluate SpikF-GO against major SNN forecasting families under a common experimental protocol across eight benchmark datasets.
|
| 37 |
+
- **Energy-aware forecasting:** SpikF-GO achieves competitive-to-superior forecasting performance while reducing theoretical energy consumption relative to FourierGNN.
|
| 38 |
+
|
| 39 |
+
---
|
| 40 |
+
|
| 41 |
+
## Related Library: SpikingTSF
|
| 42 |
+
|
| 43 |
+
We also maintain **[SpikingTSF](https://github.com/spikora/SpikingTSF)**, a broader open-source benchmark library for spiking neural network-based time series forecasting. SpikingTSF unifies SNN forecasting architectures and ANN baselines under a common training and evaluation protocol across datasets, horizons, metrics, and random seeds.
|
| 44 |
+
|
| 45 |
+
> **Note:** SpikingTSF is a benchmarking library and may not reproduce all experiments from this repository directly.
|
| 46 |
+
|
| 47 |
+
---
|
| 48 |
+
|
| 49 |
+
## Repository Structure
|
| 50 |
+
|
| 51 |
+
```
|
| 52 |
+
SpikF-GO/
|
| 53 |
+
├── README.md
|
| 54 |
+
├── LICENSE
|
| 55 |
+
├── CITATION.cff
|
| 56 |
+
├── requirements.txt
|
| 57 |
+
├── train.py # main training & evaluation entry point
|
| 58 |
+
├── model/ # SpikF-GO + all baseline implementations
|
| 59 |
+
├── utils/ # shared utilities (metrics, helpers)
|
| 60 |
+
├── data/
|
| 61 |
+
│ └── data_loader.py # dataset loading (raw files placed here at runtime)
|
| 62 |
+
├── scripts/
|
| 63 |
+
│ ├── ecg.sh
|
| 64 |
+
│ ├── covid.sh
|
| 65 |
+
│ ├── solar.sh
|
| 66 |
+
│ ├── ecl.sh
|
| 67 |
+
│ ├── traffic.sh
|
| 68 |
+
│ ├── metr_la.sh
|
| 69 |
+
│ ├── pems_bay.sh
|
| 70 |
+
│ └── wiki.sh
|
| 71 |
+
└── assets/
|
| 72 |
+
├── spikf-go-architecture.png
|
| 73 |
+
└── supplementary.pdf
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
---
|
| 77 |
+
|
| 78 |
+
## Environment Setup
|
| 79 |
+
|
| 80 |
+
Create and activate a virtual environment:
|
| 81 |
+
|
| 82 |
+
**Linux / macOS**
|
| 83 |
+
```bash
|
| 84 |
+
python3 -m venv venv
|
| 85 |
+
source venv/bin/activate
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
**Windows**
|
| 89 |
+
```bash
|
| 90 |
+
python -m venv venv
|
| 91 |
+
venv\Scripts\activate
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
Install dependencies:
|
| 95 |
+
```bash
|
| 96 |
+
pip install -r requirements.txt
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
Experiments were run with **PyTorch 2.5.1** on a single **NVIDIA RTX 4090**.
|
| 100 |
+
|
| 101 |
+
---
|
| 102 |
+
|
| 103 |
+
## Dataset
|
| 104 |
+
|
| 105 |
+
Download the processed datasets from Figshare:
|
| 106 |
+
|
| 107 |
+
https://figshare.com/s/7617530bce306584fe95?file=62576929
|
| 108 |
+
|
| 109 |
+
Place all dataset files **directly** inside the `data/` folder (do **not** create subfolders):
|
| 110 |
+
|
| 111 |
+
```
|
| 112 |
+
SpikF-GO/
|
| 113 |
+
├── data/
|
| 114 |
+
│ ├── dataset_file_1
|
| 115 |
+
│ ├── dataset_file_2
|
| 116 |
+
│ └── ...
|
| 117 |
+
├── model/
|
| 118 |
+
├── scripts/
|
| 119 |
+
└── train.py
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
---
|
| 123 |
+
|
| 124 |
+
## Run Experiments
|
| 125 |
+
|
| 126 |
+
Scripts are in `scripts/`, one per dataset:
|
| 127 |
+
|
| 128 |
+
```bash
|
| 129 |
+
bash scripts/ecg.sh
|
| 130 |
+
bash scripts/covid.sh
|
| 131 |
+
bash scripts/solar.sh
|
| 132 |
+
bash scripts/ecl.sh
|
| 133 |
+
bash scripts/traffic.sh
|
| 134 |
+
bash scripts/metr_la.sh
|
| 135 |
+
bash scripts/pems_bay.sh
|
| 136 |
+
bash scripts/wiki.sh
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
Each script sets the exact hyperparameters used to produce the results reported in the paper.
|
| 140 |
+
|
| 141 |
+
---
|
| 142 |
+
|
| 143 |
+
## Supplementary Material
|
| 144 |
+
|
| 145 |
+
Available at [`assets/supplementary.pdf`](assets/supplementary.pdf).
|
| 146 |
+
|
| 147 |
+
---
|
| 148 |
+
|
| 149 |
+
## Citation
|
| 150 |
+
|
| 151 |
+
If you use this code or build on SpikF-GO, please cite our paper:
|
| 152 |
+
|
| 153 |
+
**arXiv preprint:**
|
| 154 |
+
```bibtex
|
| 155 |
+
@misc{bakhshaliyev2026spikfgo,
|
| 156 |
+
title = {SpikF-GO: Spiking Fourier Graph Operators for Multivariate Time Series Forecasting},
|
| 157 |
+
author = {Bakhshaliyev, Jafar and Landwehr, Niels},
|
| 158 |
+
year = {2026},
|
| 159 |
+
eprint = {2606.13901},
|
| 160 |
+
archivePrefix= {arXiv},
|
| 161 |
+
primaryClass = {cs.LG},
|
| 162 |
+
url = {https://arxiv.org/abs/2606.13901}
|
| 163 |
+
}
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
**ECML PKDD 2026 proceedings:**
|
| 167 |
+
```bibtex
|
| 168 |
+
@inproceedings{bakhshaliyev2026spikfgo,
|
| 169 |
+
title = {SpikF-GO: Spiking Fourier Graph Operators for Multivariate Time Series Forecasting},
|
| 170 |
+
author = {Bakhshaliyev, Jafar and Landwehr, Niels},
|
| 171 |
+
booktitle = {ECML PKDD},
|
| 172 |
+
year = {2026}
|
| 173 |
+
}
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
See [`CITATION.cff`](CITATION.cff) for full citation metadata.
|
| 177 |
+
|
| 178 |
+
---
|
| 179 |
+
|
| 180 |
+
## Acknowledgements
|
| 181 |
+
|
| 182 |
+
The baselines in `model/` build on prior work. We thank the authors for releasing their code; original licenses are respected.
|
| 183 |
+
|
| 184 |
+
- **`SpikF.py`** — adapted from **SpikF** (Wu, Huo & Chen, *"SpikF: Spiking Fourier Network for Efficient Long-term Prediction"*, [ICML 2025 / PMLR v267](https://proceedings.mlr.press/v267/wu25m.html)).
|
| 185 |
+
- **`TS_Former.py`, `TS_GRU.py`, `TS_TCN.py`** — adapted from **TS-LIF** (Feng et al., *"TS-LIF: A Temporal Segment Spiking Neuron Network for Time Series Forecasting"*, [arXiv:2503.05108](https://arxiv.org/abs/2503.05108)).
|
| 186 |
+
- **`iSpikformer.py`, `SpikeGRU.py`** — adapted from **SeqSNN** (Lv et al., *"Efficient and Effective Time-Series Forecasting with Spiking Neural Networks"*, [arXiv:2402.01533](https://arxiv.org/abs/2402.01533)), [microsoft/SeqSNN](https://github.com/microsoft/SeqSNN).
|
| 187 |
+
- **`SpikeRNN_CPG.py`, `SpikeTCN_CPG.py`, `Spikformer_CPG.py`** — CPG variants build on [arXiv:2405.14362](https://arxiv.org/abs/2405.14362) / [microsoft/SeqSNN](https://github.com/microsoft/SeqSNN).
|
| 188 |
+
- **`FourierGNN.py`** — adapted from **FourierGNN**, [arXiv:2311.06190](https://arxiv.org/abs/2311.06190) / [aikunyi/FourierGNN](https://github.com/aikunyi/FourierGNN).
|
| 189 |
+
|
| 190 |
+
---
|
| 191 |
+
|
| 192 |
+
## License
|
| 193 |
+
|
| 194 |
+
This project is released under the [MIT License](LICENSE).
|
assets/architecture.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5b4851b633d30ef57fa79e8c1096ad660b31e8a9dad9376e7840b39ff3adc0a4
|
| 3 |
+
size 194717
|
assets/spikf-go-architecture.png
ADDED
|
Git LFS Details
|
assets/supplementary.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:353d376afbe7e500c8af0a7cf21f927a95eeac9feb781a324991c5b2b063dc4e
|
| 3 |
+
size 219947
|
data/.gitkeep
ADDED
|
File without changes
|
data/data_loader.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import datetime
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
from sklearn.preprocessing import StandardScaler
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _split_with_overlap(data: np.ndarray, train_ratio: float, val_ratio: float, seq_len: int):
|
| 12 |
+
"""
|
| 13 |
+
Time split with overlap for val/test to allow past context:
|
| 14 |
+
train: [0 : train_end)
|
| 15 |
+
val : [train_end - seq_len : val_end)
|
| 16 |
+
test : [val_end - seq_len : T)
|
| 17 |
+
"""
|
| 18 |
+
T = len(data)
|
| 19 |
+
train_end = int(T * train_ratio)
|
| 20 |
+
val_end = int(T * (train_ratio + val_ratio))
|
| 21 |
+
|
| 22 |
+
train_end = max(0, min(train_end, T))
|
| 23 |
+
val_end = max(train_end, min(val_end, T))
|
| 24 |
+
|
| 25 |
+
val_start = max(0, train_end - seq_len)
|
| 26 |
+
test_start = max(0, val_end - seq_len)
|
| 27 |
+
|
| 28 |
+
train_data = data[:train_end]
|
| 29 |
+
val_data = data[val_start:val_end]
|
| 30 |
+
test_data = data[test_start:]
|
| 31 |
+
|
| 32 |
+
return train_data, val_data, test_data
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _fit_transform_splits(train_data, val_data, test_data, type_flag: str, scaler=None):
|
| 36 |
+
if type_flag == "1":
|
| 37 |
+
if scaler is None:
|
| 38 |
+
scaler = StandardScaler()
|
| 39 |
+
scaler.fit(train_data)
|
| 40 |
+
train_data = scaler.transform(train_data)
|
| 41 |
+
val_data = scaler.transform(val_data)
|
| 42 |
+
test_data = scaler.transform(test_data)
|
| 43 |
+
return train_data, val_data, test_data, scaler
|
| 44 |
+
else:
|
| 45 |
+
return train_data, val_data, test_data, None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _to_float32(x: np.ndarray) -> np.ndarray:
|
| 49 |
+
return np.asarray(x, dtype=np.float32)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _clean_numeric_csv(df: pd.DataFrame) -> np.ndarray:
|
| 53 |
+
"""
|
| 54 |
+
Keep only numeric columns, and drop common junk index columns.
|
| 55 |
+
"""
|
| 56 |
+
drop_cols = [c for c in df.columns if str(c).lower().startswith("unnamed")]
|
| 57 |
+
if drop_cols:
|
| 58 |
+
df = df.drop(columns=drop_cols, errors="ignore")
|
| 59 |
+
|
| 60 |
+
num_df = df.select_dtypes(include=[np.number])
|
| 61 |
+
|
| 62 |
+
if num_df.shape[1] == 0:
|
| 63 |
+
raise ValueError("No numeric columns found in CSV after cleaning. Check your file format.")
|
| 64 |
+
|
| 65 |
+
num_df = num_df.dropna(axis=0, how="any")
|
| 66 |
+
|
| 67 |
+
return num_df.values.astype(np.float32)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class _BaseTimeSeriesDataset(Dataset):
|
| 72 |
+
|
| 73 |
+
def __init__(self, flag, seq_len, pre_len):
|
| 74 |
+
assert flag in ["train", "val", "test"]
|
| 75 |
+
self.flag = flag
|
| 76 |
+
self.seq_len = int(seq_len)
|
| 77 |
+
self.pre_len = int(pre_len)
|
| 78 |
+
self.scaler = None
|
| 79 |
+
self.split = None
|
| 80 |
+
|
| 81 |
+
def __getitem__(self, index):
|
| 82 |
+
s_begin = index
|
| 83 |
+
s_end = s_begin + self.seq_len
|
| 84 |
+
r_end = s_end + self.pre_len
|
| 85 |
+
|
| 86 |
+
x = self.split[s_begin:s_end]
|
| 87 |
+
y = self.split[s_end:r_end]
|
| 88 |
+
return x, y
|
| 89 |
+
|
| 90 |
+
def __len__(self):
|
| 91 |
+
if self.split is None:
|
| 92 |
+
return 0
|
| 93 |
+
return max(0, len(self.split) - self.seq_len - self.pre_len)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class Dataset_Dhfm(_BaseTimeSeriesDataset):
|
| 97 |
+
def __init__(self, root_path, flag, seq_len, pre_len, type, train_ratio, val_ratio, scaler=None):
|
| 98 |
+
super().__init__(flag, seq_len, pre_len)
|
| 99 |
+
self.path = root_path
|
| 100 |
+
|
| 101 |
+
load_data = np.load(root_path)
|
| 102 |
+
data = np.array(load_data).transpose()
|
| 103 |
+
data = _to_float32(data)
|
| 104 |
+
|
| 105 |
+
train_data, val_data, test_data = _split_with_overlap(data, train_ratio, val_ratio, self.seq_len)
|
| 106 |
+
train_data, val_data, test_data, self.scaler = _fit_transform_splits(train_data, val_data, test_data, type, scaler)
|
| 107 |
+
|
| 108 |
+
if self.flag == "train":
|
| 109 |
+
self.split = train_data
|
| 110 |
+
elif self.flag == "val":
|
| 111 |
+
self.split = val_data
|
| 112 |
+
else:
|
| 113 |
+
self.split = test_data
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class Dataset_ECG(_BaseTimeSeriesDataset):
|
| 117 |
+
def __init__(self, root_path, flag, seq_len, pre_len, type, train_ratio, val_ratio, scaler=None):
|
| 118 |
+
super().__init__(flag, seq_len, pre_len)
|
| 119 |
+
self.path = root_path
|
| 120 |
+
|
| 121 |
+
df = pd.read_csv(root_path)
|
| 122 |
+
data = _clean_numeric_csv(df)
|
| 123 |
+
|
| 124 |
+
train_data, val_data, test_data = _split_with_overlap(data, train_ratio, val_ratio, self.seq_len)
|
| 125 |
+
train_data, val_data, test_data, self.scaler = _fit_transform_splits(train_data, val_data, test_data, type, scaler)
|
| 126 |
+
|
| 127 |
+
if self.flag == "train":
|
| 128 |
+
self.split = train_data
|
| 129 |
+
elif self.flag == "val":
|
| 130 |
+
self.split = val_data
|
| 131 |
+
else:
|
| 132 |
+
self.split = test_data
|
| 133 |
+
|
| 134 |
+
class Dataset_Solar(_BaseTimeSeriesDataset):
|
| 135 |
+
def __init__(self, root_path, flag, seq_len, pre_len, type, train_ratio, val_ratio, scaler=None):
|
| 136 |
+
super().__init__(flag, seq_len, pre_len)
|
| 137 |
+
self.path = root_path
|
| 138 |
+
|
| 139 |
+
files = os.listdir(root_path)
|
| 140 |
+
solar_data = []
|
| 141 |
+
time_data = None
|
| 142 |
+
|
| 143 |
+
for file in files:
|
| 144 |
+
full = os.path.join(root_path, file)
|
| 145 |
+
if os.path.isdir(full):
|
| 146 |
+
continue
|
| 147 |
+
if file.startswith("DA_"):
|
| 148 |
+
arr = pd.read_csv(full).values
|
| 149 |
+
raw_time = arr[:, 0:1]
|
| 150 |
+
if time_data is None:
|
| 151 |
+
time_data = raw_time
|
| 152 |
+
raw_data = arr[:, 1:arr.shape[1]]
|
| 153 |
+
raw_data = raw_data.transpose()
|
| 154 |
+
solar_data.append(raw_data)
|
| 155 |
+
|
| 156 |
+
if len(solar_data) == 0 or time_data is None:
|
| 157 |
+
raise ValueError(f"No solar files found in {root_path} with prefix 'DA_'.")
|
| 158 |
+
|
| 159 |
+
solar_data = np.array(solar_data).squeeze(1).transpose() # (T, N)
|
| 160 |
+
time_data = np.array(time_data) # (T, 1)
|
| 161 |
+
out = np.concatenate((time_data, solar_data), axis=1) # (T, 1+N)
|
| 162 |
+
|
| 163 |
+
filtered = []
|
| 164 |
+
for item in out:
|
| 165 |
+
dt = datetime.datetime.strptime(item[0], "%m/%d/%y %H:%M")
|
| 166 |
+
if 8 <= dt.hour <= 17:
|
| 167 |
+
filtered.append(item[1:out.shape[1]-1])
|
| 168 |
+
|
| 169 |
+
data = _to_float32(np.array(filtered))
|
| 170 |
+
|
| 171 |
+
train_data, val_data, test_data = _split_with_overlap(data, train_ratio, val_ratio, self.seq_len)
|
| 172 |
+
train_data, val_data, test_data, self.scaler = _fit_transform_splits(train_data, val_data, test_data, type, scaler)
|
| 173 |
+
|
| 174 |
+
if self.flag == "train":
|
| 175 |
+
self.split = train_data
|
| 176 |
+
elif self.flag == "val":
|
| 177 |
+
self.split = val_data
|
| 178 |
+
else:
|
| 179 |
+
self.split = test_data
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class Dataset_Wiki(_BaseTimeSeriesDataset):
|
| 183 |
+
def __init__(self, root_path, flag, seq_len, pre_len, type, train_ratio, val_ratio, scaler=None):
|
| 184 |
+
super().__init__(flag, seq_len, pre_len)
|
| 185 |
+
self.path = root_path
|
| 186 |
+
|
| 187 |
+
df = pd.read_csv(root_path)
|
| 188 |
+
|
| 189 |
+
if df.shape[1] < 2:
|
| 190 |
+
raise ValueError("Wiki CSV must have at least 2 columns (time + features).")
|
| 191 |
+
|
| 192 |
+
df_feat = df.iloc[:, 1:]
|
| 193 |
+
data = _clean_numeric_csv(df_feat)
|
| 194 |
+
|
| 195 |
+
train_data, val_data, test_data = _split_with_overlap(data, train_ratio, val_ratio, self.seq_len)
|
| 196 |
+
train_data, val_data, test_data, self.scaler = _fit_transform_splits(train_data, val_data, test_data, type, scaler)
|
| 197 |
+
|
| 198 |
+
if self.flag == "train":
|
| 199 |
+
self.split = train_data
|
| 200 |
+
elif self.flag == "val":
|
| 201 |
+
self.split = val_data
|
| 202 |
+
else:
|
| 203 |
+
self.split = test_data
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class Dataset_PEMS_BAY(_BaseTimeSeriesDataset):
|
| 208 |
+
def __init__(self, root_path, flag, seq_len, pre_len, type, train_ratio, val_ratio, scaler=None, fillna="ffill"):
|
| 209 |
+
super().__init__(flag, seq_len, pre_len)
|
| 210 |
+
self.path = root_path
|
| 211 |
+
|
| 212 |
+
obj = pd.read_hdf(root_path)
|
| 213 |
+
|
| 214 |
+
if isinstance(obj, pd.Series):
|
| 215 |
+
df = obj.to_frame()
|
| 216 |
+
elif isinstance(obj, pd.DataFrame):
|
| 217 |
+
df = obj
|
| 218 |
+
else:
|
| 219 |
+
df = pd.DataFrame(obj)
|
| 220 |
+
|
| 221 |
+
if fillna == "ffill":
|
| 222 |
+
df = df.ffill()
|
| 223 |
+
df = df.fillna(0.0)
|
| 224 |
+
elif fillna == "zero":
|
| 225 |
+
df = df.fillna(0.0)
|
| 226 |
+
elif fillna == "drop":
|
| 227 |
+
df = df.dropna(axis=0, how="any")
|
| 228 |
+
elif fillna is None:
|
| 229 |
+
pass
|
| 230 |
+
else:
|
| 231 |
+
raise ValueError("fillna must be one of: 'ffill', 'zero', 'drop', or None")
|
| 232 |
+
|
| 233 |
+
data = df.values.astype(np.float32)
|
| 234 |
+
|
| 235 |
+
train_data, val_data, test_data = _split_with_overlap(data, train_ratio, val_ratio, self.seq_len)
|
| 236 |
+
train_data, val_data, test_data, self.scaler = _fit_transform_splits(train_data, val_data, test_data, type, scaler)
|
| 237 |
+
|
| 238 |
+
if self.flag == "train":
|
| 239 |
+
self.split = train_data
|
| 240 |
+
elif self.flag == "val":
|
| 241 |
+
self.split = val_data
|
| 242 |
+
else:
|
| 243 |
+
self.split = test_data
|
model/FourierGNN.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class FGN(nn.Module):
|
| 6 |
+
def __init__(self, args, pre_length, embed_size,
|
| 7 |
+
feature_size, seq_length, hidden_size, hard_thresholding_fraction=1, hidden_size_factor=1, sparsity_threshold=0.01):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.embed_size = embed_size
|
| 10 |
+
self.hidden_size = hidden_size
|
| 11 |
+
self.number_frequency = 1
|
| 12 |
+
self.pre_length = pre_length
|
| 13 |
+
self.feature_size = feature_size
|
| 14 |
+
self.seq_length = seq_length
|
| 15 |
+
self.frequency_size = self.embed_size // self.number_frequency
|
| 16 |
+
self.hidden_size_factor = hidden_size_factor
|
| 17 |
+
self.sparsity_threshold = sparsity_threshold
|
| 18 |
+
self.hard_thresholding_fraction = hard_thresholding_fraction
|
| 19 |
+
self.scale = 0.02
|
| 20 |
+
self.embeddings = nn.Parameter(torch.randn(1, self.embed_size))
|
| 21 |
+
self.args = args
|
| 22 |
+
|
| 23 |
+
self.w1 = nn.Parameter(
|
| 24 |
+
self.scale * torch.randn(2, self.frequency_size, self.frequency_size * self.hidden_size_factor))
|
| 25 |
+
self.b1 = nn.Parameter(self.scale * torch.randn(2, self.frequency_size * self.hidden_size_factor))
|
| 26 |
+
self.w2 = nn.Parameter(
|
| 27 |
+
self.scale * torch.randn(2, self.frequency_size * self.hidden_size_factor, self.frequency_size))
|
| 28 |
+
self.b2 = nn.Parameter(self.scale * torch.randn(2, self.frequency_size))
|
| 29 |
+
self.w3 = nn.Parameter(
|
| 30 |
+
self.scale * torch.randn(2, self.frequency_size,
|
| 31 |
+
self.frequency_size * self.hidden_size_factor))
|
| 32 |
+
self.b3 = nn.Parameter(
|
| 33 |
+
self.scale * torch.randn(2, self.frequency_size * self.hidden_size_factor))
|
| 34 |
+
self.embeddings_10 = nn.Parameter(torch.randn(self.seq_length, 8))
|
| 35 |
+
self.fc = nn.Sequential(
|
| 36 |
+
nn.Linear(self.embed_size * 8, 64),
|
| 37 |
+
nn.LeakyReLU(),
|
| 38 |
+
nn.Linear(64, self.hidden_size),
|
| 39 |
+
nn.LeakyReLU(),
|
| 40 |
+
nn.Linear(self.hidden_size, self.pre_length)
|
| 41 |
+
)
|
| 42 |
+
self.to('cuda:0')
|
| 43 |
+
|
| 44 |
+
def tokenEmb(self, x):
|
| 45 |
+
x = x.unsqueeze(2)
|
| 46 |
+
y = self.embeddings
|
| 47 |
+
return x * y
|
| 48 |
+
|
| 49 |
+
# FourierGNN
|
| 50 |
+
def fourierGC(self, x, B, N, L):
|
| 51 |
+
o1_real = torch.zeros([B, (N*L)//2 + 1, self.frequency_size * self.hidden_size_factor],
|
| 52 |
+
device=x.device)
|
| 53 |
+
o1_imag = torch.zeros([B, (N*L)//2 + 1, self.frequency_size * self.hidden_size_factor],
|
| 54 |
+
device=x.device)
|
| 55 |
+
o2_real = torch.zeros(x.shape, device=x.device)
|
| 56 |
+
o2_imag = torch.zeros(x.shape, device=x.device)
|
| 57 |
+
|
| 58 |
+
o3_real = torch.zeros(x.shape, device=x.device)
|
| 59 |
+
o3_imag = torch.zeros(x.shape, device=x.device)
|
| 60 |
+
|
| 61 |
+
o1_real = F.relu(
|
| 62 |
+
torch.einsum('bli,ii->bli', x.real, self.w1[0]) - \
|
| 63 |
+
torch.einsum('bli,ii->bli', x.imag, self.w1[1]) + \
|
| 64 |
+
self.b1[0]
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
o1_imag = F.relu(
|
| 68 |
+
torch.einsum('bli,ii->bli', x.imag, self.w1[0]) + \
|
| 69 |
+
torch.einsum('bli,ii->bli', x.real, self.w1[1]) + \
|
| 70 |
+
self.b1[1]
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# 1 layer
|
| 74 |
+
y = torch.stack([o1_real, o1_imag], dim=-1)
|
| 75 |
+
y = F.softshrink(y, lambd=self.sparsity_threshold)
|
| 76 |
+
|
| 77 |
+
o2_real = F.relu(
|
| 78 |
+
torch.einsum('bli,ii->bli', o1_real, self.w2[0]) - \
|
| 79 |
+
torch.einsum('bli,ii->bli', o1_imag, self.w2[1]) + \
|
| 80 |
+
self.b2[0]
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
o2_imag = F.relu(
|
| 84 |
+
torch.einsum('bli,ii->bli', o1_imag, self.w2[0]) + \
|
| 85 |
+
torch.einsum('bli,ii->bli', o1_real, self.w2[1]) + \
|
| 86 |
+
self.b2[1]
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# 2 layer
|
| 90 |
+
x = torch.stack([o2_real, o2_imag], dim=-1)
|
| 91 |
+
x = F.softshrink(x, lambd=self.sparsity_threshold)
|
| 92 |
+
x = x + y
|
| 93 |
+
|
| 94 |
+
o3_real = F.relu(
|
| 95 |
+
torch.einsum('bli,ii->bli', o2_real, self.w3[0]) - \
|
| 96 |
+
torch.einsum('bli,ii->bli', o2_imag, self.w3[1]) + \
|
| 97 |
+
self.b3[0]
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
o3_imag = F.relu(
|
| 101 |
+
torch.einsum('bli,ii->bli', o2_imag, self.w3[0]) + \
|
| 102 |
+
torch.einsum('bli,ii->bli', o2_real, self.w3[1]) + \
|
| 103 |
+
self.b3[1]
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# 3 layer
|
| 107 |
+
z = torch.stack([o3_real, o3_imag], dim=-1)
|
| 108 |
+
z = F.softshrink(z, lambd=self.sparsity_threshold)
|
| 109 |
+
z = z + x
|
| 110 |
+
z = torch.view_as_complex(z)
|
| 111 |
+
return z
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
|
| 115 |
+
if self.args.normalize:
|
| 116 |
+
|
| 117 |
+
mean = x.mean(dim=1, keepdim=True).detach()
|
| 118 |
+
x = x - mean
|
| 119 |
+
|
| 120 |
+
std = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
| 121 |
+
x = x / std
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 125 |
+
B, N, L = x.shape
|
| 126 |
+
# B*N*L ==> B*NL
|
| 127 |
+
x = x.reshape(B, -1)
|
| 128 |
+
# embedding B*NL ==> B*NL*D
|
| 129 |
+
x = self.tokenEmb(x)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# FFT B*NL*D ==> B*NT/2*D
|
| 133 |
+
x = torch.fft.rfft(x, dim=1, norm='ortho')
|
| 134 |
+
|
| 135 |
+
x = x.reshape(B, (N*L)//2+1, self.frequency_size)
|
| 136 |
+
|
| 137 |
+
bias = x
|
| 138 |
+
|
| 139 |
+
# FourierGNN
|
| 140 |
+
x = self.fourierGC(x, B, N, L)
|
| 141 |
+
|
| 142 |
+
x = x + bias
|
| 143 |
+
|
| 144 |
+
x = x.reshape(B, (N*L)//2+1, self.embed_size)
|
| 145 |
+
|
| 146 |
+
# ifft
|
| 147 |
+
x = torch.fft.irfft(x, n=N*L, dim=1, norm="ortho")
|
| 148 |
+
|
| 149 |
+
x = x.reshape(B, N, L, self.embed_size)
|
| 150 |
+
x = x.permute(0, 1, 3, 2) # B, N, D, L
|
| 151 |
+
|
| 152 |
+
# projection
|
| 153 |
+
x = torch.matmul(x, self.embeddings_10)
|
| 154 |
+
x = x.reshape(B, N, -1)
|
| 155 |
+
x = self.fc(x)
|
| 156 |
+
x = x.permute(0, 2, 1)
|
| 157 |
+
|
| 158 |
+
if self.args.normalize:
|
| 159 |
+
x = x * std
|
| 160 |
+
x = x + mean
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
aux = {
|
| 164 |
+
'gate_l0': torch.tensor(0.0, device=x.device) # placeholder
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
return x, aux
|
| 168 |
+
|
model/SpikF.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from spikingjelly.clock_driven.neuron import MultiStepLIFNode
|
| 4 |
+
|
| 5 |
+
class SPE(nn.Module):
|
| 6 |
+
def __init__(self, input_len, patch_num, patch_dim, T, tau, D):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.patch_projector = nn.Linear(input_len // patch_num, patch_dim)
|
| 9 |
+
self.bn = nn.BatchNorm2d(patch_dim)
|
| 10 |
+
self.encoder_lif = MultiStepLIFNode(tau=tau, detach_reset=False, backend='torch')
|
| 11 |
+
|
| 12 |
+
self.D = D
|
| 13 |
+
self.T = T
|
| 14 |
+
self.patch_dim = patch_dim
|
| 15 |
+
self.patch_num = patch_num
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
B, L, D = x.shape
|
| 19 |
+
|
| 20 |
+
x = x.view(B, self.patch_num, L // self.patch_num, D).contiguous()
|
| 21 |
+
x = x.transpose(-1, -2).contiguous()
|
| 22 |
+
x = self.patch_projector(x)
|
| 23 |
+
x = x.repeat(self.T, 1, 1, 1, 1)
|
| 24 |
+
x = x.permute(0, 1, 4, 2, 3).contiguous()
|
| 25 |
+
x = x.flatten(0, 1)
|
| 26 |
+
x = self.bn(x)
|
| 27 |
+
x = x.view(self.T, B, self.patch_dim, self.patch_num, D)
|
| 28 |
+
x = self.encoder_lif(x)
|
| 29 |
+
|
| 30 |
+
return x
|
| 31 |
+
|
| 32 |
+
class SFS(nn.Module):
|
| 33 |
+
def __init__(self, patch_num, D, patch_dim, tau, alpha):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.time2freq = nn.Linear(patch_num, patch_num // 2 + 1)
|
| 36 |
+
|
| 37 |
+
self.intra_conv = nn.Conv2d(in_channels=patch_dim, out_channels=patch_dim, kernel_size=[5, 1], stride=[1, 1], padding=[2, 0])
|
| 38 |
+
self.inter_conv = nn.Conv2d(in_channels=patch_dim, out_channels=patch_dim, kernel_size=[3, 1], stride=[1, 1], padding=[1, 0])
|
| 39 |
+
|
| 40 |
+
self.generator_lif = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch', v_threshold=0.1)
|
| 41 |
+
self.mp_lif = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch')
|
| 42 |
+
self.sfs_lif = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch')
|
| 43 |
+
self.intra_lif = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch')
|
| 44 |
+
self.inter_lif = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch')
|
| 45 |
+
|
| 46 |
+
self.bn1 = nn.BatchNorm2d(patch_dim)
|
| 47 |
+
self.bn2 = nn.BatchNorm2d(patch_dim)
|
| 48 |
+
self.bn3 = nn.BatchNorm2d(patch_dim)
|
| 49 |
+
self.bn4 = nn.BatchNorm2d(patch_dim)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
res_x = x
|
| 53 |
+
T, B, pd, pn, D = x.shape
|
| 54 |
+
|
| 55 |
+
x = x.transpose(-1, -2).contiguous()
|
| 56 |
+
freq_spec = torch.fft.rfft(x)
|
| 57 |
+
|
| 58 |
+
selector = self.time2freq(x)
|
| 59 |
+
selector = selector.flatten(0, 1)
|
| 60 |
+
selector = self.bn1(selector)
|
| 61 |
+
selector = selector.view(T, B, pd, D, -1)
|
| 62 |
+
selector = self.generator_lif(selector)
|
| 63 |
+
selector = selector.sum(dim=0, keepdim=True)
|
| 64 |
+
selector = self.mp_lif(selector)
|
| 65 |
+
selector = selector.repeat(T, 1, 1, 1, 1).float()
|
| 66 |
+
selector_imag = torch.zeros(selector.size()).to(x.device)
|
| 67 |
+
selector = torch.complex(selector, selector_imag).to(x.device)
|
| 68 |
+
|
| 69 |
+
remain_freq = selector * freq_spec
|
| 70 |
+
|
| 71 |
+
current = torch.fft.irfft(remain_freq)
|
| 72 |
+
current = current.transpose(-1, -2).contiguous()
|
| 73 |
+
current = current.flatten(0, 1)
|
| 74 |
+
current = self.bn2(current)
|
| 75 |
+
current = current.view(T, B, pd, pn, D)
|
| 76 |
+
|
| 77 |
+
spike = self.sfs_lif(current)
|
| 78 |
+
x = spike + res_x
|
| 79 |
+
res_x = x
|
| 80 |
+
|
| 81 |
+
x = x.flatten(0, 1)
|
| 82 |
+
x = self.intra_conv(x)
|
| 83 |
+
x = self.bn3(x)
|
| 84 |
+
x = x.view(T, B, pd, pn, D)
|
| 85 |
+
x = self.intra_lif(x) + res_x
|
| 86 |
+
res_x = x
|
| 87 |
+
|
| 88 |
+
x = x.transpose(0, 3).contiguous()
|
| 89 |
+
x = x.flatten(0, 1)
|
| 90 |
+
x = self.inter_conv(x)
|
| 91 |
+
x = self.bn4(x)
|
| 92 |
+
x = x.view(pn, B, pd, T, D)
|
| 93 |
+
x = x.transpose(0, 3)
|
| 94 |
+
x = self.inter_lif(x)
|
| 95 |
+
x = x + res_x
|
| 96 |
+
|
| 97 |
+
return x
|
| 98 |
+
|
| 99 |
+
class SpikF(nn.Module):
|
| 100 |
+
def __init__(self, args, input_len, patch_num, patch_dim, T, blocks, D, pred_len, tau, alpha, hidden_dim):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.SPE = SPE(input_len, patch_num, patch_dim, T, tau, D)
|
| 103 |
+
self.args = args
|
| 104 |
+
|
| 105 |
+
self.SFSs = nn.ModuleList()
|
| 106 |
+
for i in range(blocks):
|
| 107 |
+
self.SFSs.append(SFS(patch_num, D, patch_dim, tau, alpha))
|
| 108 |
+
|
| 109 |
+
self.dense1 = nn.Linear(patch_num * patch_dim, hidden_dim)
|
| 110 |
+
self.dense2 = nn.Linear(hidden_dim, pred_len)
|
| 111 |
+
|
| 112 |
+
self.bn = nn.BatchNorm1d(D)
|
| 113 |
+
|
| 114 |
+
self.activ = nn.GELU()
|
| 115 |
+
|
| 116 |
+
def forward(self, x):
|
| 117 |
+
|
| 118 |
+
if self.args.normalize:
|
| 119 |
+
|
| 120 |
+
mean = x.mean(dim=1, keepdim=True).detach()
|
| 121 |
+
x = x - mean
|
| 122 |
+
|
| 123 |
+
std = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
| 124 |
+
x = x / std
|
| 125 |
+
|
| 126 |
+
x = self.SPE(x)
|
| 127 |
+
T, B, pd, pn, D = x.shape
|
| 128 |
+
|
| 129 |
+
for i in range(len(self.SFSs)):
|
| 130 |
+
x = self.SFSs[i](x)
|
| 131 |
+
|
| 132 |
+
x = x.permute(0, 1, 4, 2, 3).contiguous()
|
| 133 |
+
x = x.flatten(-2, -1)
|
| 134 |
+
x = self.dense1(x)
|
| 135 |
+
x = x.flatten(0, 1)
|
| 136 |
+
x = self.bn(x)
|
| 137 |
+
x = self.activ(x)
|
| 138 |
+
x = self.dense2(x)
|
| 139 |
+
x = x.transpose(-1, -2).contiguous()
|
| 140 |
+
x = x.view(T, B, -1, D)
|
| 141 |
+
|
| 142 |
+
if self.args.normalize:
|
| 143 |
+
x = x * std
|
| 144 |
+
x = x + mean.repeat(T, 1, 1, 1)
|
| 145 |
+
|
| 146 |
+
aux = {
|
| 147 |
+
'gate_l0': torch.tensor(0.0, device=x.device) # placeholder
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
return x, aux
|
model/SpikF_GO.py
ADDED
|
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Dict, Tuple
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch.nn.utils import weight_norm
|
| 8 |
+
|
| 9 |
+
from spikingjelly.clock_driven.neuron import MultiStepLIFNode
|
| 10 |
+
from spikingjelly.activation_based import surrogate
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Affine(nn.Module):
|
| 16 |
+
def __init__(self, D: int):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.gamma = nn.Parameter(torch.ones(D))
|
| 19 |
+
self.beta = nn.Parameter(torch.zeros(D))
|
| 20 |
+
|
| 21 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 22 |
+
return x * self.gamma + self.beta
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class RMSNorm(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
tok: [B, M, E]
|
| 29 |
+
Normalize over M per sample, per channel plus affine.
|
| 30 |
+
"""
|
| 31 |
+
def __init__(self, E: int, eps: float = 1e-6):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.eps = eps
|
| 34 |
+
self.affine = Affine(E)
|
| 35 |
+
|
| 36 |
+
def forward(self, tok: torch.Tensor) -> torch.Tensor:
|
| 37 |
+
rms = torch.rsqrt(tok.pow(2).mean(dim=1, keepdim=True) + self.eps) # [B,1,E]
|
| 38 |
+
y = tok * rms
|
| 39 |
+
y = self.affine(y)
|
| 40 |
+
return y
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SFFT(nn.Module):
|
| 45 |
+
"""
|
| 46 |
+
S-FFT: implementing FFT on GPU; for theoretical information (spiking FFT),
|
| 47 |
+
refer to the our paper and paper SpikF.
|
| 48 |
+
"""
|
| 49 |
+
def __init__(self, M: int):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.M = M
|
| 52 |
+
self.F = M // 2 + 1
|
| 53 |
+
|
| 54 |
+
def rfft(self, s_t: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
T, B, M, E = s_t.shape
|
| 56 |
+
x = s_t.permute(0, 1, 3, 2).contiguous().view(T * B * E, M) # [T*B*E, M]
|
| 57 |
+
Z = torch.fft.rfft(x, n=self.M, dim=-1, norm="ortho") # [T*B*E, F] complex
|
| 58 |
+
Z = Z.view(T, B, E, self.F).permute(0, 1, 3, 2).contiguous() # [T,B,F,E]
|
| 59 |
+
return Z
|
| 60 |
+
|
| 61 |
+
def irfft(self, Z_t: torch.Tensor) -> torch.Tensor:
|
| 62 |
+
T, B, Freq, E = Z_t.shape
|
| 63 |
+
x = Z_t.permute(0, 1, 3, 2).contiguous().view(T * B * E, Freq) # [T*B*E, F]
|
| 64 |
+
y = torch.fft.irfft(x, n=self.M, dim=-1, norm="ortho") # [T*B*E, M]
|
| 65 |
+
y = y.view(T, B, E, self.M).permute(0, 1, 3, 2).contiguous() # [T,B,M,E]
|
| 66 |
+
return y
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class HardConcreteGate(nn.Module):
|
| 71 |
+
"""
|
| 72 |
+
Gate over frequency bins.
|
| 73 |
+
Z: [T,B,F,E]
|
| 74 |
+
mask m: [1,1,F,1] in [0,1]
|
| 75 |
+
"""
|
| 76 |
+
def __init__(self, F_bins: int, init_logit: float = 2.0, eps: float = 1e-6):
|
| 77 |
+
super().__init__()
|
| 78 |
+
self.log_alpha = nn.Parameter(torch.full((F_bins,), float(init_logit)))
|
| 79 |
+
self.eps = eps
|
| 80 |
+
|
| 81 |
+
def _sample_u(self, shape, device):
|
| 82 |
+
return torch.empty(shape, device=device).uniform_(self.eps, 1.0 - self.eps)
|
| 83 |
+
|
| 84 |
+
def _hard_concrete(self, training: bool, device, tau: float):
|
| 85 |
+
if training:
|
| 86 |
+
u = self._sample_u(self.log_alpha.shape, device)
|
| 87 |
+
s = torch.sigmoid((torch.log(u) - torch.log(1 - u) + self.log_alpha) / tau)
|
| 88 |
+
else:
|
| 89 |
+
s = torch.sigmoid(self.log_alpha)
|
| 90 |
+
s_bar = s * 1.2 - 0.1
|
| 91 |
+
return s_bar.clamp(0.0, 1.0)
|
| 92 |
+
|
| 93 |
+
def forward(self, Z: torch.Tensor, tau: float) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 94 |
+
m = self._hard_concrete(self.training, Z.device, tau=tau) # [F]
|
| 95 |
+
m = m.view(1, 1, -1, 1).to(Z.real.dtype) # [1,1,F,1]
|
| 96 |
+
return Z * m, m
|
| 97 |
+
|
| 98 |
+
def l0(self) -> torch.Tensor:
|
| 99 |
+
return torch.sigmoid(self.log_alpha).mean()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class ComplexAffine(nn.Module):
|
| 105 |
+
def __init__(self, E: int):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.gamma_r = nn.Parameter(torch.ones(E))
|
| 108 |
+
self.beta_r = nn.Parameter(torch.zeros(E))
|
| 109 |
+
self.gamma_i = nn.Parameter(torch.ones(E))
|
| 110 |
+
self.beta_i = nn.Parameter(torch.zeros(E))
|
| 111 |
+
|
| 112 |
+
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
| 113 |
+
zr = z.real * self.gamma_r + self.beta_r
|
| 114 |
+
zi = z.imag * self.gamma_i + self.beta_i
|
| 115 |
+
return torch.complex(zr, zi)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class ComplexLinear(nn.Module):
|
| 120 |
+
def __init__(self, E_in: int, E_out: int, init_scale: float = 0.02):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.Wr = nn.Parameter(init_scale * torch.randn(E_in, E_out))
|
| 123 |
+
self.Wi = nn.Parameter(init_scale * torch.randn(E_in, E_out))
|
| 124 |
+
self.br = nn.Parameter(torch.zeros(E_out))
|
| 125 |
+
self.bi = nn.Parameter(torch.zeros(E_out))
|
| 126 |
+
|
| 127 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 128 |
+
xr, xi = x.real, x.imag
|
| 129 |
+
yr = xr @ self.Wr - xi @ self.Wi + self.br
|
| 130 |
+
yi = xi @ self.Wr + xr @ self.Wi + self.bi
|
| 131 |
+
return torch.complex(yr, yi)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class ComplexLIFGate(nn.Module):
|
| 135 |
+
def __init__(self, tau: float, v_th: float):
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.lif_r = MultiStepLIFNode(
|
| 138 |
+
tau=tau, v_threshold=v_th, detach_reset=True,
|
| 139 |
+
surrogate_function=surrogate.ATan(alpha=4.0), backend="torch"
|
| 140 |
+
)
|
| 141 |
+
self.lif_i = MultiStepLIFNode(
|
| 142 |
+
tau=tau, v_threshold=v_th, detach_reset=True,
|
| 143 |
+
surrogate_function=surrogate.ATan(alpha=4.0), backend="torch"
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
| 147 |
+
s_r = self.lif_r(z.real) # [T,B,F,D] in [0,1]
|
| 148 |
+
s_i = self.lif_i(z.imag)
|
| 149 |
+
g = ((s_r > 0) | (s_i > 0)).to(z.real.dtype)
|
| 150 |
+
return g
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class SFGO(nn.Module):
|
| 155 |
+
def __init__(
|
| 156 |
+
self,
|
| 157 |
+
args,
|
| 158 |
+
E: int,
|
| 159 |
+
hidden_size_factor: int,
|
| 160 |
+
tau: float = 2.0,
|
| 161 |
+
v_th: float = 1.0,
|
| 162 |
+
apply_gate_to_complex: bool = True,
|
| 163 |
+
):
|
| 164 |
+
super().__init__()
|
| 165 |
+
H = int(E * hidden_size_factor)
|
| 166 |
+
|
| 167 |
+
self.args = args
|
| 168 |
+
|
| 169 |
+
self.lin1 = ComplexLinear(E, H)
|
| 170 |
+
self.lin2 = ComplexLinear(H, E)
|
| 171 |
+
self.lin3 = ComplexLinear(E, E)
|
| 172 |
+
|
| 173 |
+
self.g1 = ComplexLIFGate(tau=tau, v_th=v_th)
|
| 174 |
+
self.g2 = ComplexLIFGate(tau=tau, v_th=v_th)
|
| 175 |
+
self.g3 = ComplexLIFGate(tau=tau, v_th=v_th)
|
| 176 |
+
|
| 177 |
+
self.apply_gate_to_complex = apply_gate_to_complex
|
| 178 |
+
|
| 179 |
+
self.r2 = nn.Parameter(torch.tensor(0.1))
|
| 180 |
+
self.r3 = nn.Parameter(torch.tensor(0.1))
|
| 181 |
+
|
| 182 |
+
if self.args.affine:
|
| 183 |
+
|
| 184 |
+
self.a1 = ComplexAffine(E)
|
| 185 |
+
self.a2 = ComplexAffine(H)
|
| 186 |
+
self.a3 = ComplexAffine(E)
|
| 187 |
+
|
| 188 |
+
self.ga1 = ComplexLIFGate(tau=tau, v_th=v_th)
|
| 189 |
+
self.ga2 = ComplexLIFGate(tau=tau, v_th=v_th)
|
| 190 |
+
self.ga3 = ComplexLIFGate(tau=tau, v_th=v_th)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _apply_gate(self, z: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
|
| 194 |
+
if not self.apply_gate_to_complex:
|
| 195 |
+
return z
|
| 196 |
+
return z * g.to(z.real.dtype)
|
| 197 |
+
|
| 198 |
+
def forward(self, Z: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 199 |
+
stats: Dict[str, torch.Tensor] = {}
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
if self.args.affine:
|
| 203 |
+
A1 = self.a1(Z)
|
| 204 |
+
GA1 = self.ga1(A1)
|
| 205 |
+
A1 = self._apply_gate(A1, GA1)
|
| 206 |
+
else:
|
| 207 |
+
A1 = Z
|
| 208 |
+
|
| 209 |
+
Y = self.lin1(A1)
|
| 210 |
+
G1 = self.g1(Y)
|
| 211 |
+
Y = self._apply_gate(Y, G1)
|
| 212 |
+
|
| 213 |
+
if self.args.affine:
|
| 214 |
+
A2 = self.a2(Y)
|
| 215 |
+
GA2 = self.ga2(A2)
|
| 216 |
+
A2 = self._apply_gate(A2, GA2)
|
| 217 |
+
else:
|
| 218 |
+
A2 = Y
|
| 219 |
+
|
| 220 |
+
X = self.lin2(A2)
|
| 221 |
+
G2 = self.g2(X)
|
| 222 |
+
X = self._apply_gate(X, G2)
|
| 223 |
+
|
| 224 |
+
Z2 = Z + self.r2 * X
|
| 225 |
+
|
| 226 |
+
if self.args.affine:
|
| 227 |
+
A3 = self.a3(Z2)
|
| 228 |
+
GA3 = self.ga3(A3)
|
| 229 |
+
A3 = self._apply_gate(A3, GA3)
|
| 230 |
+
else:
|
| 231 |
+
A3 = Z2
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
W = self.lin3(A3)
|
| 235 |
+
G3 = self.g3(W)
|
| 236 |
+
W = self._apply_gate(W, G3)
|
| 237 |
+
|
| 238 |
+
out = Z2 + self.r3 * W
|
| 239 |
+
|
| 240 |
+
with torch.no_grad():
|
| 241 |
+
mag2 = out.real * out.real + out.imag * out.imag
|
| 242 |
+
stats["freq_active_frac"] = (mag2 > 0).float().mean()
|
| 243 |
+
|
| 244 |
+
stats["rezero_r2"] = self.r2.detach()
|
| 245 |
+
stats["rezero_r3"] = self.r3.detach()
|
| 246 |
+
|
| 247 |
+
stats["gate_lin_frac_1"] = G1.mean().detach()
|
| 248 |
+
stats["gate_lin_frac_2"] = G2.mean().detach()
|
| 249 |
+
stats["gate_lin_frac_3"] = G3.mean().detach()
|
| 250 |
+
|
| 251 |
+
return out, stats
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class Decoder(nn.Module):
|
| 256 |
+
def __init__(
|
| 257 |
+
self,
|
| 258 |
+
E: int,
|
| 259 |
+
L: int,
|
| 260 |
+
pred_len: int,
|
| 261 |
+
T: int,
|
| 262 |
+
tau: float,
|
| 263 |
+
v_th: float,
|
| 264 |
+
proj_dim: int = 4,
|
| 265 |
+
reduced_dim: int = 64,
|
| 266 |
+
):
|
| 267 |
+
super().__init__()
|
| 268 |
+
self.E, self.L, self.P, self.T = E, L, pred_len, T
|
| 269 |
+
self.proj_dim = int(proj_dim)
|
| 270 |
+
|
| 271 |
+
self.time_proj = nn.Linear(L, self.proj_dim, bias=False)
|
| 272 |
+
D_in = E * self.proj_dim
|
| 273 |
+
self.reduced_dim = int(reduced_dim)
|
| 274 |
+
|
| 275 |
+
self.lif = MultiStepLIFNode(
|
| 276 |
+
tau=tau,
|
| 277 |
+
v_threshold=v_th,
|
| 278 |
+
detach_reset=True,
|
| 279 |
+
surrogate_function=surrogate.ATan(alpha=4.0),
|
| 280 |
+
backend="torch",
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
self.fc_reduce = weight_norm(nn.Linear(D_in, int(reduced_dim), bias=True))
|
| 284 |
+
self.fc_out = weight_norm(nn.Linear(int(reduced_dim), pred_len, bias=True))
|
| 285 |
+
|
| 286 |
+
nn.init.xavier_uniform_(self.time_proj.weight, gain=0.5)
|
| 287 |
+
nn.init.xavier_uniform_(self.fc_reduce.weight, gain=0.6)
|
| 288 |
+
nn.init.xavier_uniform_(self.fc_out.weight, gain=0.2)
|
| 289 |
+
nn.init.zeros_(self.fc_reduce.bias)
|
| 290 |
+
nn.init.zeros_(self.fc_out.bias)
|
| 291 |
+
|
| 292 |
+
def forward(self, y_t: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 293 |
+
T, B, N, E, L = y_t.shape
|
| 294 |
+
|
| 295 |
+
y_p = self.time_proj(y_t) # [T,B,N,E,p]
|
| 296 |
+
x = y_p.reshape(T, B * N, E * self.proj_dim) # [T,B*N,D]
|
| 297 |
+
s = self.lif(x) # [T,B*N,D] spikes
|
| 298 |
+
h_t = self.fc_reduce(s.reshape(T * B * N, -1)).view(T, B * N, self.reduced_dim)
|
| 299 |
+
|
| 300 |
+
h = h_t.mean(dim=0) # [B*N,reduced_dim]
|
| 301 |
+
h = F.gelu(h)
|
| 302 |
+
out = self.fc_out(h) # [B*N,O]
|
| 303 |
+
|
| 304 |
+
preds = out.view(B, N, self.P).permute(0, 2, 1).contiguous()
|
| 305 |
+
stats = {"dec_spike_rate": s.mean().detach()}
|
| 306 |
+
return preds, stats
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class SpikF_GO(nn.Module):
|
| 311 |
+
def __init__(
|
| 312 |
+
self,
|
| 313 |
+
args,
|
| 314 |
+
pre_length: int,
|
| 315 |
+
embed_size: int,
|
| 316 |
+
feature_size: int,
|
| 317 |
+
seq_length: int,
|
| 318 |
+
hidden_size: int,
|
| 319 |
+
hard_thresholding_fraction=1,
|
| 320 |
+
hidden_size_factor: int = 1,
|
| 321 |
+
sparsity_threshold: float = 0.01,
|
| 322 |
+
):
|
| 323 |
+
super().__init__()
|
| 324 |
+
self.args = args
|
| 325 |
+
|
| 326 |
+
self.N = feature_size
|
| 327 |
+
self.L = seq_length
|
| 328 |
+
self.E = embed_size
|
| 329 |
+
self.T = args.T
|
| 330 |
+
self.M = self.N * self.L
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
self.embeddings = nn.Parameter(torch.randn(1, self.E) * 0.02)
|
| 334 |
+
self.node_aff = Affine(self.E)
|
| 335 |
+
self.node_rms = RMSNorm(E=self.E, eps=1e-6)
|
| 336 |
+
|
| 337 |
+
# step modulation
|
| 338 |
+
self.step_gamma = nn.Parameter(torch.ones(self.T))
|
| 339 |
+
self.step_beta = nn.Parameter(torch.zeros(self.T))
|
| 340 |
+
self.register_buffer("step_scale", torch.linspace(0, 1, steps=self.T).view(self.T, 1, 1, 1))
|
| 341 |
+
|
| 342 |
+
# Encoder LIF
|
| 343 |
+
self.encoder_lif = MultiStepLIFNode(
|
| 344 |
+
tau=args.tau,
|
| 345 |
+
v_threshold=args.alpha,
|
| 346 |
+
detach_reset=True,
|
| 347 |
+
surrogate_function=surrogate.ATan(alpha=4.0),
|
| 348 |
+
backend="torch",
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
self.sfft = SFFT(self.M)
|
| 352 |
+
self.F_bins = self.sfft.F
|
| 353 |
+
|
| 354 |
+
# frequency gate
|
| 355 |
+
self.freq_gate = HardConcreteGate(self.F_bins, init_logit=2.0)
|
| 356 |
+
self.register_buffer("gate_tau", torch.tensor(0.10))
|
| 357 |
+
|
| 358 |
+
self.sfgo = SFGO(
|
| 359 |
+
self.args,
|
| 360 |
+
E=self.E,
|
| 361 |
+
hidden_size_factor=hidden_size_factor,
|
| 362 |
+
tau=args.tau,
|
| 363 |
+
v_th=args.alpha,
|
| 364 |
+
apply_gate_to_complex=True,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# decoder
|
| 368 |
+
proj_dim = self.args.proj_dim
|
| 369 |
+
reduced_dim = max(16, min(128, hidden_size // 4))
|
| 370 |
+
self.decoder = Decoder(
|
| 371 |
+
E=self.E,
|
| 372 |
+
L=self.L,
|
| 373 |
+
pred_len=pre_length,
|
| 374 |
+
T=self.T,
|
| 375 |
+
tau=args.tau,
|
| 376 |
+
v_th=args.alpha,
|
| 377 |
+
proj_dim=proj_dim,
|
| 378 |
+
reduced_dim=reduced_dim,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
def node_embed(self, x: torch.Tensor) -> torch.Tensor:
|
| 382 |
+
# x: [B,L,N] -> [B,M,E]
|
| 383 |
+
B, L, N = x.shape
|
| 384 |
+
x_flat = x.permute(0, 2, 1).contiguous().reshape(B, self.M) # [B,M]
|
| 385 |
+
tok = x_flat.unsqueeze(-1) * self.embeddings # [B,M,E]
|
| 386 |
+
tok = self.node_aff(tok)
|
| 387 |
+
return tok
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 391 |
+
B, L, N = x.shape
|
| 392 |
+
|
| 393 |
+
# normalize
|
| 394 |
+
if self.args.normalize:
|
| 395 |
+
mean = x.mean(dim=1, keepdim=True).detach()
|
| 396 |
+
x0 = x - mean
|
| 397 |
+
std = torch.sqrt(torch.var(x0, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
| 398 |
+
x0 = x0 / std
|
| 399 |
+
else:
|
| 400 |
+
mean, std = None, None
|
| 401 |
+
x0 = x
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
tok = self.node_embed(x0) # [B,M,E]
|
| 405 |
+
tok = self.node_rms(tok) # RMSNorm
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
# step modulation
|
| 409 |
+
cur_t = tok.unsqueeze(0).repeat(self.T, 1, 1, 1)
|
| 410 |
+
cur_t = cur_t * self.step_gamma.view(self.T, 1, 1, 1) + self.step_beta.view(self.T, 1, 1, 1)
|
| 411 |
+
cur_t = cur_t * (1.0 + 0.02 * self.step_scale.to(cur_t.dtype))
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
# spikes
|
| 415 |
+
s_t = self.encoder_lif(cur_t)
|
| 416 |
+
enc_rate = s_t.mean()
|
| 417 |
+
|
| 418 |
+
# FFT
|
| 419 |
+
Z_t = self.sfft.rfft(s_t)
|
| 420 |
+
|
| 421 |
+
# prune
|
| 422 |
+
Z_t, m = self.freq_gate(Z_t, tau=float(self.gate_tau))
|
| 423 |
+
|
| 424 |
+
# S-FGO blocks
|
| 425 |
+
Z_t, fb_stats = self.sfgo(Z_t)
|
| 426 |
+
|
| 427 |
+
# iFFT
|
| 428 |
+
y_time_t = self.sfft.irfft(Z_t).to(tok.dtype)
|
| 429 |
+
|
| 430 |
+
y_t = y_time_t.view(self.T, B, N, self.L, self.E).permute(0, 1, 2, 4, 3).contiguous()
|
| 431 |
+
|
| 432 |
+
preds, dec_stats = self.decoder(y_t)
|
| 433 |
+
|
| 434 |
+
if self.args.normalize:
|
| 435 |
+
preds = preds * std + mean # denormalize
|
| 436 |
+
|
| 437 |
+
aux = {
|
| 438 |
+
"enc_rate": enc_rate.detach(),
|
| 439 |
+
"rho_hat": self.freq_gate.l0().detach(),
|
| 440 |
+
"freq_mask_mean": m.mean().detach(),
|
| 441 |
+
"freq_mask_active": (m > 0.5).float().mean().detach(),
|
| 442 |
+
**fb_stats,
|
| 443 |
+
**dec_stats,
|
| 444 |
+
}
|
| 445 |
+
return preds, aux
|
model/SpikF_GO_CPG.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Tuple, Dict, Optional
|
| 4 |
+
import torch
|
| 5 |
+
import math
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch.nn.utils import weight_norm
|
| 9 |
+
|
| 10 |
+
from spikingjelly.clock_driven.neuron import MultiStepLIFNode
|
| 11 |
+
from spikingjelly.activation_based import surrogate
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Affine(nn.Module):
|
| 17 |
+
def __init__(self, D: int):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.gamma = nn.Parameter(torch.ones(D))
|
| 20 |
+
self.beta = nn.Parameter(torch.zeros(D))
|
| 21 |
+
|
| 22 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 23 |
+
return x * self.gamma + self.beta
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class RMSNorm(nn.Module):
|
| 28 |
+
"""
|
| 29 |
+
tok: [B, M, E]
|
| 30 |
+
Normalize over M per sample, per channel plus affine.
|
| 31 |
+
"""
|
| 32 |
+
def __init__(self, E: int, eps: float = 1e-6):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.eps = eps
|
| 35 |
+
self.affine = Affine(E)
|
| 36 |
+
|
| 37 |
+
def forward(self, tok: torch.Tensor) -> torch.Tensor:
|
| 38 |
+
rms = torch.rsqrt(tok.pow(2).mean(dim=1, keepdim=True) + self.eps) # [B,1,E]
|
| 39 |
+
y = tok * rms
|
| 40 |
+
y = self.affine(y)
|
| 41 |
+
return y
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class CPGSpikePE(nn.Module):
|
| 47 |
+
"""
|
| 48 |
+
Spike-form positional encoding (CPG-PE).
|
| 49 |
+
Generates 2*N_pe binary channels with log-spaced rhythms over the flattened index t in [0, T*M).
|
| 50 |
+
Shapes:
|
| 51 |
+
returns pe: [T, B, M, 2*N_pe] with 0/1 spikes (no learnable params).
|
| 52 |
+
"""
|
| 53 |
+
def __init__(self,
|
| 54 |
+
num_pairs: int = 20,
|
| 55 |
+
tau: float = 10000.0,
|
| 56 |
+
eta: float = 1.0,
|
| 57 |
+
vthres: float = 0.8,
|
| 58 |
+
w_max: float = 10000.0):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.num_pairs = num_pairs
|
| 61 |
+
self.tau = tau
|
| 62 |
+
self.eta = eta
|
| 63 |
+
self.vthres = vthres
|
| 64 |
+
self.w_max = w_max
|
| 65 |
+
|
| 66 |
+
def forward(self, T: int, B: int, M: int, device) -> torch.Tensor:
|
| 67 |
+
t = torch.arange(T * M, device=device, dtype=torch.float32) # [T*M]
|
| 68 |
+
i = torch.arange(self.num_pairs, device=device, dtype=torch.float32)
|
| 69 |
+
freq = torch.exp(-torch.log(torch.tensor(self.w_max, device=device)) * (i / max(1, self.num_pairs))) # [N_pe]
|
| 70 |
+
|
| 71 |
+
arg = self.eta * (t[:, None] * freq[None, :] / self.tau) # [T*M, N_pe]
|
| 72 |
+
cos_spk = (torch.cos(arg) - self.vthres > 0).float()
|
| 73 |
+
sin_spk = (torch.sin(arg) - self.vthres > 0).float()
|
| 74 |
+
|
| 75 |
+
pe = torch.cat([cos_spk, sin_spk], dim=1) # [T*M, 2*N_pe]
|
| 76 |
+
pe = pe.view(T, M, 2 * self.num_pairs).unsqueeze(1) # [T, 1, M, 2*N_pe]
|
| 77 |
+
pe = pe.expand(-1, B, -1, -1).contiguous() # [T, B, M, 2*N_pe]
|
| 78 |
+
return pe
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class SFFT(nn.Module):
|
| 84 |
+
"""
|
| 85 |
+
S-FFT: implementing FFT on GPU; for theoretical information (spiking FFT),
|
| 86 |
+
refer to the our paper and paper SpikF.
|
| 87 |
+
"""
|
| 88 |
+
def __init__(self, M: int):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.M = M
|
| 91 |
+
self.F = M // 2 + 1
|
| 92 |
+
|
| 93 |
+
def rfft(self, s_t: torch.Tensor) -> torch.Tensor:
|
| 94 |
+
T, B, M, E = s_t.shape
|
| 95 |
+
x = s_t.permute(0, 1, 3, 2).contiguous().view(T * B * E, M) # [T*B*E, M]
|
| 96 |
+
Z = torch.fft.rfft(x, n=self.M, dim=-1, norm="ortho") # [T*B*E, F] complex
|
| 97 |
+
Z = Z.view(T, B, E, self.F).permute(0, 1, 3, 2).contiguous() # [T,B,F,E]
|
| 98 |
+
return Z
|
| 99 |
+
|
| 100 |
+
def irfft(self, Z_t: torch.Tensor) -> torch.Tensor:
|
| 101 |
+
T, B, Freq, E = Z_t.shape
|
| 102 |
+
x = Z_t.permute(0, 1, 3, 2).contiguous().view(T * B * E, Freq) # [T*B*E, F]
|
| 103 |
+
y = torch.fft.irfft(x, n=self.M, dim=-1, norm="ortho") # [T*B*E, M]
|
| 104 |
+
y = y.view(T, B, E, self.M).permute(0, 1, 3, 2).contiguous() # [T,B,M,E]
|
| 105 |
+
return y
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class HardConcreteGate(nn.Module):
|
| 110 |
+
"""
|
| 111 |
+
Gate over frequency bins.
|
| 112 |
+
Z: [T,B,F,E]
|
| 113 |
+
mask m: [1,1,F,1] in [0,1]
|
| 114 |
+
"""
|
| 115 |
+
def __init__(self, F_bins: int, init_logit: float = 2.0, eps: float = 1e-6):
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.log_alpha = nn.Parameter(torch.full((F_bins,), float(init_logit)))
|
| 118 |
+
self.eps = eps
|
| 119 |
+
|
| 120 |
+
def _sample_u(self, shape, device):
|
| 121 |
+
return torch.empty(shape, device=device).uniform_(self.eps, 1.0 - self.eps)
|
| 122 |
+
|
| 123 |
+
def _hard_concrete(self, training: bool, device, tau: float):
|
| 124 |
+
if training:
|
| 125 |
+
u = self._sample_u(self.log_alpha.shape, device)
|
| 126 |
+
s = torch.sigmoid((torch.log(u) - torch.log(1 - u) + self.log_alpha) / tau)
|
| 127 |
+
else:
|
| 128 |
+
s = torch.sigmoid(self.log_alpha)
|
| 129 |
+
s_bar = s * 1.2 - 0.1
|
| 130 |
+
return s_bar.clamp(0.0, 1.0)
|
| 131 |
+
|
| 132 |
+
def forward(self, Z: torch.Tensor, tau: float) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 133 |
+
m = self._hard_concrete(self.training, Z.device, tau=tau) # [F]
|
| 134 |
+
m = m.view(1, 1, -1, 1).to(Z.real.dtype) # [1,1,F,1]
|
| 135 |
+
return Z * m, m
|
| 136 |
+
|
| 137 |
+
def l0(self) -> torch.Tensor:
|
| 138 |
+
return torch.sigmoid(self.log_alpha).mean()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class ComplexAffine(nn.Module):
|
| 144 |
+
def __init__(self, E: int):
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.gamma_r = nn.Parameter(torch.ones(E))
|
| 147 |
+
self.beta_r = nn.Parameter(torch.zeros(E))
|
| 148 |
+
self.gamma_i = nn.Parameter(torch.ones(E))
|
| 149 |
+
self.beta_i = nn.Parameter(torch.zeros(E))
|
| 150 |
+
|
| 151 |
+
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
| 152 |
+
zr = z.real * self.gamma_r + self.beta_r
|
| 153 |
+
zi = z.imag * self.gamma_i + self.beta_i
|
| 154 |
+
return torch.complex(zr, zi)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class ComplexLinear(nn.Module):
|
| 159 |
+
def __init__(self, E_in: int, E_out: int, init_scale: float = 0.02):
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.Wr = nn.Parameter(init_scale * torch.randn(E_in, E_out))
|
| 162 |
+
self.Wi = nn.Parameter(init_scale * torch.randn(E_in, E_out))
|
| 163 |
+
self.br = nn.Parameter(torch.zeros(E_out))
|
| 164 |
+
self.bi = nn.Parameter(torch.zeros(E_out))
|
| 165 |
+
|
| 166 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 167 |
+
xr, xi = x.real, x.imag
|
| 168 |
+
yr = xr @ self.Wr - xi @ self.Wi + self.br
|
| 169 |
+
yi = xi @ self.Wr + xr @ self.Wi + self.bi
|
| 170 |
+
return torch.complex(yr, yi)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class ComplexLIFGate(nn.Module):
|
| 174 |
+
def __init__(self, tau: float, v_th: float):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.lif_r = MultiStepLIFNode(
|
| 177 |
+
tau=tau, v_threshold=v_th, detach_reset=True,
|
| 178 |
+
surrogate_function=surrogate.ATan(alpha=4.0), backend="torch"
|
| 179 |
+
)
|
| 180 |
+
self.lif_i = MultiStepLIFNode(
|
| 181 |
+
tau=tau, v_threshold=v_th, detach_reset=True,
|
| 182 |
+
surrogate_function=surrogate.ATan(alpha=4.0), backend="torch"
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
| 186 |
+
s_r = self.lif_r(z.real) # [T,B,F,D] in [0,1]
|
| 187 |
+
s_i = self.lif_i(z.imag)
|
| 188 |
+
g = ((s_r > 0) | (s_i > 0)).to(z.real.dtype)
|
| 189 |
+
return g
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class SFGO(nn.Module):
|
| 194 |
+
def __init__(
|
| 195 |
+
self,
|
| 196 |
+
args,
|
| 197 |
+
E: int,
|
| 198 |
+
hidden_size_factor: int,
|
| 199 |
+
tau: float = 2.0,
|
| 200 |
+
v_th: float = 1.0,
|
| 201 |
+
apply_gate_to_complex: bool = True,
|
| 202 |
+
):
|
| 203 |
+
super().__init__()
|
| 204 |
+
H = int(E * hidden_size_factor)
|
| 205 |
+
|
| 206 |
+
self.args = args
|
| 207 |
+
|
| 208 |
+
self.lin1 = ComplexLinear(E, H)
|
| 209 |
+
self.lin2 = ComplexLinear(H, E)
|
| 210 |
+
self.lin3 = ComplexLinear(E, E)
|
| 211 |
+
|
| 212 |
+
self.g1 = ComplexLIFGate(tau=tau, v_th=v_th)
|
| 213 |
+
self.g2 = ComplexLIFGate(tau=tau, v_th=v_th)
|
| 214 |
+
self.g3 = ComplexLIFGate(tau=tau, v_th=v_th)
|
| 215 |
+
|
| 216 |
+
self.apply_gate_to_complex = apply_gate_to_complex
|
| 217 |
+
|
| 218 |
+
self.r2 = nn.Parameter(torch.tensor(0.1))
|
| 219 |
+
self.r3 = nn.Parameter(torch.tensor(0.1))
|
| 220 |
+
|
| 221 |
+
if self.args.affine:
|
| 222 |
+
|
| 223 |
+
self.a1 = ComplexAffine(E)
|
| 224 |
+
self.a2 = ComplexAffine(H)
|
| 225 |
+
self.a3 = ComplexAffine(E)
|
| 226 |
+
|
| 227 |
+
self.ga1 = ComplexLIFGate(tau=tau, v_th=v_th)
|
| 228 |
+
self.ga2 = ComplexLIFGate(tau=tau, v_th=v_th)
|
| 229 |
+
self.ga3 = ComplexLIFGate(tau=tau, v_th=v_th)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def _apply_gate(self, z: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
|
| 233 |
+
if not self.apply_gate_to_complex:
|
| 234 |
+
return z
|
| 235 |
+
return z * g.to(z.real.dtype)
|
| 236 |
+
|
| 237 |
+
def forward(self, Z: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 238 |
+
stats: Dict[str, torch.Tensor] = {}
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
if self.args.affine:
|
| 242 |
+
A1 = self.a1(Z)
|
| 243 |
+
GA1 = self.ga1(A1)
|
| 244 |
+
A1 = self._apply_gate(A1, GA1)
|
| 245 |
+
else:
|
| 246 |
+
A1 = Z
|
| 247 |
+
|
| 248 |
+
Y = self.lin1(A1)
|
| 249 |
+
G1 = self.g1(Y)
|
| 250 |
+
Y = self._apply_gate(Y, G1)
|
| 251 |
+
|
| 252 |
+
if self.args.affine:
|
| 253 |
+
A2 = self.a2(Y)
|
| 254 |
+
GA2 = self.ga2(A2)
|
| 255 |
+
A2 = self._apply_gate(A2, GA2)
|
| 256 |
+
else:
|
| 257 |
+
A2 = Y
|
| 258 |
+
|
| 259 |
+
X = self.lin2(A2)
|
| 260 |
+
G2 = self.g2(X)
|
| 261 |
+
X = self._apply_gate(X, G2)
|
| 262 |
+
|
| 263 |
+
Z2 = Z + self.r2 * X
|
| 264 |
+
|
| 265 |
+
if self.args.affine:
|
| 266 |
+
A3 = self.a3(Z2)
|
| 267 |
+
GA3 = self.ga3(A3)
|
| 268 |
+
A3 = self._apply_gate(A3, GA3)
|
| 269 |
+
else:
|
| 270 |
+
A3 = Z2
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
W = self.lin3(A3)
|
| 274 |
+
G3 = self.g3(W)
|
| 275 |
+
W = self._apply_gate(W, G3)
|
| 276 |
+
|
| 277 |
+
out = Z2 + self.r3 * W
|
| 278 |
+
|
| 279 |
+
with torch.no_grad():
|
| 280 |
+
mag2 = out.real * out.real + out.imag * out.imag
|
| 281 |
+
stats["freq_active_frac"] = (mag2 > 0).float().mean()
|
| 282 |
+
|
| 283 |
+
stats["rezero_r2"] = self.r2.detach()
|
| 284 |
+
stats["rezero_r3"] = self.r3.detach()
|
| 285 |
+
|
| 286 |
+
stats["gate_lin_frac_1"] = G1.mean().detach()
|
| 287 |
+
stats["gate_lin_frac_2"] = G2.mean().detach()
|
| 288 |
+
stats["gate_lin_frac_3"] = G3.mean().detach()
|
| 289 |
+
|
| 290 |
+
return out, stats
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class Decoder(nn.Module):
|
| 295 |
+
def __init__(
|
| 296 |
+
self,
|
| 297 |
+
E: int,
|
| 298 |
+
L: int,
|
| 299 |
+
pred_len: int,
|
| 300 |
+
T: int,
|
| 301 |
+
tau: float,
|
| 302 |
+
v_th: float,
|
| 303 |
+
proj_dim: int = 4,
|
| 304 |
+
reduced_dim: int = 64,
|
| 305 |
+
):
|
| 306 |
+
super().__init__()
|
| 307 |
+
self.E, self.L, self.P, self.T = E, L, pred_len, T
|
| 308 |
+
self.proj_dim = int(proj_dim)
|
| 309 |
+
|
| 310 |
+
self.time_proj = nn.Linear(L, self.proj_dim, bias=False)
|
| 311 |
+
D_in = E * self.proj_dim
|
| 312 |
+
self.reduced_dim = int(reduced_dim)
|
| 313 |
+
|
| 314 |
+
self.lif = MultiStepLIFNode(
|
| 315 |
+
tau=tau,
|
| 316 |
+
v_threshold=v_th,
|
| 317 |
+
detach_reset=True,
|
| 318 |
+
surrogate_function=surrogate.ATan(alpha=4.0),
|
| 319 |
+
backend="torch",
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
self.fc_reduce = weight_norm(nn.Linear(D_in, int(reduced_dim), bias=True))
|
| 323 |
+
self.fc_out = weight_norm(nn.Linear(int(reduced_dim), pred_len, bias=True))
|
| 324 |
+
|
| 325 |
+
nn.init.xavier_uniform_(self.time_proj.weight, gain=0.5)
|
| 326 |
+
nn.init.xavier_uniform_(self.fc_reduce.weight, gain=0.6)
|
| 327 |
+
nn.init.xavier_uniform_(self.fc_out.weight, gain=0.2)
|
| 328 |
+
nn.init.zeros_(self.fc_reduce.bias)
|
| 329 |
+
nn.init.zeros_(self.fc_out.bias)
|
| 330 |
+
|
| 331 |
+
def forward(self, y_t: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 332 |
+
T, B, N, E, L = y_t.shape
|
| 333 |
+
|
| 334 |
+
y_p = self.time_proj(y_t) # [T,B,N,E,p]
|
| 335 |
+
x = y_p.reshape(T, B * N, E * self.proj_dim) # [T,B*N,D]
|
| 336 |
+
s = self.lif(x) # [T,B*N,D] spikes
|
| 337 |
+
h_t = self.fc_reduce(s.reshape(T * B * N, -1)).view(T, B * N, self.reduced_dim)
|
| 338 |
+
|
| 339 |
+
h = h_t.mean(dim=0) # [B*N,reduced_dim]
|
| 340 |
+
h = F.gelu(h)
|
| 341 |
+
out = self.fc_out(h) # [B*N,O]
|
| 342 |
+
|
| 343 |
+
preds = out.view(B, N, self.P).permute(0, 2, 1).contiguous()
|
| 344 |
+
stats = {"dec_spike_rate": s.mean().detach()}
|
| 345 |
+
return preds, stats
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
class SpikF_GO_CPG(nn.Module):
|
| 350 |
+
def __init__(
|
| 351 |
+
self,
|
| 352 |
+
args,
|
| 353 |
+
pre_length: int,
|
| 354 |
+
embed_size: int,
|
| 355 |
+
feature_size: int,
|
| 356 |
+
seq_length: int,
|
| 357 |
+
hidden_size: int,
|
| 358 |
+
hard_thresholding_fraction=1,
|
| 359 |
+
hidden_size_factor: int = 1,
|
| 360 |
+
sparsity_threshold: float = 0.01,
|
| 361 |
+
):
|
| 362 |
+
super().__init__()
|
| 363 |
+
self.args = args
|
| 364 |
+
|
| 365 |
+
self.N = feature_size
|
| 366 |
+
self.L = seq_length
|
| 367 |
+
self.E = embed_size
|
| 368 |
+
self.T = args.T
|
| 369 |
+
self.M = self.N * self.L
|
| 370 |
+
|
| 371 |
+
self.use_cpg_pe = True
|
| 372 |
+
self.num_pe_pairs = 20
|
| 373 |
+
self.pe_tau = 10000.0
|
| 374 |
+
self.pe_eta = 1.0
|
| 375 |
+
self.pe_vthres = 0.8
|
| 376 |
+
self.pe_wmax = 10000.0
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
if self.use_cpg_pe:
|
| 380 |
+
self.cpg_pe = CPGSpikePE(
|
| 381 |
+
num_pairs=self.num_pe_pairs,
|
| 382 |
+
tau=self.pe_tau, eta=self.pe_eta,
|
| 383 |
+
vthres=self.pe_vthres, w_max=self.pe_wmax
|
| 384 |
+
)
|
| 385 |
+
self.pe_linear = nn.Linear(self.E + 2 * self.num_pe_pairs, self.E, bias=False)
|
| 386 |
+
self.pe_bn = nn.BatchNorm1d(self.E)
|
| 387 |
+
self.pe_lif = MultiStepLIFNode(
|
| 388 |
+
tau=self.args.tau, v_threshold=self.args.alpha, detach_reset=True,
|
| 389 |
+
surrogate_function=surrogate.ATan(alpha=4.0), backend='torch'
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
self.embeddings = nn.Parameter(torch.randn(1, self.E) * 0.02)
|
| 395 |
+
self.node_aff = Affine(self.E)
|
| 396 |
+
self.node_rms = RMSNorm(E=self.E, eps=1e-6)
|
| 397 |
+
|
| 398 |
+
# step modulation
|
| 399 |
+
self.step_gamma = nn.Parameter(torch.ones(self.T))
|
| 400 |
+
self.step_beta = nn.Parameter(torch.zeros(self.T))
|
| 401 |
+
self.register_buffer("step_scale", torch.linspace(0, 1, steps=self.T).view(self.T, 1, 1, 1))
|
| 402 |
+
|
| 403 |
+
# Encoder LIF
|
| 404 |
+
self.encoder_lif = MultiStepLIFNode(
|
| 405 |
+
tau=args.tau,
|
| 406 |
+
v_threshold=args.alpha,
|
| 407 |
+
detach_reset=True,
|
| 408 |
+
surrogate_function=surrogate.ATan(alpha=4.0),
|
| 409 |
+
backend="torch",
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
self.sfft = SFFT(self.M)
|
| 413 |
+
self.F_bins = self.sfft.F
|
| 414 |
+
|
| 415 |
+
# frequency gate
|
| 416 |
+
self.freq_gate = HardConcreteGate(self.F_bins, init_logit=2.0)
|
| 417 |
+
self.register_buffer("gate_tau", torch.tensor(0.10))
|
| 418 |
+
|
| 419 |
+
self.sfgo = SFGO(
|
| 420 |
+
self.args,
|
| 421 |
+
E=self.E,
|
| 422 |
+
hidden_size_factor=hidden_size_factor,
|
| 423 |
+
tau=args.tau,
|
| 424 |
+
v_th=args.alpha,
|
| 425 |
+
apply_gate_to_complex=True,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
# decoder
|
| 429 |
+
proj_dim = self.args.proj_dim
|
| 430 |
+
reduced_dim = max(16, min(128, hidden_size // 4))
|
| 431 |
+
self.decoder = Decoder(
|
| 432 |
+
E=self.E,
|
| 433 |
+
L=self.L,
|
| 434 |
+
pred_len=pre_length,
|
| 435 |
+
T=self.T,
|
| 436 |
+
tau=args.tau,
|
| 437 |
+
v_th=args.alpha,
|
| 438 |
+
proj_dim=proj_dim,
|
| 439 |
+
reduced_dim=reduced_dim,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
def node_embed(self, x: torch.Tensor) -> torch.Tensor:
|
| 443 |
+
# x: [B,L,N] -> [B,M,E]
|
| 444 |
+
B, L, N = x.shape
|
| 445 |
+
x_flat = x.permute(0, 2, 1).contiguous().reshape(B, self.M) # [B,M]
|
| 446 |
+
tok = x_flat.unsqueeze(-1) * self.embeddings # [B,M,E]
|
| 447 |
+
tok = self.node_aff(tok)
|
| 448 |
+
return tok
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 452 |
+
B, L, N = x.shape
|
| 453 |
+
|
| 454 |
+
# normalize
|
| 455 |
+
if self.args.normalize:
|
| 456 |
+
mean = x.mean(dim=1, keepdim=True).detach()
|
| 457 |
+
x0 = x - mean
|
| 458 |
+
std = torch.sqrt(torch.var(x0, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
| 459 |
+
x0 = x0 / std
|
| 460 |
+
else:
|
| 461 |
+
mean, std = None, None
|
| 462 |
+
x0 = x
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
tok = self.node_embed(x0) # [B,M,E]
|
| 466 |
+
tok = self.node_rms(tok) # RMSNorm
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
# step modulation
|
| 470 |
+
cur_t = tok.unsqueeze(0).repeat(self.T, 1, 1, 1)
|
| 471 |
+
cur_t = cur_t * self.step_gamma.view(self.T, 1, 1, 1) + self.step_beta.view(self.T, 1, 1, 1)
|
| 472 |
+
cur_t = cur_t * (1.0 + 0.02 * self.step_scale.to(cur_t.dtype))
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
# spikes
|
| 476 |
+
s_t = self.encoder_lif(cur_t)
|
| 477 |
+
if self.use_cpg_pe:
|
| 478 |
+
pe_spk = self.cpg_pe(T=self.T, B=B, M=self.M, device=x.device) # [T,B,M,2*N_pe]
|
| 479 |
+
s_cat = torch.cat([s_t, pe_spk], dim=-1) # [T,B,M,E+2*N_pe]
|
| 480 |
+
h = self.pe_linear(s_cat) # [T,B,M,E]
|
| 481 |
+
h = h.reshape(self.T * B * self.M, self.E)
|
| 482 |
+
h = self.pe_bn(h).view(self.T, B, self.M, self.E)
|
| 483 |
+
s_t = self.pe_lif(h)
|
| 484 |
+
|
| 485 |
+
enc_rate = s_t.mean()
|
| 486 |
+
|
| 487 |
+
# FFT
|
| 488 |
+
Z_t = self.sfft.rfft(s_t)
|
| 489 |
+
|
| 490 |
+
# prune
|
| 491 |
+
Z_t, m = self.freq_gate(Z_t, tau=float(self.gate_tau))
|
| 492 |
+
|
| 493 |
+
# S-FGO blocks
|
| 494 |
+
Z_t, fb_stats = self.sfgo(Z_t)
|
| 495 |
+
|
| 496 |
+
# iFFT
|
| 497 |
+
y_time_t = self.sfft.irfft(Z_t).to(tok.dtype)
|
| 498 |
+
|
| 499 |
+
y_t = y_time_t.view(self.T, B, N, self.L, self.E).permute(0, 1, 2, 4, 3).contiguous()
|
| 500 |
+
|
| 501 |
+
preds, dec_stats = self.decoder(y_t)
|
| 502 |
+
|
| 503 |
+
if self.args.normalize:
|
| 504 |
+
preds = preds * std + mean # denormalize
|
| 505 |
+
|
| 506 |
+
aux = {
|
| 507 |
+
"enc_rate": enc_rate.detach(),
|
| 508 |
+
"rho_hat": self.freq_gate.l0().detach(),
|
| 509 |
+
"freq_mask_mean": m.mean().detach(),
|
| 510 |
+
"freq_mask_active": (m > 0.5).float().mean().detach(),
|
| 511 |
+
**fb_stats,
|
| 512 |
+
**dec_stats,
|
| 513 |
+
}
|
| 514 |
+
return preds, aux
|
model/SpikeGRU.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from spikingjelly.activation_based import surrogate as sj_surrogate
|
| 5 |
+
from snntorch import utils
|
| 6 |
+
import snntorch as snn
|
| 7 |
+
from snntorch import surrogate
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class GRUCell(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
input_size: int,
|
| 17 |
+
hidden_size: int,
|
| 18 |
+
num_steps: int = 4,
|
| 19 |
+
grad_slope: float = 25.0,
|
| 20 |
+
beta: float = 0.99,
|
| 21 |
+
output_mems: bool = False,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.spike_grad = surrogate.atan(alpha=2.0)
|
| 25 |
+
self.input_size = input_size
|
| 26 |
+
self.num_steps = num_steps
|
| 27 |
+
self.hidden_size = hidden_size
|
| 28 |
+
self.beta = beta
|
| 29 |
+
self.full_rec = output_mems
|
| 30 |
+
self.lif = snn.Leaky(
|
| 31 |
+
beta=self.beta,
|
| 32 |
+
spike_grad=self.spike_grad,
|
| 33 |
+
init_hidden=True,
|
| 34 |
+
output=output_mems,
|
| 35 |
+
)
|
| 36 |
+
self.linear_ih = nn.Linear(input_size, 3 * hidden_size)
|
| 37 |
+
self.linear_hh = nn.Linear(hidden_size, 3 * hidden_size)
|
| 38 |
+
self.surrogate_function1 = sj_surrogate.ATan()
|
| 39 |
+
|
| 40 |
+
def forward(self, inputs):
|
| 41 |
+
if inputs.size(-1) == self.input_size:
|
| 42 |
+
# assume static spikes:
|
| 43 |
+
h = torch.zeros(
|
| 44 |
+
size=[inputs.shape[0], self.hidden_size],
|
| 45 |
+
dtype=torch.float,
|
| 46 |
+
device=inputs.device,
|
| 47 |
+
)
|
| 48 |
+
y_ih = torch.split(self.linear_ih(inputs), self.hidden_size, dim=1)
|
| 49 |
+
y_hh = torch.split(self.linear_hh(h), self.hidden_size, dim=1)
|
| 50 |
+
r = self.surrogate_function1(y_ih[0] + y_hh[0])
|
| 51 |
+
z = self.surrogate_function1(y_ih[1] + y_hh[1])
|
| 52 |
+
n = self.surrogate_function1(y_ih[2] + r * y_hh[2])
|
| 53 |
+
h = (1.0 - z) * n + z * h
|
| 54 |
+
cur = h
|
| 55 |
+
static = True
|
| 56 |
+
elif inputs.size(-1) == self.num_steps and inputs.size(-2) == self.input_size:
|
| 57 |
+
inputs = inputs.transpose(-1, -2) # BC, T, H
|
| 58 |
+
h = torch.zeros(
|
| 59 |
+
size=[inputs.shape[0], self.hidden_size, self.num_steps],
|
| 60 |
+
dtype=torch.float,
|
| 61 |
+
device=inputs.device,
|
| 62 |
+
)
|
| 63 |
+
y_ih = torch.split(
|
| 64 |
+
self.linear_ih(inputs).transpose(-1, -2), self.hidden_size, dim=1
|
| 65 |
+
)
|
| 66 |
+
y_hh = torch.split(
|
| 67 |
+
self.linear_hh(h.transpose(-1, -2)).transpose(-1, -2),
|
| 68 |
+
self.hidden_size,
|
| 69 |
+
dim=1,
|
| 70 |
+
)
|
| 71 |
+
r = self.surrogate_function1(y_ih[0] + y_hh[0])
|
| 72 |
+
z = self.surrogate_function1(y_ih[1] + y_hh[1])
|
| 73 |
+
n = self.surrogate_function1(y_ih[2] + r * y_hh[2])
|
| 74 |
+
h = (1.0 - z) * n + z * h
|
| 75 |
+
cur = h
|
| 76 |
+
static = False
|
| 77 |
+
else:
|
| 78 |
+
raise ValueError(
|
| 79 |
+
f"Input size mismatch!"
|
| 80 |
+
f"Got {inputs.size()} but expected (..., {self.input_size}, {self.num_steps}) or (..., {self.input_size})"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
spk_rec = []
|
| 84 |
+
mem_rec = []
|
| 85 |
+
if self.full_rec:
|
| 86 |
+
for i_step in range(self.num_steps):
|
| 87 |
+
if static:
|
| 88 |
+
spk, mem = self.lif(cur)
|
| 89 |
+
else:
|
| 90 |
+
spk, mem = self.lif(cur[:, :, i_step])
|
| 91 |
+
spk_rec.append(spk)
|
| 92 |
+
mem_rec.append(mem)
|
| 93 |
+
spks = torch.stack(spk_rec, dim=-1)
|
| 94 |
+
mems = torch.stack(mem_rec, dim=-1)
|
| 95 |
+
return spks, mems
|
| 96 |
+
else:
|
| 97 |
+
for i_step in range(self.num_steps):
|
| 98 |
+
if static:
|
| 99 |
+
spk = self.lif(cur)
|
| 100 |
+
else:
|
| 101 |
+
spk = self.lif(cur[:, :, i_step])
|
| 102 |
+
spk_rec.append(spk)
|
| 103 |
+
spks = torch.stack(spk_rec, dim=-1)
|
| 104 |
+
return spks
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class DeltaEncoder(nn.Module):
|
| 108 |
+
def __init__(self, output_size: int):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.norm = nn.BatchNorm2d(1)
|
| 111 |
+
self.enc = nn.Linear(1, output_size)
|
| 112 |
+
self.lif = snn.Leaky(
|
| 113 |
+
beta=0.99, spike_grad=surrogate.atan(), init_hidden=True, output=False
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def forward(self, inputs: torch.Tensor):
|
| 117 |
+
# inputs: batch, L, C
|
| 118 |
+
delta = torch.zeros_like(inputs)
|
| 119 |
+
delta[:, 1:] = inputs[:, 1:, :] - inputs[:, :-1, :]
|
| 120 |
+
delta = delta.unsqueeze(1).permute(0, 1, 3, 2) # batch, 1, C, L
|
| 121 |
+
delta = self.norm(delta)
|
| 122 |
+
delta = delta.permute(0, 2, 3, 1) # batch, C, L, 1
|
| 123 |
+
enc = self.enc(delta) # batch, C, L, output_size
|
| 124 |
+
enc = enc.permute(0, 3, 1, 2) # batch, output_size, C, L
|
| 125 |
+
spks = self.lif(enc)
|
| 126 |
+
return spks
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class ConvEncoder(nn.Module):
|
| 130 |
+
def __init__(self, output_size: int, kernel_size: int = 3):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.encoder = nn.Sequential(
|
| 133 |
+
nn.Conv2d(
|
| 134 |
+
in_channels=1,
|
| 135 |
+
out_channels=output_size,
|
| 136 |
+
kernel_size=(1, kernel_size),
|
| 137 |
+
stride=1,
|
| 138 |
+
padding=(0, kernel_size // 2),
|
| 139 |
+
),
|
| 140 |
+
nn.BatchNorm2d(output_size),
|
| 141 |
+
)
|
| 142 |
+
self.lif = snn.Leaky(
|
| 143 |
+
beta=0.99,
|
| 144 |
+
spike_grad=surrogate.atan(alpha=2.0),
|
| 145 |
+
init_hidden=True,
|
| 146 |
+
output=False,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def forward(self, inputs: torch.Tensor):
|
| 150 |
+
# inputs: batch, L, C
|
| 151 |
+
inputs = inputs.permute(0, 2, 1).unsqueeze(1) # batch, 1, C, L
|
| 152 |
+
enc = self.encoder(inputs) # batch, output_size, C, L
|
| 153 |
+
spks = self.lif(enc)
|
| 154 |
+
return spks
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class SpikeGRU(nn.Module):
|
| 159 |
+
def __init__(
|
| 160 |
+
self,
|
| 161 |
+
args,
|
| 162 |
+
hidden_size: int,
|
| 163 |
+
layers: int = 1,
|
| 164 |
+
num_steps: int = 50,
|
| 165 |
+
grad_slope: float = 25.0,
|
| 166 |
+
input_size: Optional[int] = None,
|
| 167 |
+
max_length: Optional[int] = None,
|
| 168 |
+
weight_file: Optional[Path] = None,
|
| 169 |
+
encoder_type: Optional[str] = "conv",
|
| 170 |
+
):
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.args = args
|
| 173 |
+
self.hidden_size = args.hidden_size
|
| 174 |
+
self.num_steps = args.T
|
| 175 |
+
self.input_size = args.feature_size
|
| 176 |
+
self.pre_length = args.pre_length
|
| 177 |
+
self.layers = args.blocks
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
if encoder_type == "conv":
|
| 181 |
+
self.encoder = ConvEncoder(self.hidden_size)
|
| 182 |
+
elif encoder_type == "delta":
|
| 183 |
+
self.encoder = DeltaEncoder(self.hidden_size)
|
| 184 |
+
else:
|
| 185 |
+
raise ValueError(f"Unknown encoder type {encoder_type}")
|
| 186 |
+
|
| 187 |
+
self.net = nn.Sequential(
|
| 188 |
+
*[
|
| 189 |
+
GRUCell(
|
| 190 |
+
self.hidden_size,
|
| 191 |
+
self.hidden_size,
|
| 192 |
+
num_steps=self.num_steps,
|
| 193 |
+
grad_slope=grad_slope,
|
| 194 |
+
output_mems=(i == self.layers - 1),
|
| 195 |
+
)
|
| 196 |
+
for i in range(self.layers)
|
| 197 |
+
]
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
self.__output_size = self.hidden_size
|
| 201 |
+
self.fc = nn.Linear(self.__output_size, self.pre_length)
|
| 202 |
+
|
| 203 |
+
self.to('cuda:0')
|
| 204 |
+
|
| 205 |
+
def forward(
|
| 206 |
+
self,
|
| 207 |
+
inputs: torch.Tensor,
|
| 208 |
+
):
|
| 209 |
+
utils.reset(self.encoder)
|
| 210 |
+
for layer in self.net:
|
| 211 |
+
utils.reset(layer)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
bs, length, c_num = inputs.size()
|
| 215 |
+
|
| 216 |
+
if self.args.normalize:
|
| 217 |
+
mean = inputs.mean(dim=1, keepdim=True).detach() # shape [B, 1, D]
|
| 218 |
+
inputs = inputs - mean
|
| 219 |
+
std = torch.sqrt(torch.var(inputs, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
| 220 |
+
inputs = inputs / std
|
| 221 |
+
|
| 222 |
+
h = self.encoder(inputs) # B, H, C, L
|
| 223 |
+
hidden_size = h.size(1)
|
| 224 |
+
h = h.permute(0, 2, 3, 1).reshape(bs * c_num, length, hidden_size) # BC, L, H
|
| 225 |
+
for i in range(length):
|
| 226 |
+
spks, mems = self.net(h[:, i, :])
|
| 227 |
+
spks = spks.reshape(bs, c_num * hidden_size, -1) # B, CH, Time Step
|
| 228 |
+
spks = spks[:, :, -1] # aggregate over time dimension shape, (B, CH)
|
| 229 |
+
preds = self.fc(spks.view(bs, c_num, -1)).squeeze(-1) # B, O, C
|
| 230 |
+
preds = preds.permute(0, 2, 1).contiguous()
|
| 231 |
+
|
| 232 |
+
if self.args.normalize:
|
| 233 |
+
preds = preds * std + mean # denormalize
|
| 234 |
+
|
| 235 |
+
aux = {'gate_l0': torch.tensor(0.0, device=preds.device)} # palceholder
|
| 236 |
+
|
| 237 |
+
return preds, aux
|
| 238 |
+
|
| 239 |
+
@property
|
| 240 |
+
def output_size(self):
|
| 241 |
+
return self.__output_size
|
model/SpikeRNN_CPG.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from spikingjelly.activation_based import surrogate, neuron, functional
|
| 6 |
+
import math
|
| 7 |
+
import copy
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
tau = 2.0
|
| 11 |
+
backend = "torch"
|
| 12 |
+
detach_reset = True
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def generate_ones_and_minus_ones_matrix(rows, cols):
|
| 17 |
+
random_matrix = torch.randint(0, 2, (rows, cols))
|
| 18 |
+
binary_matrix = torch.where(
|
| 19 |
+
random_matrix == 0,
|
| 20 |
+
-1 * torch.ones_like(random_matrix),
|
| 21 |
+
torch.ones_like(random_matrix),
|
| 22 |
+
)
|
| 23 |
+
return binary_matrix.float()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class RandomPE(nn.Module):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
d_model,
|
| 30 |
+
pe_mode="concat",
|
| 31 |
+
num_pe_neuron=10,
|
| 32 |
+
neuron_pe_scale=1000.0,
|
| 33 |
+
dropout=0.1,
|
| 34 |
+
num_steps=4,
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.max_len = 5000 # different from windows
|
| 38 |
+
self.pe_mode = pe_mode
|
| 39 |
+
self.neuron_pe_scale = neuron_pe_scale
|
| 40 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 41 |
+
if self.pe_mode == "concat":
|
| 42 |
+
self.num_pe_neuron = copy.deepcopy(num_pe_neuron)
|
| 43 |
+
elif self.pe_mode == "add":
|
| 44 |
+
self.num_pe_neuron = copy.deepcopy(d_model)
|
| 45 |
+
pe = generate_ones_and_minus_ones_matrix(
|
| 46 |
+
self.max_len, self.num_pe_neuron
|
| 47 |
+
) # MaxL, Neur
|
| 48 |
+
pe = pe.unsqueeze(0).transpose(0, 1) # MaxL, 1, Neur
|
| 49 |
+
print("pe.shape: ", pe.shape)
|
| 50 |
+
self.register_buffer("pe", pe)
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
# T, B, L, D
|
| 54 |
+
T, B, L, _ = x.shape
|
| 55 |
+
x = x.permute(1, 0, 2, 3) # B, T, L, D
|
| 56 |
+
x = x.flatten(1, 2) # B, TL, D
|
| 57 |
+
if self.pe_mode == "concat":
|
| 58 |
+
# tmp: TL, 1, Neur -> TL, B, Neur -> B, TL, Neur
|
| 59 |
+
tmp = self.pe[: x.size(-2), :].repeat(1, B, 1).transpose(0, 1)
|
| 60 |
+
x = torch.concat([x, tmp], dim=-1)
|
| 61 |
+
# print(x.shape) # B, TL, D'
|
| 62 |
+
elif self.pe_mode == "add":
|
| 63 |
+
# [B, TL, D] + [1, TL, Neur]
|
| 64 |
+
x = x + self.pe[: x.size(-2), :].transpose(0, 1)
|
| 65 |
+
# print(x.shape) # B, TL, D
|
| 66 |
+
x = x.transpose(0, 1) # TL, B D
|
| 67 |
+
x = x.reshape(T, L, B, -1) # T, L, B, D
|
| 68 |
+
x = x.permute(0, 2, 1, 3) # T, B, L, D
|
| 69 |
+
return self.dropout(x)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class NeuronPE(nn.Module):
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
d_model,
|
| 76 |
+
pe_mode="concat",
|
| 77 |
+
num_pe_neuron=10,
|
| 78 |
+
neuron_pe_scale=10000.0,
|
| 79 |
+
dropout=0.1,
|
| 80 |
+
num_steps=4,
|
| 81 |
+
):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.max_len = 50000 # different from windows
|
| 84 |
+
self.pe_mode = pe_mode
|
| 85 |
+
self.neuron_pe_scale = neuron_pe_scale
|
| 86 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 87 |
+
if self.pe_mode == "concat":
|
| 88 |
+
self.num_pe_neuron = copy.deepcopy(num_pe_neuron)
|
| 89 |
+
elif self.pe_mode == "add":
|
| 90 |
+
self.num_pe_neuron = copy.deepcopy(d_model)
|
| 91 |
+
pe = torch.zeros(self.max_len, self.num_pe_neuron) # MaxL, Neur
|
| 92 |
+
position = torch.arange(0, self.max_len, dtype=torch.float).unsqueeze(
|
| 93 |
+
1
|
| 94 |
+
) # MaxL, 1
|
| 95 |
+
div_term = torch.exp(
|
| 96 |
+
torch.arange(0, self.num_pe_neuron, 2).float()
|
| 97 |
+
* (-math.log(neuron_pe_scale) / self.num_pe_neuron)
|
| 98 |
+
)
|
| 99 |
+
div_term_single = torch.exp(
|
| 100 |
+
torch.arange(0, self.num_pe_neuron - 1, 2).float()
|
| 101 |
+
* (-math.log(neuron_pe_scale) / self.num_pe_neuron)
|
| 102 |
+
)
|
| 103 |
+
pe[:, 0::2] = torch.heaviside(
|
| 104 |
+
torch.sin(position * div_term) - 0.8, torch.tensor([1.0])
|
| 105 |
+
)
|
| 106 |
+
pe[:, 1::2] = torch.heaviside(
|
| 107 |
+
torch.cos(position * div_term_single) - 0.8, torch.tensor([1.0])
|
| 108 |
+
)
|
| 109 |
+
pe = pe.unsqueeze(0).transpose(0, 1) # MaxL, 1, Neur
|
| 110 |
+
print("pe.shape: ", pe.shape)
|
| 111 |
+
self.register_buffer("pe", pe)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
# T, B, L, D
|
| 115 |
+
T, B, L, _ = x.shape
|
| 116 |
+
x = x.permute(1, 0, 2, 3) # B, T, L, D
|
| 117 |
+
x = x.flatten(1, 2) # B, TL, D
|
| 118 |
+
if self.pe_mode == "concat":
|
| 119 |
+
# tmp: TL, 1, Neur -> TL, B, Neur -> B, TL, Neur
|
| 120 |
+
tmp = self.pe[: x.size(-2), :].repeat(1, B, 1).transpose(0, 1)
|
| 121 |
+
x = torch.concat([x, tmp], dim=-1)
|
| 122 |
+
# print(x.shape) # B, TL, D'
|
| 123 |
+
elif self.pe_mode == "add":
|
| 124 |
+
# [B, TL, D] + [1, TL, Neur]
|
| 125 |
+
x = x + self.pe[: x.size(-2), :].transpose(0, 1)
|
| 126 |
+
# print(x.shape) # B, TL, D
|
| 127 |
+
x = x.transpose(0, 1) # TL, B D
|
| 128 |
+
x = x.reshape(T, L, B, -1) # T, L, B, D
|
| 129 |
+
x = x.permute(0, 2, 1, 3) # T, B, L, D
|
| 130 |
+
return self.dropout(x)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class StaticPE(nn.Module):
|
| 134 |
+
r"""Inject some information about the relative or absolute position of the tokens
|
| 135 |
+
in the sequence. The positional encodings have the same dimension as
|
| 136 |
+
the embeddings, so that the two can be summed. Here, we use sine and cosine
|
| 137 |
+
functions of different frequencies.
|
| 138 |
+
.. math::
|
| 139 |
+
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
|
| 140 |
+
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
|
| 141 |
+
\text{where pos is the word position and i is the embed idx)"""
|
| 142 |
+
|
| 143 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 146 |
+
pe = torch.zeros(max_len, d_model) # MaxL, D
|
| 147 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # MaxL, 1
|
| 148 |
+
div_term = torch.exp(
|
| 149 |
+
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
|
| 150 |
+
)
|
| 151 |
+
div_term_single = torch.exp(
|
| 152 |
+
torch.arange(0, d_model - 1, 2).float() * (-math.log(10000.0) / d_model)
|
| 153 |
+
)
|
| 154 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 155 |
+
pe[:, 1::2] = torch.cos(position * div_term_single)
|
| 156 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
| 157 |
+
self.register_buffer("pe", pe)
|
| 158 |
+
|
| 159 |
+
def forward(self, x):
|
| 160 |
+
# x: L, TB, D
|
| 161 |
+
x = x + self.pe[: x.size(0), :]
|
| 162 |
+
x = self.dropout(x)
|
| 163 |
+
return x
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class ConvPE(nn.Module):
|
| 167 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000, num_steps=4):
|
| 168 |
+
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.T = num_steps
|
| 171 |
+
self.rpe_conv = nn.Conv1d(
|
| 172 |
+
d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False
|
| 173 |
+
)
|
| 174 |
+
self.rpe_bn = nn.BatchNorm1d(d_model)
|
| 175 |
+
self.rpe_lif = neuron.LIFNode(
|
| 176 |
+
step_mode="m",
|
| 177 |
+
detach_reset=True,
|
| 178 |
+
surrogate_function=surrogate.ATan(),
|
| 179 |
+
v_threshold=1.0,
|
| 180 |
+
)
|
| 181 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 182 |
+
|
| 183 |
+
def forward(self, x):
|
| 184 |
+
# x: L, TB, D
|
| 185 |
+
L, TB, D = x.shape
|
| 186 |
+
x_feat = x.permute(1, 2, 0) # TB, D, L
|
| 187 |
+
x_feat = self.rpe_conv(x_feat) # TB, D, L
|
| 188 |
+
x_feat = (
|
| 189 |
+
self.rpe_bn(x_feat).reshape(self.T, int(TB / self.T), D, L).contiguous()
|
| 190 |
+
) # T, B, D, L
|
| 191 |
+
x_feat = self.rpe_lif(x_feat)
|
| 192 |
+
x_feat = x_feat.flatten(0, 1) # TB, D, L
|
| 193 |
+
x_feat = self.dropout(x_feat) # TB, D, L
|
| 194 |
+
x_feat = x_feat.permute(2, 0, 1) # L, TB, D
|
| 195 |
+
x = x + x_feat
|
| 196 |
+
return x
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class PositionEmbedding(nn.Module):
|
| 200 |
+
def __init__(
|
| 201 |
+
self,
|
| 202 |
+
input_size: int,
|
| 203 |
+
pe_type: str,
|
| 204 |
+
max_len: int = 5000,
|
| 205 |
+
pe_mode: str = "add",
|
| 206 |
+
num_pe_neuron: int = 10,
|
| 207 |
+
neuron_pe_scale: float = 1000.0,
|
| 208 |
+
dropout=0.1,
|
| 209 |
+
num_steps=4,
|
| 210 |
+
):
|
| 211 |
+
super().__init__()
|
| 212 |
+
self.emb_type = pe_type
|
| 213 |
+
if pe_type in ["learn", "none"]:
|
| 214 |
+
self.emb = nn.Embedding(max_len, input_size)
|
| 215 |
+
elif pe_type == "conv":
|
| 216 |
+
self.emb = ConvPE(
|
| 217 |
+
d_model=input_size,
|
| 218 |
+
max_len=max_len,
|
| 219 |
+
dropout=dropout,
|
| 220 |
+
num_steps=num_steps,
|
| 221 |
+
)
|
| 222 |
+
elif pe_type == "static":
|
| 223 |
+
self.emb = StaticPE(d_model=input_size, max_len=max_len, dropout=dropout)
|
| 224 |
+
elif pe_type == "neuron":
|
| 225 |
+
self.emb = NeuronPE(
|
| 226 |
+
d_model=input_size,
|
| 227 |
+
pe_mode=pe_mode,
|
| 228 |
+
num_pe_neuron=num_pe_neuron,
|
| 229 |
+
neuron_pe_scale=neuron_pe_scale,
|
| 230 |
+
dropout=dropout,
|
| 231 |
+
num_steps=num_steps,
|
| 232 |
+
)
|
| 233 |
+
elif pe_type == "random":
|
| 234 |
+
self.emb = RandomPE(
|
| 235 |
+
d_model=input_size,
|
| 236 |
+
pe_mode=pe_mode,
|
| 237 |
+
num_pe_neuron=num_pe_neuron,
|
| 238 |
+
neuron_pe_scale=neuron_pe_scale,
|
| 239 |
+
dropout=dropout,
|
| 240 |
+
num_steps=num_steps,
|
| 241 |
+
)
|
| 242 |
+
else:
|
| 243 |
+
raise ValueError("Unknown embedding type: {}".format(pe_type))
|
| 244 |
+
|
| 245 |
+
def forward(self, x):
|
| 246 |
+
if self.emb_type == "learn":
|
| 247 |
+
# T, B, L, D = x.shape # x: T, B, L, D
|
| 248 |
+
# x = x.flatten(0, 1) # TB, L, D
|
| 249 |
+
tmp = torch.arange(
|
| 250 |
+
end=x.size()[1], device=x.device
|
| 251 |
+
) # [0,1,2,...,L-1], shape: L
|
| 252 |
+
embedding = self.emb(tmp) # shape: L, D
|
| 253 |
+
embedding = embedding.repeat([x.size()[0], 1, 1]) # TB, L, D'
|
| 254 |
+
x = x + embedding
|
| 255 |
+
# x = x.reshape(T, B, L, -1)
|
| 256 |
+
elif self.emb_type in ["static", "conv"]:
|
| 257 |
+
T, B, L, _ = x.shape # x: T, B, L, D
|
| 258 |
+
x = x.flatten(0, 1) # TB, L, D
|
| 259 |
+
x = self.emb(x.transpose(0, 1)).transpose(0, 1) # x: TB, L, D'
|
| 260 |
+
x = x.reshape(T, B, L, -1)
|
| 261 |
+
elif self.emb_type in ["neuron", "random"]:
|
| 262 |
+
T, B, L, _ = x.shape # x: T, B, L, D
|
| 263 |
+
# T, B, L, D
|
| 264 |
+
x = self.emb(x)
|
| 265 |
+
x = x.reshape(T, B, L, -1)
|
| 266 |
+
return x # T, B, L, D'
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class RepeatEncoder(nn.Module):
|
| 270 |
+
def __init__(self, output_size: int):
|
| 271 |
+
super().__init__()
|
| 272 |
+
self.out_size = output_size
|
| 273 |
+
self.lif = neuron.LIFNode(
|
| 274 |
+
tau=tau,
|
| 275 |
+
step_mode="m",
|
| 276 |
+
detach_reset=detach_reset,
|
| 277 |
+
surrogate_function=surrogate.ATan(),
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
def forward(self, inputs: torch.Tensor):
|
| 281 |
+
# inputs: B, L, C
|
| 282 |
+
inputs = inputs.repeat(
|
| 283 |
+
tuple([self.out_size] + torch.ones(len(inputs.size()), dtype=int).tolist())
|
| 284 |
+
) # T B L C
|
| 285 |
+
inputs = inputs.permute(0, 1, 3, 2) # T B C L
|
| 286 |
+
spks = self.lif(inputs) # T B C L
|
| 287 |
+
return spks
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class DeltaEncoder(nn.Module):
|
| 291 |
+
def __init__(self, output_size: int):
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.norm = nn.BatchNorm2d(1)
|
| 294 |
+
self.enc = nn.Linear(1, output_size)
|
| 295 |
+
self.lif = neuron.LIFNode(
|
| 296 |
+
tau=tau,
|
| 297 |
+
step_mode="m",
|
| 298 |
+
detach_reset=detach_reset,
|
| 299 |
+
surrogate_function=surrogate.ATan(),
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
def forward(self, inputs: torch.Tensor):
|
| 303 |
+
# inputs: B, L, C
|
| 304 |
+
delta = torch.zeros_like(inputs)
|
| 305 |
+
delta[:, 1:] = inputs[:, 1:, :] - inputs[:, :-1, :]
|
| 306 |
+
delta = delta.unsqueeze(1).permute(0, 1, 3, 2) # B, 1, C, L
|
| 307 |
+
delta = self.norm(delta)
|
| 308 |
+
delta = delta.permute(0, 2, 3, 1) # B, C, L, 1
|
| 309 |
+
enc = self.enc(delta) # B, C, L, T
|
| 310 |
+
enc = enc.permute(3, 0, 1, 2) # T, B, C, L
|
| 311 |
+
spks = self.lif(enc)
|
| 312 |
+
return spks
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class ConvEncoder(nn.Module):
|
| 316 |
+
def __init__(self, output_size: int, kernel_size: int = 3):
|
| 317 |
+
super().__init__()
|
| 318 |
+
self.encoder = nn.Sequential(
|
| 319 |
+
nn.Conv2d(
|
| 320 |
+
in_channels=1,
|
| 321 |
+
out_channels=output_size,
|
| 322 |
+
kernel_size=(1, kernel_size),
|
| 323 |
+
stride=1,
|
| 324 |
+
padding=(0, kernel_size // 2),
|
| 325 |
+
),
|
| 326 |
+
nn.BatchNorm2d(output_size),
|
| 327 |
+
)
|
| 328 |
+
self.lif = neuron.LIFNode(
|
| 329 |
+
tau=tau,
|
| 330 |
+
step_mode="m",
|
| 331 |
+
detach_reset=detach_reset,
|
| 332 |
+
surrogate_function=surrogate.ATan(),
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
def forward(self, inputs: torch.Tensor):
|
| 336 |
+
# inputs: B, L, C
|
| 337 |
+
inputs = inputs.permute(0, 2, 1).unsqueeze(1) # B, 1, C, L
|
| 338 |
+
enc = self.encoder(inputs) # B, T, C, L
|
| 339 |
+
enc = enc.permute(1, 0, 2, 3) # T, B, C, L
|
| 340 |
+
spks = self.lif(enc) # T, B, C, L
|
| 341 |
+
return spks
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
SpikeEncoder = {
|
| 346 |
+
"snntorch": {
|
| 347 |
+
"repeat": RepeatEncoder,
|
| 348 |
+
"conv": ConvEncoder,
|
| 349 |
+
"delta": DeltaEncoder,
|
| 350 |
+
},
|
| 351 |
+
"spikingjelly": {
|
| 352 |
+
"repeat": RepeatEncoder,
|
| 353 |
+
"conv": ConvEncoder,
|
| 354 |
+
"delta": DeltaEncoder,
|
| 355 |
+
},
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class SpikeRNNCell(nn.Module):
|
| 361 |
+
def __init__(self, input_size: int, output_size: int):
|
| 362 |
+
super().__init__()
|
| 363 |
+
self.input_size = input_size
|
| 364 |
+
self.linear = nn.Linear(input_size, output_size)
|
| 365 |
+
self.lif = neuron.LIFNode(
|
| 366 |
+
tau=tau,
|
| 367 |
+
step_mode="m",
|
| 368 |
+
detach_reset=detach_reset,
|
| 369 |
+
surrogate_function=surrogate.ATan(),
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
def forward(self, x):
|
| 373 |
+
# T, B, L, C'
|
| 374 |
+
T, B, L, _ = x.shape
|
| 375 |
+
x = x.flatten(0, 1) # TB, L, C'
|
| 376 |
+
x = self.linear(x)
|
| 377 |
+
x = x.reshape(T, B, L, -1)
|
| 378 |
+
x = self.lif(x) # T, B, L, C'
|
| 379 |
+
return x
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
class SpikeRNN_CPG(nn.Module):
|
| 383 |
+
|
| 384 |
+
def __init__(
|
| 385 |
+
self,
|
| 386 |
+
args,
|
| 387 |
+
hidden_size: int,
|
| 388 |
+
layers: int = 1,
|
| 389 |
+
num_steps: int = 4,
|
| 390 |
+
input_size: Optional[int] = None,
|
| 391 |
+
max_length: Optional[int] = 5000,
|
| 392 |
+
weight_file: Optional[Path] = None,
|
| 393 |
+
encoder_type: Optional[str] = "conv",
|
| 394 |
+
num_pe_neuron: int = 40,
|
| 395 |
+
pe_type: str = "neuron",
|
| 396 |
+
pe_mode: str = "concat", # "add" or concat
|
| 397 |
+
neuron_pe_scale: float = 10000.0, # "100" or "1000" or "10000"
|
| 398 |
+
):
|
| 399 |
+
super().__init__()
|
| 400 |
+
self._snn_backend = "spikingjelly"
|
| 401 |
+
self.hidden_size = args.hidden_size
|
| 402 |
+
self.num_steps = args.T
|
| 403 |
+
self.input_size = args.feature_size
|
| 404 |
+
self.pre_length = args.pre_length
|
| 405 |
+
self.layers = args.blocks
|
| 406 |
+
self.pe_type = pe_type
|
| 407 |
+
self.pe_mode = pe_mode
|
| 408 |
+
self.num_pe_neuron = num_pe_neuron
|
| 409 |
+
self.neuron_pe_scale = neuron_pe_scale
|
| 410 |
+
self.temporal_encoder = SpikeEncoder[self._snn_backend][encoder_type](self.num_steps)
|
| 411 |
+
self.args = args
|
| 412 |
+
|
| 413 |
+
self.pe = PositionEmbedding(
|
| 414 |
+
pe_type=pe_type,
|
| 415 |
+
pe_mode=pe_mode,
|
| 416 |
+
neuron_pe_scale=neuron_pe_scale,
|
| 417 |
+
input_size=self.input_size,
|
| 418 |
+
max_len=max_length,
|
| 419 |
+
num_pe_neuron=self.num_pe_neuron,
|
| 420 |
+
dropout=0.1,
|
| 421 |
+
num_steps=self.num_steps,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
if self.pe_type == "neuron" and self.pe_mode == "concat":
|
| 425 |
+
self.dim = hidden_size + num_pe_neuron
|
| 426 |
+
else:
|
| 427 |
+
self.dim = hidden_size
|
| 428 |
+
|
| 429 |
+
if self.pe_type == "neuron" and self.pe_mode == "concat":
|
| 430 |
+
self.encoder = nn.Linear(input_size + num_pe_neuron, self.dim)
|
| 431 |
+
else:
|
| 432 |
+
self.encoder = nn.Linear(input_size, self.dim)
|
| 433 |
+
|
| 434 |
+
self.init_lif = neuron.LIFNode(
|
| 435 |
+
tau=tau,
|
| 436 |
+
step_mode="m",
|
| 437 |
+
detach_reset=detach_reset,
|
| 438 |
+
surrogate_function=surrogate.ATan(),
|
| 439 |
+
v_threshold=1.0,
|
| 440 |
+
backend=backend,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
self.net = nn.Sequential(
|
| 444 |
+
*[
|
| 445 |
+
SpikeRNNCell(input_size=self.dim, output_size=self.dim)
|
| 446 |
+
for i in range(layers)
|
| 447 |
+
]
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
self.__output_size = self.dim
|
| 451 |
+
self.fc1 = nn.Linear(self.__output_size, args.feature_size)
|
| 452 |
+
self.fc2 = nn.Linear(args.seq_length, self.pre_length)
|
| 453 |
+
self.to('cuda:0')
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def forward(
|
| 457 |
+
self,
|
| 458 |
+
inputs: torch.Tensor,
|
| 459 |
+
):
|
| 460 |
+
functional.reset_net(self)
|
| 461 |
+
if self.args.normalize:
|
| 462 |
+
mean = inputs.mean(dim=1, keepdim=True).detach() # shape [B, 1, D]
|
| 463 |
+
inputs = inputs - mean
|
| 464 |
+
|
| 465 |
+
std = torch.sqrt(torch.var(inputs, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
| 466 |
+
inputs = inputs / std
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
hiddens = self.temporal_encoder(inputs) # T, B, C, L
|
| 470 |
+
hiddens = hiddens.transpose(-2, -1) # T, B, L, C
|
| 471 |
+
T, B, L, _ = hiddens.size() # T, B, L, D
|
| 472 |
+
if self.pe_type != "none":
|
| 473 |
+
hiddens = self.pe(hiddens) # T B L C'
|
| 474 |
+
hiddens = self.encoder(hiddens.flatten(0, 1)).reshape(T, B, L, -1) # T B L D
|
| 475 |
+
hiddens = self.init_lif(hiddens)
|
| 476 |
+
hiddens = self.net(hiddens) # T, B, L, D
|
| 477 |
+
out = hiddens.mean(0) # B, L, D
|
| 478 |
+
preds = self.fc1(out) # B, L, C
|
| 479 |
+
preds = self.fc2(preds.permute(0, 2, 1)) # B, C, L
|
| 480 |
+
preds = preds.permute(0, 2, 1).contiguous()
|
| 481 |
+
|
| 482 |
+
if self.args.normalize:
|
| 483 |
+
preds = preds * std + mean # denormalize
|
| 484 |
+
|
| 485 |
+
aux = {'gate_l0': torch.tensor(0.0, device=preds.device)} # placeholder
|
| 486 |
+
|
| 487 |
+
return preds, aux
|
| 488 |
+
|
| 489 |
+
|
model/SpikeTCN_CPG.py
ADDED
|
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.utils import weight_norm
|
| 6 |
+
import snntorch as snn
|
| 7 |
+
from snntorch import surrogate
|
| 8 |
+
from snntorch import utils
|
| 9 |
+
import copy
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
def generate_ones_and_minus_ones_matrix(rows, cols):
|
| 13 |
+
random_matrix = torch.randint(0, 2, (rows, cols))
|
| 14 |
+
binary_matrix = torch.where(
|
| 15 |
+
random_matrix == 0,
|
| 16 |
+
-1 * torch.ones_like(random_matrix),
|
| 17 |
+
torch.ones_like(random_matrix),
|
| 18 |
+
)
|
| 19 |
+
return binary_matrix.float()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class RandomPE(nn.Module):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
d_model,
|
| 26 |
+
pe_mode="concat",
|
| 27 |
+
num_pe_neuron=10,
|
| 28 |
+
neuron_pe_scale=1000.0,
|
| 29 |
+
dropout=0.1,
|
| 30 |
+
num_steps=4,
|
| 31 |
+
):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.max_len = 5000 # different from windows
|
| 34 |
+
self.pe_mode = pe_mode
|
| 35 |
+
self.neuron_pe_scale = neuron_pe_scale
|
| 36 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 37 |
+
if self.pe_mode == "concat":
|
| 38 |
+
self.num_pe_neuron = copy.deepcopy(num_pe_neuron)
|
| 39 |
+
elif self.pe_mode == "add":
|
| 40 |
+
self.num_pe_neuron = copy.deepcopy(d_model)
|
| 41 |
+
pe = generate_ones_and_minus_ones_matrix(
|
| 42 |
+
self.max_len, self.num_pe_neuron
|
| 43 |
+
) # MaxL, Neur
|
| 44 |
+
pe = pe.unsqueeze(0).transpose(0, 1) # MaxL, 1, Neur
|
| 45 |
+
print("pe.shape: ", pe.shape)
|
| 46 |
+
self.register_buffer("pe", pe)
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
# T, B, L, D
|
| 50 |
+
T, B, L, _ = x.shape
|
| 51 |
+
x = x.permute(1, 0, 2, 3) # B, T, L, D
|
| 52 |
+
x = x.flatten(1, 2) # B, TL, D
|
| 53 |
+
if self.pe_mode == "concat":
|
| 54 |
+
# tmp: TL, 1, Neur -> TL, B, Neur -> B, TL, Neur
|
| 55 |
+
tmp = self.pe[: x.size(-2), :].repeat(1, B, 1).transpose(0, 1)
|
| 56 |
+
x = torch.concat([x, tmp], dim=-1)
|
| 57 |
+
# print(x.shape) # B, TL, D'
|
| 58 |
+
elif self.pe_mode == "add":
|
| 59 |
+
# [B, TL, D] + [1, TL, Neur]
|
| 60 |
+
x = x + self.pe[: x.size(-2), :].transpose(0, 1)
|
| 61 |
+
# print(x.shape) # B, TL, D
|
| 62 |
+
x = x.transpose(0, 1) # TL, B D
|
| 63 |
+
x = x.reshape(T, L, B, -1) # T, L, B, D
|
| 64 |
+
x = x.permute(0, 2, 1, 3) # T, B, L, D
|
| 65 |
+
return self.dropout(x)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class NeuronPE(nn.Module):
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
d_model,
|
| 72 |
+
pe_mode="concat",
|
| 73 |
+
num_pe_neuron=10,
|
| 74 |
+
neuron_pe_scale=10000.0,
|
| 75 |
+
dropout=0.1,
|
| 76 |
+
num_steps=4,
|
| 77 |
+
):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.max_len = 50000 # different from windows
|
| 80 |
+
self.pe_mode = pe_mode
|
| 81 |
+
self.neuron_pe_scale = neuron_pe_scale
|
| 82 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 83 |
+
if self.pe_mode == "concat":
|
| 84 |
+
self.num_pe_neuron = copy.deepcopy(num_pe_neuron)
|
| 85 |
+
elif self.pe_mode == "add":
|
| 86 |
+
self.num_pe_neuron = copy.deepcopy(d_model)
|
| 87 |
+
pe = torch.zeros(self.max_len, self.num_pe_neuron) # MaxL, Neur
|
| 88 |
+
position = torch.arange(0, self.max_len, dtype=torch.float).unsqueeze(
|
| 89 |
+
1
|
| 90 |
+
) # MaxL, 1
|
| 91 |
+
div_term = torch.exp(
|
| 92 |
+
torch.arange(0, self.num_pe_neuron, 2).float()
|
| 93 |
+
* (-math.log(neuron_pe_scale) / self.num_pe_neuron)
|
| 94 |
+
)
|
| 95 |
+
div_term_single = torch.exp(
|
| 96 |
+
torch.arange(0, self.num_pe_neuron - 1, 2).float()
|
| 97 |
+
* (-math.log(neuron_pe_scale) / self.num_pe_neuron)
|
| 98 |
+
)
|
| 99 |
+
pe[:, 0::2] = torch.heaviside(
|
| 100 |
+
torch.sin(position * div_term) - 0.8, torch.tensor([1.0])
|
| 101 |
+
)
|
| 102 |
+
pe[:, 1::2] = torch.heaviside(
|
| 103 |
+
torch.cos(position * div_term_single) - 0.8, torch.tensor([1.0])
|
| 104 |
+
)
|
| 105 |
+
pe = pe.unsqueeze(0).transpose(0, 1) # MaxL, 1, Neur
|
| 106 |
+
print("pe.shape: ", pe.shape)
|
| 107 |
+
self.register_buffer("pe", pe)
|
| 108 |
+
|
| 109 |
+
def forward(self, x):
|
| 110 |
+
# T, B, L, D
|
| 111 |
+
T, B, L, _ = x.shape
|
| 112 |
+
x = x.permute(1, 0, 2, 3) # B, T, L, D
|
| 113 |
+
x = x.flatten(1, 2) # B, TL, D
|
| 114 |
+
if self.pe_mode == "concat":
|
| 115 |
+
# tmp: TL, 1, Neur -> TL, B, Neur -> B, TL, Neur
|
| 116 |
+
tmp = self.pe[: x.size(-2), :].repeat(1, B, 1).transpose(0, 1)
|
| 117 |
+
x = torch.concat([x, tmp], dim=-1)
|
| 118 |
+
# print(x.shape) # B, TL, D'
|
| 119 |
+
elif self.pe_mode == "add":
|
| 120 |
+
# [B, TL, D] + [1, TL, Neur]
|
| 121 |
+
# print(self.pe[:x.size(-2), :].shape)
|
| 122 |
+
x = x + self.pe[: x.size(-2), :].transpose(0, 1)
|
| 123 |
+
# print(x.shape) # B, TL, D
|
| 124 |
+
x = x.transpose(0, 1) # TL, B D
|
| 125 |
+
x = x.reshape(T, L, B, -1) # T, L, B, D
|
| 126 |
+
x = x.permute(0, 2, 1, 3) # T, B, L, D
|
| 127 |
+
return self.dropout(x)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class StaticPE(nn.Module):
|
| 131 |
+
r"""Inject some information about the relative or absolute position of the tokens
|
| 132 |
+
in the sequence. The positional encodings have the same dimension as
|
| 133 |
+
the embeddings, so that the two can be summed. Here, we use sine and cosine
|
| 134 |
+
functions of different frequencies.
|
| 135 |
+
.. math::
|
| 136 |
+
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
|
| 137 |
+
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
|
| 138 |
+
\text{where pos is the word position and i is the embed idx)"""
|
| 139 |
+
|
| 140 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 143 |
+
pe = torch.zeros(max_len, d_model) # MaxL, D
|
| 144 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # MaxL, 1
|
| 145 |
+
div_term = torch.exp(
|
| 146 |
+
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
|
| 147 |
+
)
|
| 148 |
+
div_term_single = torch.exp(
|
| 149 |
+
torch.arange(0, d_model - 1, 2).float() * (-math.log(10000.0) / d_model)
|
| 150 |
+
)
|
| 151 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 152 |
+
pe[:, 1::2] = torch.cos(position * div_term_single)
|
| 153 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
| 154 |
+
self.register_buffer("pe", pe)
|
| 155 |
+
|
| 156 |
+
def forward(self, x):
|
| 157 |
+
# x: L, TB, D
|
| 158 |
+
x = x + self.pe[: x.size(0), :]
|
| 159 |
+
x = self.dropout(x)
|
| 160 |
+
return x
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class ConvPE(nn.Module):
|
| 164 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000, num_steps=4):
|
| 165 |
+
|
| 166 |
+
super().__init__()
|
| 167 |
+
self.T = num_steps
|
| 168 |
+
self.rpe_conv = nn.Conv1d(
|
| 169 |
+
d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False
|
| 170 |
+
)
|
| 171 |
+
self.rpe_bn = nn.BatchNorm1d(d_model)
|
| 172 |
+
self.rpe_lif = neuron.LIFNode(
|
| 173 |
+
step_mode="m",
|
| 174 |
+
detach_reset=True,
|
| 175 |
+
surrogate_function=surrogate.ATan(),
|
| 176 |
+
v_threshold=1.0,
|
| 177 |
+
)
|
| 178 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 179 |
+
|
| 180 |
+
def forward(self, x):
|
| 181 |
+
# x: L, TB, D
|
| 182 |
+
L, TB, D = x.shape
|
| 183 |
+
x_feat = x.permute(1, 2, 0) # TB, D, L
|
| 184 |
+
x_feat = self.rpe_conv(x_feat) # TB, D, L
|
| 185 |
+
x_feat = (
|
| 186 |
+
self.rpe_bn(x_feat).reshape(self.T, int(TB / self.T), D, L).contiguous()
|
| 187 |
+
) # T, B, D, L
|
| 188 |
+
x_feat = self.rpe_lif(x_feat)
|
| 189 |
+
x_feat = x_feat.flatten(0, 1) # TB, D, L
|
| 190 |
+
x_feat = self.dropout(x_feat) # TB, D, L
|
| 191 |
+
x_feat = x_feat.permute(2, 0, 1) # L, TB, D
|
| 192 |
+
x = x + x_feat
|
| 193 |
+
return x
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class PositionEmbedding(nn.Module):
|
| 197 |
+
def __init__(
|
| 198 |
+
self,
|
| 199 |
+
input_size: int,
|
| 200 |
+
pe_type: str,
|
| 201 |
+
max_len: int = 5000,
|
| 202 |
+
pe_mode: str = "add",
|
| 203 |
+
num_pe_neuron: int = 10,
|
| 204 |
+
neuron_pe_scale: float = 1000.0,
|
| 205 |
+
dropout=0.1,
|
| 206 |
+
num_steps=4,
|
| 207 |
+
):
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.emb_type = pe_type
|
| 210 |
+
if pe_type in ["learn", "none"]:
|
| 211 |
+
self.emb = nn.Embedding(max_len, input_size)
|
| 212 |
+
elif pe_type == "conv":
|
| 213 |
+
self.emb = ConvPE(
|
| 214 |
+
d_model=input_size,
|
| 215 |
+
max_len=max_len,
|
| 216 |
+
dropout=dropout,
|
| 217 |
+
num_steps=num_steps,
|
| 218 |
+
)
|
| 219 |
+
elif pe_type == "static":
|
| 220 |
+
self.emb = StaticPE(d_model=input_size, max_len=max_len, dropout=dropout)
|
| 221 |
+
elif pe_type == "neuron":
|
| 222 |
+
self.emb = NeuronPE(
|
| 223 |
+
d_model=input_size,
|
| 224 |
+
pe_mode=pe_mode,
|
| 225 |
+
num_pe_neuron=num_pe_neuron,
|
| 226 |
+
neuron_pe_scale=neuron_pe_scale,
|
| 227 |
+
dropout=dropout,
|
| 228 |
+
num_steps=num_steps,
|
| 229 |
+
)
|
| 230 |
+
elif pe_type == "random":
|
| 231 |
+
self.emb = RandomPE(
|
| 232 |
+
d_model=input_size,
|
| 233 |
+
pe_mode=pe_mode,
|
| 234 |
+
num_pe_neuron=num_pe_neuron,
|
| 235 |
+
neuron_pe_scale=neuron_pe_scale,
|
| 236 |
+
dropout=dropout,
|
| 237 |
+
num_steps=num_steps,
|
| 238 |
+
)
|
| 239 |
+
else:
|
| 240 |
+
raise ValueError("Unknown embedding type: {}".format(pe_type))
|
| 241 |
+
|
| 242 |
+
def forward(self, x):
|
| 243 |
+
if self.emb_type == "learn":
|
| 244 |
+
# T, B, L, D = x.shape # x: T, B, L, D
|
| 245 |
+
# x = x.flatten(0, 1) # TB, L, D
|
| 246 |
+
tmp = torch.arange(
|
| 247 |
+
end=x.size()[1], device=x.device
|
| 248 |
+
) # [0,1,2,...,L-1], shape: L
|
| 249 |
+
embedding = self.emb(tmp) # shape: L, D
|
| 250 |
+
embedding = embedding.repeat([x.size()[0], 1, 1]) # TB, L, D'
|
| 251 |
+
x = x + embedding
|
| 252 |
+
# x = x.reshape(T, B, L, -1)
|
| 253 |
+
elif self.emb_type in ["static", "conv"]:
|
| 254 |
+
T, B, L, _ = x.shape # x: T, B, L, D
|
| 255 |
+
x = x.flatten(0, 1) # TB, L, D
|
| 256 |
+
x = self.emb(x.transpose(0, 1)).transpose(0, 1) # x: TB, L, D'
|
| 257 |
+
x = x.reshape(T, B, L, -1)
|
| 258 |
+
elif self.emb_type in ["neuron", "random"]:
|
| 259 |
+
T, B, L, _ = x.shape # x: T, B, L, D
|
| 260 |
+
# T, B, L, D
|
| 261 |
+
x = self.emb(x)
|
| 262 |
+
x = x.reshape(T, B, L, -1)
|
| 263 |
+
return x # T, B, L, D'
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class RepeatEncoder(nn.Module):
|
| 271 |
+
def __init__(self, output_size: int):
|
| 272 |
+
super().__init__()
|
| 273 |
+
self.out_size = output_size
|
| 274 |
+
self.lif = snn.Leaky(
|
| 275 |
+
beta=0.99, spike_grad=surrogate.atan(), init_hidden=True, output=False
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
def forward(self, inputs: torch.Tensor):
|
| 279 |
+
# inputs: batch, L, C
|
| 280 |
+
inputs = inputs.repeat(
|
| 281 |
+
tuple([self.out_size] + torch.ones(len(inputs.size()), dtype=int).tolist())
|
| 282 |
+
) # out_size batch L C
|
| 283 |
+
inputs = inputs.permute(1, 0, 3, 2) # batch out_size L C
|
| 284 |
+
spks = self.lif(inputs)
|
| 285 |
+
return spks
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class ConvEncoder(nn.Module):
|
| 289 |
+
def __init__(self, output_size: int, kernel_size: int = 3):
|
| 290 |
+
super().__init__()
|
| 291 |
+
self.encoder = nn.Sequential(
|
| 292 |
+
nn.Conv2d(
|
| 293 |
+
in_channels=1,
|
| 294 |
+
out_channels=output_size,
|
| 295 |
+
kernel_size=(1, kernel_size),
|
| 296 |
+
stride=1,
|
| 297 |
+
padding=(0, kernel_size // 2),
|
| 298 |
+
),
|
| 299 |
+
nn.BatchNorm2d(output_size),
|
| 300 |
+
)
|
| 301 |
+
self.lif = snn.Leaky(
|
| 302 |
+
beta=0.99,
|
| 303 |
+
spike_grad=surrogate.atan(alpha=2.0),
|
| 304 |
+
init_hidden=True,
|
| 305 |
+
output=False,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
def forward(self, inputs: torch.Tensor):
|
| 309 |
+
# inputs: batch, L, C
|
| 310 |
+
inputs = inputs.permute(0, 2, 1).unsqueeze(1) # batch, 1, C, L
|
| 311 |
+
enc = self.encoder(inputs) # batch, output_size, C, L
|
| 312 |
+
spks = self.lif(enc)
|
| 313 |
+
return spks
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class DeltaEncoder(nn.Module):
|
| 317 |
+
def __init__(self, output_size: int):
|
| 318 |
+
super().__init__()
|
| 319 |
+
self.norm = nn.BatchNorm2d(1)
|
| 320 |
+
self.enc = nn.Linear(1, output_size)
|
| 321 |
+
self.lif = snn.Leaky(
|
| 322 |
+
beta=0.99, spike_grad=surrogate.atan(), init_hidden=True, output=False
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
def forward(self, inputs: torch.Tensor):
|
| 326 |
+
# inputs: batch, L, C
|
| 327 |
+
delta = torch.zeros_like(inputs)
|
| 328 |
+
delta[:, 1:] = inputs[:, 1:, :] - inputs[:, :-1, :]
|
| 329 |
+
delta = delta.unsqueeze(1).permute(0, 1, 3, 2) # batch, 1, C, L
|
| 330 |
+
delta = self.norm(delta)
|
| 331 |
+
delta = delta.permute(0, 2, 3, 1) # batch, C, L, 1
|
| 332 |
+
enc = self.enc(delta) # batch, C, L, output_size
|
| 333 |
+
enc = enc.permute(0, 3, 1, 2) # batch, output_size, C, L
|
| 334 |
+
spks = self.lif(enc)
|
| 335 |
+
return spks
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class Chomp1d(nn.Module):
|
| 340 |
+
def __init__(self, chomp_size):
|
| 341 |
+
super().__init__()
|
| 342 |
+
self.chomp_size = chomp_size
|
| 343 |
+
|
| 344 |
+
def forward(self, x):
|
| 345 |
+
return x[:, :, : -self.chomp_size].contiguous()
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class Chomp2d(nn.Module):
|
| 349 |
+
def __init__(self, chomp_size):
|
| 350 |
+
super().__init__()
|
| 351 |
+
self.chomp_size = chomp_size
|
| 352 |
+
|
| 353 |
+
def forward(self, x):
|
| 354 |
+
return x[:, :, :, : -self.chomp_size].contiguous()
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
SpikeEncoder = {
|
| 359 |
+
"snntorch": {
|
| 360 |
+
"repeat": RepeatEncoder,
|
| 361 |
+
"conv": ConvEncoder,
|
| 362 |
+
"delta": DeltaEncoder,
|
| 363 |
+
},
|
| 364 |
+
"spikingjelly": {
|
| 365 |
+
"repeat": RepeatEncoder,
|
| 366 |
+
"conv": ConvEncoder,
|
| 367 |
+
"delta": DeltaEncoder,
|
| 368 |
+
},
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
class SpikeTemporalBlock2D(nn.Module):
|
| 374 |
+
def __init__(
|
| 375 |
+
self,
|
| 376 |
+
n_inputs,
|
| 377 |
+
n_outputs,
|
| 378 |
+
kernel_size,
|
| 379 |
+
stride,
|
| 380 |
+
dilation,
|
| 381 |
+
padding,
|
| 382 |
+
num_steps=4,
|
| 383 |
+
):
|
| 384 |
+
super().__init__()
|
| 385 |
+
self.num_steps = num_steps
|
| 386 |
+
self.conv1 = weight_norm(
|
| 387 |
+
nn.Conv2d(
|
| 388 |
+
n_inputs,
|
| 389 |
+
n_outputs,
|
| 390 |
+
(1, kernel_size),
|
| 391 |
+
stride=stride,
|
| 392 |
+
padding=(0, padding),
|
| 393 |
+
dilation=(1, dilation),
|
| 394 |
+
)
|
| 395 |
+
)
|
| 396 |
+
self.bn1 = nn.BatchNorm2d(n_outputs)
|
| 397 |
+
self.chomp1 = Chomp2d(padding)
|
| 398 |
+
self.lif1 = snn.Leaky(
|
| 399 |
+
beta=0.99,
|
| 400 |
+
spike_grad=surrogate.atan(alpha=2.0),
|
| 401 |
+
init_hidden=True,
|
| 402 |
+
threshold=1.0,
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
self.conv2 = weight_norm(
|
| 406 |
+
nn.Conv2d(
|
| 407 |
+
n_outputs,
|
| 408 |
+
n_outputs,
|
| 409 |
+
(1, kernel_size),
|
| 410 |
+
stride=stride,
|
| 411 |
+
padding=(0, padding),
|
| 412 |
+
dilation=(1, dilation),
|
| 413 |
+
)
|
| 414 |
+
)
|
| 415 |
+
self.bn2 = nn.BatchNorm2d(n_outputs)
|
| 416 |
+
self.chomp2 = Chomp2d(padding)
|
| 417 |
+
self.lif2 = snn.Leaky(
|
| 418 |
+
beta=0.99,
|
| 419 |
+
spike_grad=surrogate.atan(alpha=2.0),
|
| 420 |
+
init_hidden=True,
|
| 421 |
+
threshold=1.0,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
self.downsample = (
|
| 425 |
+
nn.Conv2d(n_inputs, n_outputs, (1, 1)) if n_inputs != n_outputs else None
|
| 426 |
+
)
|
| 427 |
+
self.lif = snn.Leaky(
|
| 428 |
+
beta=0.99,
|
| 429 |
+
spike_grad=surrogate.atan(alpha=2.0),
|
| 430 |
+
init_hidden=True,
|
| 431 |
+
threshold=1.0,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
def init_weights(self):
|
| 435 |
+
self.conv1.weight.data.normal_(0, 0.01)
|
| 436 |
+
self.conv2.weight.data.normal_(0, 0.01)
|
| 437 |
+
if self.downsample is not None:
|
| 438 |
+
self.downsample.weight.data.normal_(0, 0.01)
|
| 439 |
+
|
| 440 |
+
def forward(self, x):
|
| 441 |
+
out1 = self.chomp1(self.bn1(self.conv1(x)))
|
| 442 |
+
spk_rec1 = []
|
| 443 |
+
for _ in range(self.num_steps):
|
| 444 |
+
spk = self.lif1(out1)
|
| 445 |
+
spk_rec1.append(spk)
|
| 446 |
+
spks1 = torch.stack(spk_rec1, dim=-1) # spks1: B, H, C, L, T
|
| 447 |
+
spks1 = spks1.mean(-1) # spks1: B, H, C, L
|
| 448 |
+
|
| 449 |
+
out2 = self.chomp2(self.bn2(self.conv2(spks1)))
|
| 450 |
+
spk_rec2 = []
|
| 451 |
+
for _ in range(self.num_steps):
|
| 452 |
+
spk = self.lif2(out2)
|
| 453 |
+
spk_rec2.append(spk)
|
| 454 |
+
spks2 = torch.stack(spk_rec2, dim=-1) # spks2: B, H, C, L, T
|
| 455 |
+
spks2 = spks2.mean(-1) # spks2: B, H, C, L
|
| 456 |
+
|
| 457 |
+
if torch.isnan(spks2).any() or torch.isinf(spks2).any():
|
| 458 |
+
print("illegal value in TemporalBlock2D")
|
| 459 |
+
|
| 460 |
+
if self.downsample is None:
|
| 461 |
+
res = x
|
| 462 |
+
else:
|
| 463 |
+
res = self.downsample(x)
|
| 464 |
+
spk_rec3 = []
|
| 465 |
+
for _ in range(self.num_steps):
|
| 466 |
+
spk = self.lif(spks2 + res)
|
| 467 |
+
spk_rec3.append(spk)
|
| 468 |
+
|
| 469 |
+
res = torch.stack(spk_rec3, dim=-1) # res: B, H, C, L, T
|
| 470 |
+
res = res.mean(-1)
|
| 471 |
+
|
| 472 |
+
return res
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
class SpikeTCN_CPG(nn.Module):
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def __init__(
|
| 479 |
+
self,
|
| 480 |
+
args,
|
| 481 |
+
num_levels: int=3,
|
| 482 |
+
channel: int=16,
|
| 483 |
+
dilation: int=2,
|
| 484 |
+
stride: int = 1,
|
| 485 |
+
num_steps: int = 16,
|
| 486 |
+
kernel_size: int = 2,
|
| 487 |
+
dropout: float = 0.2,
|
| 488 |
+
max_length: int = 100,
|
| 489 |
+
input_size: Optional[int] = None,
|
| 490 |
+
hidden_size: int = 128,
|
| 491 |
+
encoder_type: Optional[str] = "conv",
|
| 492 |
+
num_pe_neuron: int = 40,
|
| 493 |
+
pe_type: str = "neuron",
|
| 494 |
+
pe_mode: str = "concat", # "add" or "concat"
|
| 495 |
+
neuron_pe_scale: float = 10000.0, # "100" or "1000" or "10000"
|
| 496 |
+
):
|
| 497 |
+
"""
|
| 498 |
+
Args:
|
| 499 |
+
num_channels: The number of convolutional channels in each layer.
|
| 500 |
+
kernel_size: The kernel size of convolutional layers.
|
| 501 |
+
dropout: Dropout rate.
|
| 502 |
+
"""
|
| 503 |
+
super().__init__()
|
| 504 |
+
self.pe_type = pe_type
|
| 505 |
+
self._snn_backend = "snntorch"
|
| 506 |
+
self.pe_mode = pe_mode
|
| 507 |
+
self.num_pe_neuron = num_pe_neuron
|
| 508 |
+
self.hidden_size = args.hidden_size
|
| 509 |
+
self.num_steps = args.T
|
| 510 |
+
self.input_size = args.feature_size
|
| 511 |
+
self.pre_length = args.pre_length
|
| 512 |
+
self.num_levels = args.blocks
|
| 513 |
+
self.pe_type = pe_type
|
| 514 |
+
self.pe_mode = pe_mode
|
| 515 |
+
self.num_pe_neuron = num_pe_neuron
|
| 516 |
+
self.kernel_size = args.kernel_size
|
| 517 |
+
|
| 518 |
+
self.encoder = SpikeEncoder[self._snn_backend][encoder_type](self.hidden_size)
|
| 519 |
+
self.args = args
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
self.pe = PositionEmbedding(
|
| 523 |
+
pe_type=pe_type,
|
| 524 |
+
pe_mode=pe_mode,
|
| 525 |
+
neuron_pe_scale=neuron_pe_scale,
|
| 526 |
+
input_size=self.input_size,
|
| 527 |
+
max_len=max_length,
|
| 528 |
+
num_pe_neuron=self.num_pe_neuron,
|
| 529 |
+
dropout=0.1,
|
| 530 |
+
num_steps=self.num_steps,
|
| 531 |
+
)
|
| 532 |
+
layers = []
|
| 533 |
+
num_channels = [channel] * self.num_levels
|
| 534 |
+
num_channels.append(1)
|
| 535 |
+
for i in range(self.num_levels + 1):
|
| 536 |
+
dilation_size = dilation**i
|
| 537 |
+
in_channels = self.hidden_size if i == 0 else num_channels[i - 1]
|
| 538 |
+
out_channels = num_channels[i]
|
| 539 |
+
layers += [
|
| 540 |
+
SpikeTemporalBlock2D(
|
| 541 |
+
in_channels,
|
| 542 |
+
out_channels,
|
| 543 |
+
self.kernel_size,
|
| 544 |
+
stride=stride,
|
| 545 |
+
dilation=dilation_size,
|
| 546 |
+
padding=(self.kernel_size - 1) * dilation_size,
|
| 547 |
+
num_steps=self.num_steps,
|
| 548 |
+
)
|
| 549 |
+
]
|
| 550 |
+
|
| 551 |
+
self.network = nn.Sequential(*layers)
|
| 552 |
+
if (self.pe_type == "neuron" and self.pe_mode == "concat") or (
|
| 553 |
+
self.pe_type == "random" and self.pe_mode == "concat"
|
| 554 |
+
):
|
| 555 |
+
self.__output_size = args.feature_size + num_pe_neuron
|
| 556 |
+
else:
|
| 557 |
+
self.__output_size = args.seq_length
|
| 558 |
+
|
| 559 |
+
self.fc1 = nn.Linear(self.__output_size, args.feature_size)
|
| 560 |
+
self.fc2 = nn.Linear(args.seq_length, self.pre_length)
|
| 561 |
+
self.to('cuda:0')
|
| 562 |
+
|
| 563 |
+
def forward(self, inputs: torch.Tensor):
|
| 564 |
+
utils.reset(self.encoder)
|
| 565 |
+
for layer in self.network:
|
| 566 |
+
utils.reset(layer)
|
| 567 |
+
|
| 568 |
+
if self.args.normalize:
|
| 569 |
+
mean = inputs.mean(dim=1, keepdim=True).detach() # shape [B, 1, D]
|
| 570 |
+
inputs = inputs - mean
|
| 571 |
+
|
| 572 |
+
std = torch.sqrt(torch.var(inputs, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
| 573 |
+
inputs = inputs / std
|
| 574 |
+
|
| 575 |
+
inputs = self.encoder(inputs) # B, H, C, L
|
| 576 |
+
if self.pe_type != "none":
|
| 577 |
+
# B, H, C, L -> H B L C' -> B H C' L
|
| 578 |
+
inputs = self.pe(inputs.permute(1, 0, 3, 2)).permute(1, 0, 3, 2)
|
| 579 |
+
spks = self.network(inputs)
|
| 580 |
+
spks = spks.squeeze(1) # B, C', L
|
| 581 |
+
|
| 582 |
+
preds = self.fc1(spks.permute(0, 2, 1)) # B, L, C
|
| 583 |
+
preds = self.fc2(preds.permute(0, 2, 1)) # B, C', L
|
| 584 |
+
#.squeeze(-1) # B, O, C'
|
| 585 |
+
preds = preds.permute(0, 2, 1).contiguous()
|
| 586 |
+
if self.args.normalize:
|
| 587 |
+
preds = preds * std + mean # denormalize
|
| 588 |
+
aux = {'gate_l0': torch.tensor(0.0, device=preds.device)} # placeholder
|
| 589 |
+
|
| 590 |
+
return preds, aux
|
| 591 |
+
|
| 592 |
+
@property
|
| 593 |
+
def output_size(self):
|
| 594 |
+
return self.__output_size
|
| 595 |
+
|
| 596 |
+
|
model/Spikformer_CPG.py
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from spikingjelly.activation_based import surrogate, neuron, functional
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
import warnings
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
tau = 2.0 # beta = 1 - 1/tau
|
| 15 |
+
backend = "torch"
|
| 16 |
+
detach_reset = True
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class CPG(nn.Module):
|
| 22 |
+
num_neurons: int = 40
|
| 23 |
+
w_max: float = 10000.0
|
| 24 |
+
l_max: int = 5000
|
| 25 |
+
|
| 26 |
+
def __post_init__(self):
|
| 27 |
+
self._cpg = torch.zeros(self.l_max, self.num_neurons)
|
| 28 |
+
position = torch.arange(0, self.l_max, dtype=torch.float).unsqueeze(
|
| 29 |
+
1
|
| 30 |
+
) # MaxL, 1
|
| 31 |
+
div_term = torch.exp(
|
| 32 |
+
torch.arange(0, self.num_neurons, 2).float()
|
| 33 |
+
* (-math.log(self.w_max) / self.num_neurons)
|
| 34 |
+
)
|
| 35 |
+
div_term_single = torch.exp(
|
| 36 |
+
torch.arange(0, self.num_neurons - 1, 2).float()
|
| 37 |
+
* (-math.log(self.w_max) / self.num_neurons)
|
| 38 |
+
)
|
| 39 |
+
self._cpg[:, 0::2] = torch.heaviside(
|
| 40 |
+
torch.sin(position * div_term) - 0.8, torch.tensor([1.0])
|
| 41 |
+
)
|
| 42 |
+
self._cpg[:, 1::2] = torch.heaviside(
|
| 43 |
+
torch.cos(position * div_term_single) - 0.8, torch.tensor([1.0])
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def cpg(self):
|
| 48 |
+
return self._cpg
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class CPGLinear(nn.Module):
|
| 52 |
+
def __init__(
|
| 53 |
+
self, input_size: int, output_size: int, cpg: CPG = CPG(), dropout: float = 0.1
|
| 54 |
+
):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.cpg = nn.Parameter(cpg.cpg, requires_grad=False)
|
| 57 |
+
self.inp_linear = nn.Linear(input_size, output_size)
|
| 58 |
+
self.cpg_linear = nn.Linear(cpg.num_neurons, output_size)
|
| 59 |
+
self.dropout = nn.Dropout(dropout)
|
| 60 |
+
|
| 61 |
+
def forward(self, x: torch.Tensor):
|
| 62 |
+
# B TL D
|
| 63 |
+
cpg = self.cpg[: x.size(-2)]
|
| 64 |
+
x = self.dropout(x)
|
| 65 |
+
return self.inp_linear(x) + self.cpg_linear(cpg)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class RepeatEncoder(nn.Module):
|
| 71 |
+
def __init__(self, output_size: int):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.out_size = output_size
|
| 74 |
+
self.lif = neuron.LIFNode(
|
| 75 |
+
tau=tau,
|
| 76 |
+
step_mode="m",
|
| 77 |
+
detach_reset=detach_reset,
|
| 78 |
+
surrogate_function=surrogate.ATan(),
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def forward(self, inputs: torch.Tensor):
|
| 82 |
+
# inputs: B, L, C
|
| 83 |
+
inputs = inputs.repeat(
|
| 84 |
+
tuple([self.out_size] + torch.ones(len(inputs.size()), dtype=int).tolist())
|
| 85 |
+
) # T B L C
|
| 86 |
+
inputs = inputs.permute(0, 1, 3, 2) # T B C L
|
| 87 |
+
spks = self.lif(inputs) # T B C L
|
| 88 |
+
return spks
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class DeltaEncoder(nn.Module):
|
| 92 |
+
def __init__(self, output_size: int):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.norm = nn.BatchNorm2d(1)
|
| 95 |
+
self.enc = nn.Linear(1, output_size)
|
| 96 |
+
self.lif = neuron.LIFNode(
|
| 97 |
+
tau=tau,
|
| 98 |
+
step_mode="m",
|
| 99 |
+
detach_reset=detach_reset,
|
| 100 |
+
surrogate_function=surrogate.ATan(),
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def forward(self, inputs: torch.Tensor):
|
| 104 |
+
# inputs: B, L, C
|
| 105 |
+
delta = torch.zeros_like(inputs)
|
| 106 |
+
delta[:, 1:] = inputs[:, 1:, :] - inputs[:, :-1, :]
|
| 107 |
+
delta = delta.unsqueeze(1).permute(0, 1, 3, 2) # B, 1, C, L
|
| 108 |
+
delta = self.norm(delta)
|
| 109 |
+
delta = delta.permute(0, 2, 3, 1) # B, C, L, 1
|
| 110 |
+
enc = self.enc(delta) # B, C, L, T
|
| 111 |
+
enc = enc.permute(3, 0, 1, 2) # T, B, C, L
|
| 112 |
+
spks = self.lif(enc)
|
| 113 |
+
return spks
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class ConvEncoder(nn.Module):
|
| 117 |
+
def __init__(self, output_size: int, kernel_size: int = 3):
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.encoder = nn.Sequential(
|
| 120 |
+
nn.Conv2d(
|
| 121 |
+
in_channels=1,
|
| 122 |
+
out_channels=output_size,
|
| 123 |
+
kernel_size=(1, kernel_size),
|
| 124 |
+
stride=1,
|
| 125 |
+
padding=(0, kernel_size // 2),
|
| 126 |
+
),
|
| 127 |
+
nn.BatchNorm2d(output_size),
|
| 128 |
+
)
|
| 129 |
+
self.lif = neuron.LIFNode(
|
| 130 |
+
tau=tau,
|
| 131 |
+
step_mode="m",
|
| 132 |
+
detach_reset=detach_reset,
|
| 133 |
+
surrogate_function=surrogate.ATan(),
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def forward(self, inputs: torch.Tensor):
|
| 137 |
+
# inputs: B, L, C
|
| 138 |
+
inputs = inputs.permute(0, 2, 1).unsqueeze(1) # B, 1, C, L
|
| 139 |
+
enc = self.encoder(inputs) # B, T, C, L
|
| 140 |
+
enc = enc.permute(1, 0, 2, 3) # T, B, C, L
|
| 141 |
+
spks = self.lif(enc) # T, B, C, L
|
| 142 |
+
return spks
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
SpikeEncoder = {
|
| 147 |
+
"snntorch": {
|
| 148 |
+
"repeat": RepeatEncoder,
|
| 149 |
+
"conv": ConvEncoder,
|
| 150 |
+
"delta": DeltaEncoder,
|
| 151 |
+
},
|
| 152 |
+
"spikingjelly": {
|
| 153 |
+
"repeat": RepeatEncoder,
|
| 154 |
+
"conv": ConvEncoder,
|
| 155 |
+
"delta": DeltaEncoder,
|
| 156 |
+
},
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class SSA(nn.Module):
|
| 164 |
+
def __init__(
|
| 165 |
+
self, length, tau, common_thr, dim, heads=8, qkv_bias=False, qk_scale=0.25
|
| 166 |
+
):
|
| 167 |
+
super().__init__()
|
| 168 |
+
assert dim % heads == 0, f"dim {dim} should be divided by num_heads {heads}."
|
| 169 |
+
|
| 170 |
+
self.dim = dim
|
| 171 |
+
self.heads = heads
|
| 172 |
+
self.qk_scale = qk_scale
|
| 173 |
+
|
| 174 |
+
self.q_m = nn.Linear(dim, dim)
|
| 175 |
+
self.q_bn = nn.BatchNorm1d(dim)
|
| 176 |
+
self.q_lif = neuron.LIFNode(
|
| 177 |
+
tau=tau,
|
| 178 |
+
step_mode="m",
|
| 179 |
+
detach_reset=detach_reset,
|
| 180 |
+
surrogate_function=surrogate.ATan(),
|
| 181 |
+
v_threshold=common_thr,
|
| 182 |
+
backend=backend,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
self.k_m = nn.Linear(dim, dim)
|
| 186 |
+
self.k_bn = nn.BatchNorm1d(dim)
|
| 187 |
+
self.k_lif = neuron.LIFNode(
|
| 188 |
+
tau=tau,
|
| 189 |
+
step_mode="m",
|
| 190 |
+
detach_reset=detach_reset,
|
| 191 |
+
surrogate_function=surrogate.ATan(),
|
| 192 |
+
v_threshold=common_thr,
|
| 193 |
+
backend=backend,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
self.v_m = nn.Linear(dim, dim)
|
| 197 |
+
self.v_bn = nn.BatchNorm1d(dim)
|
| 198 |
+
self.v_lif = neuron.LIFNode(
|
| 199 |
+
tau=tau,
|
| 200 |
+
step_mode="m",
|
| 201 |
+
detach_reset=detach_reset,
|
| 202 |
+
surrogate_function=surrogate.ATan(),
|
| 203 |
+
v_threshold=common_thr,
|
| 204 |
+
backend=backend,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
self.attn_lif = neuron.LIFNode(
|
| 208 |
+
tau=tau,
|
| 209 |
+
step_mode="m",
|
| 210 |
+
detach_reset=detach_reset,
|
| 211 |
+
surrogate_function=surrogate.ATan(),
|
| 212 |
+
v_threshold=common_thr / 2,
|
| 213 |
+
backend=backend,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
self.last_m = nn.Linear(dim, dim)
|
| 217 |
+
self.last_bn = nn.BatchNorm1d(dim)
|
| 218 |
+
self.last_lif = neuron.LIFNode(
|
| 219 |
+
tau=tau,
|
| 220 |
+
step_mode="m",
|
| 221 |
+
detach_reset=detach_reset,
|
| 222 |
+
surrogate_function=surrogate.ATan(),
|
| 223 |
+
v_threshold=common_thr,
|
| 224 |
+
backend=backend,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
def forward(self, x):
|
| 228 |
+
T, B, L, D = x.shape
|
| 229 |
+
x_for_qkv = x.flatten(0, 1) # TB L D
|
| 230 |
+
q_m_out = self.q_m(x_for_qkv) # TB L D
|
| 231 |
+
q_m_out = (
|
| 232 |
+
self.q_bn(q_m_out.transpose(-1, -2))
|
| 233 |
+
.transpose(-1, -2)
|
| 234 |
+
.reshape(T, B, L, D)
|
| 235 |
+
.contiguous()
|
| 236 |
+
)
|
| 237 |
+
q_m_out = self.q_lif(q_m_out)
|
| 238 |
+
q = (
|
| 239 |
+
q_m_out.reshape(T, B, L, self.heads, D // self.heads)
|
| 240 |
+
.permute(0, 1, 3, 2, 4)
|
| 241 |
+
.contiguous()
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
k_m_out = self.k_m(x_for_qkv)
|
| 245 |
+
k_m_out = (
|
| 246 |
+
self.k_bn(k_m_out.transpose(-1, -2))
|
| 247 |
+
.transpose(-1, -2)
|
| 248 |
+
.reshape(T, B, L, D)
|
| 249 |
+
.contiguous()
|
| 250 |
+
)
|
| 251 |
+
k_m_out = self.k_lif(k_m_out)
|
| 252 |
+
k = (
|
| 253 |
+
k_m_out.reshape(T, B, L, self.heads, D // self.heads)
|
| 254 |
+
.permute(0, 1, 3, 2, 4)
|
| 255 |
+
.contiguous()
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
v_m_out = self.v_m(x_for_qkv)
|
| 259 |
+
v_m_out = (
|
| 260 |
+
self.v_bn(v_m_out.transpose(-1, -2))
|
| 261 |
+
.transpose(-1, -2)
|
| 262 |
+
.reshape(T, B, L, D)
|
| 263 |
+
.contiguous()
|
| 264 |
+
)
|
| 265 |
+
v_m_out = self.v_lif(v_m_out)
|
| 266 |
+
v = (
|
| 267 |
+
v_m_out.reshape(T, B, L, self.heads, D // self.heads)
|
| 268 |
+
.permute(0, 1, 3, 2, 4)
|
| 269 |
+
.contiguous()
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
attn = (q @ k.transpose(-2, -1)) * self.qk_scale
|
| 273 |
+
x = attn @ v # x_shape: T * B * heads * L * D//heads
|
| 274 |
+
|
| 275 |
+
x = x.transpose(2, 3).reshape(T, B, L, D).contiguous()
|
| 276 |
+
x = self.attn_lif(x)
|
| 277 |
+
|
| 278 |
+
x = x.flatten(0, 1)
|
| 279 |
+
x = self.last_m(x)
|
| 280 |
+
x = self.last_bn(x.transpose(-1, -2)).transpose(-1, -2)
|
| 281 |
+
x = self.last_lif(x.reshape(T, B, L, D).contiguous())
|
| 282 |
+
return x
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class MLP(nn.Module):
|
| 286 |
+
def __init__(
|
| 287 |
+
self,
|
| 288 |
+
length,
|
| 289 |
+
tau,
|
| 290 |
+
common_thr,
|
| 291 |
+
in_features,
|
| 292 |
+
hidden_features=None,
|
| 293 |
+
out_features=None,
|
| 294 |
+
):
|
| 295 |
+
super().__init__()
|
| 296 |
+
out_features = out_features or in_features
|
| 297 |
+
self.in_features = in_features
|
| 298 |
+
self.hidden_features = hidden_features
|
| 299 |
+
self.out_features = out_features
|
| 300 |
+
|
| 301 |
+
self.fc1 = CPGLinear(in_features, hidden_features)
|
| 302 |
+
self.bn1 = nn.BatchNorm1d(hidden_features)
|
| 303 |
+
self.lif1 = neuron.LIFNode(
|
| 304 |
+
tau=tau,
|
| 305 |
+
step_mode="m",
|
| 306 |
+
detach_reset=detach_reset,
|
| 307 |
+
surrogate_function=surrogate.ATan(),
|
| 308 |
+
v_threshold=common_thr,
|
| 309 |
+
backend=backend,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
self.fc2 = CPGLinear(hidden_features, out_features)
|
| 313 |
+
self.bn2 = nn.BatchNorm1d(out_features)
|
| 314 |
+
self.lif2 = neuron.LIFNode(
|
| 315 |
+
tau=tau,
|
| 316 |
+
step_mode="m",
|
| 317 |
+
detach_reset=detach_reset,
|
| 318 |
+
surrogate_function=surrogate.ATan(),
|
| 319 |
+
v_threshold=common_thr,
|
| 320 |
+
backend=backend,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
def forward(self, x):
|
| 324 |
+
T, B, L, D = x.shape
|
| 325 |
+
x = x.transpose(0, 1).flatten(1, 2) # B TL D
|
| 326 |
+
x = self.fc1(x) # B TL H
|
| 327 |
+
x = (
|
| 328 |
+
self.bn1(x.transpose(-1, -2))
|
| 329 |
+
.transpose(-1, -2)
|
| 330 |
+
.reshape(B, T, L, self.hidden_features)
|
| 331 |
+
.contiguous()
|
| 332 |
+
) # B T L H
|
| 333 |
+
x = self.lif1(x.transpose(0, 1)).transpose(0, 1) # B T L H
|
| 334 |
+
x = x.flatten(1, 2) # B TL H
|
| 335 |
+
x = self.fc2(x) # B TL D
|
| 336 |
+
x = (
|
| 337 |
+
self.bn2(x.transpose(-1, -2))
|
| 338 |
+
.transpose(-1, -2)
|
| 339 |
+
.reshape(B, T, L, D)
|
| 340 |
+
.contiguous()
|
| 341 |
+
) # B T L D
|
| 342 |
+
x = self.lif2(x.transpose(0, 1)) # T B L D
|
| 343 |
+
return x
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class Block(nn.Module):
|
| 347 |
+
def __init__(
|
| 348 |
+
self,
|
| 349 |
+
length,
|
| 350 |
+
tau,
|
| 351 |
+
common_thr,
|
| 352 |
+
dim,
|
| 353 |
+
d_ff,
|
| 354 |
+
heads=8,
|
| 355 |
+
qkv_bias=False,
|
| 356 |
+
qk_scale=0.125,
|
| 357 |
+
):
|
| 358 |
+
super().__init__()
|
| 359 |
+
self.attn = SSA(
|
| 360 |
+
length=length,
|
| 361 |
+
tau=tau,
|
| 362 |
+
common_thr=common_thr,
|
| 363 |
+
dim=dim,
|
| 364 |
+
heads=heads,
|
| 365 |
+
qkv_bias=qkv_bias,
|
| 366 |
+
qk_scale=qk_scale,
|
| 367 |
+
)
|
| 368 |
+
self.mlp = MLP(
|
| 369 |
+
length=length,
|
| 370 |
+
tau=tau,
|
| 371 |
+
common_thr=common_thr,
|
| 372 |
+
in_features=dim,
|
| 373 |
+
hidden_features=d_ff,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
def forward(self, x):
|
| 377 |
+
# T B L D
|
| 378 |
+
x = x + self.attn(x)
|
| 379 |
+
x = x + self.mlp(x)
|
| 380 |
+
return x
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
class Spikformer_CPG(nn.Module):
|
| 384 |
+
def __init__(
|
| 385 |
+
self,
|
| 386 |
+
args,
|
| 387 |
+
dim: int=256,
|
| 388 |
+
d_ff: Optional[int] = None,
|
| 389 |
+
num_pe_neuron: int = 40,
|
| 390 |
+
pe_type: str = "neuron",
|
| 391 |
+
pe_mode: str = "concat", # "add" or concat
|
| 392 |
+
neuron_pe_scale: float = 10000.0, # "100" or "1000" or "10000"
|
| 393 |
+
depths: int = 2,
|
| 394 |
+
common_thr: float = 1.0,
|
| 395 |
+
max_length: int = 5000,
|
| 396 |
+
num_steps: int = 4,
|
| 397 |
+
heads: int = 8,
|
| 398 |
+
qkv_bias: bool = False,
|
| 399 |
+
qk_scale: float = 0.125,
|
| 400 |
+
input_size: Optional[int] = None,
|
| 401 |
+
weight_file: Optional[Path] = None,
|
| 402 |
+
):
|
| 403 |
+
super().__init__()
|
| 404 |
+
self.dim = 256
|
| 405 |
+
self.d_ff = 1024
|
| 406 |
+
self.T = args.T
|
| 407 |
+
self.depths = args.blocks
|
| 408 |
+
self.pe_type = pe_type
|
| 409 |
+
self.pe_mode = pe_mode
|
| 410 |
+
self.num_pe_neuron = num_pe_neuron
|
| 411 |
+
self.input_size = args.feature_size
|
| 412 |
+
self.pre_length = args.pre_length
|
| 413 |
+
self.args = args
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
self._snn_backend = "spikingjelly"
|
| 417 |
+
|
| 418 |
+
self.temporal_encoder = SpikeEncoder[self._snn_backend]["conv"](num_steps)
|
| 419 |
+
self.encoder = CPGLinear(self.input_size, dim, CPG(num_neurons=num_pe_neuron))
|
| 420 |
+
|
| 421 |
+
self.init_lif = neuron.LIFNode(
|
| 422 |
+
tau=tau,
|
| 423 |
+
step_mode="m",
|
| 424 |
+
detach_reset=detach_reset,
|
| 425 |
+
surrogate_function=surrogate.ATan(),
|
| 426 |
+
v_threshold=common_thr,
|
| 427 |
+
backend=backend,
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
self.blocks = nn.ModuleList(
|
| 431 |
+
[
|
| 432 |
+
Block(
|
| 433 |
+
length=max_length,
|
| 434 |
+
tau=tau,
|
| 435 |
+
common_thr=common_thr,
|
| 436 |
+
dim=dim,
|
| 437 |
+
d_ff=self.d_ff,
|
| 438 |
+
heads=heads,
|
| 439 |
+
qkv_bias=qkv_bias,
|
| 440 |
+
qk_scale=qk_scale,
|
| 441 |
+
)
|
| 442 |
+
for _ in range(depths)
|
| 443 |
+
]
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
self.apply(self._init_weights)
|
| 447 |
+
|
| 448 |
+
self.fc = nn.Linear(args.seq_length*dim, args.pre_length*args.feature_size)
|
| 449 |
+
|
| 450 |
+
def _init_weights(self, m):
|
| 451 |
+
if isinstance(m, nn.Linear):
|
| 452 |
+
nn.init.normal_(m.weight, std=0.02)
|
| 453 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 454 |
+
nn.init.constant_(m.bias, 0.0)
|
| 455 |
+
elif isinstance(m, nn.LayerNorm):
|
| 456 |
+
nn.init.constant_(m.weight, 1.0)
|
| 457 |
+
nn.init.constant_(m.bias, 0.0)
|
| 458 |
+
|
| 459 |
+
def forward(self, x: torch.Tensor):
|
| 460 |
+
functional.reset_net(self)
|
| 461 |
+
|
| 462 |
+
if self.args.normalize:
|
| 463 |
+
|
| 464 |
+
mean = x.mean(dim=1, keepdim=True).detach() # shape [B, 1, D]
|
| 465 |
+
x = x - mean
|
| 466 |
+
|
| 467 |
+
std = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
| 468 |
+
x = x / std
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
x = self.temporal_encoder(x) # B L C -> T B C L
|
| 472 |
+
T, B, _, L = x.shape
|
| 473 |
+
x = x.permute(1, 0, 3, 2) # B T L C
|
| 474 |
+
x = x.flatten(1, 2) # B TL C
|
| 475 |
+
x = self.encoder(x) # B TL D
|
| 476 |
+
x = x.reshape(B, T, L, -1).permute(1, 0, 2, 3) # T B L D
|
| 477 |
+
x = self.init_lif(x)
|
| 478 |
+
|
| 479 |
+
for blk in self.blocks:
|
| 480 |
+
x = blk(x) # T B L D
|
| 481 |
+
out = x.mean(0)
|
| 482 |
+
out = self.fc(out.flatten(-2, -1)).reshape(-1, self.pre_length, self.input_size) # B D L -> B L D
|
| 483 |
+
if self.args.normalize:
|
| 484 |
+
out = out * std + mean # denormalization
|
| 485 |
+
aux = {'gate_l0': torch.tensor(0.0, device=out.device)} # placeholder
|
| 486 |
+
return out, aux # B D L -> B L D
|
| 487 |
+
|
model/TS_Former.py
ADDED
|
@@ -0,0 +1,1365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Callable
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from spikingjelly.activation_based import surrogate, neuron, functional
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
import copy
|
| 10 |
+
from spikingjelly.activation_based import surrogate, neuron
|
| 11 |
+
from abc import abstractmethod
|
| 12 |
+
import snntorch as snn
|
| 13 |
+
from snntorch import utils
|
| 14 |
+
import warnings
|
| 15 |
+
|
| 16 |
+
surrogate.ATan = lambda alpha=2.0: SG.apply
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def generate_ones_and_minus_ones_matrix(rows, cols):
|
| 20 |
+
random_matrix = torch.randint(0, 2, (rows, cols))
|
| 21 |
+
binary_matrix = torch.where(
|
| 22 |
+
random_matrix == 0,
|
| 23 |
+
-1 * torch.ones_like(random_matrix),
|
| 24 |
+
torch.ones_like(random_matrix),
|
| 25 |
+
)
|
| 26 |
+
return binary_matrix.float()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class RandomPE(nn.Module):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
d_model,
|
| 33 |
+
pe_mode="concat",
|
| 34 |
+
num_pe_neuron=10,
|
| 35 |
+
neuron_pe_scale=1000.0,
|
| 36 |
+
dropout=0.1,
|
| 37 |
+
num_steps=4,
|
| 38 |
+
):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.max_len = 5000 # different from windows
|
| 41 |
+
self.pe_mode = pe_mode
|
| 42 |
+
self.neuron_pe_scale = neuron_pe_scale
|
| 43 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 44 |
+
if self.pe_mode == "concat":
|
| 45 |
+
self.num_pe_neuron = copy.deepcopy(num_pe_neuron)
|
| 46 |
+
elif self.pe_mode == "add":
|
| 47 |
+
self.num_pe_neuron = copy.deepcopy(d_model)
|
| 48 |
+
pe = generate_ones_and_minus_ones_matrix(
|
| 49 |
+
self.max_len, self.num_pe_neuron
|
| 50 |
+
) # MaxL, Neur
|
| 51 |
+
pe = pe.unsqueeze(0).transpose(0, 1) # MaxL, 1, Neur
|
| 52 |
+
print("pe.shape: ", pe.shape)
|
| 53 |
+
self.register_buffer("pe", pe)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
# T, B, L, D
|
| 57 |
+
T, B, L, _ = x.shape
|
| 58 |
+
x = x.permute(1, 0, 2, 3) # B, T, L, D
|
| 59 |
+
x = x.flatten(1, 2) # B, TL, D
|
| 60 |
+
if self.pe_mode == "concat":
|
| 61 |
+
# tmp: TL, 1, Neur -> TL, B, Neur -> B, TL, Neur
|
| 62 |
+
tmp = self.pe[: x.size(-2), :].repeat(1, B, 1).transpose(0, 1)
|
| 63 |
+
x = torch.concat([x, tmp], dim=-1)
|
| 64 |
+
# print(x.shape) # B, TL, D'
|
| 65 |
+
elif self.pe_mode == "add":
|
| 66 |
+
# [B, TL, D] + [1, TL, Neur]
|
| 67 |
+
x = x + self.pe[: x.size(-2), :].transpose(0, 1)
|
| 68 |
+
# print(x.shape) # B, TL, D
|
| 69 |
+
x = x.transpose(0, 1) # TL, B D
|
| 70 |
+
x = x.reshape(T, L, B, -1) # T, L, B, D
|
| 71 |
+
x = x.permute(0, 2, 1, 3) # T, B, L, D
|
| 72 |
+
return self.dropout(x)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class NeuronPE(nn.Module):
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
d_model,
|
| 79 |
+
pe_mode="concat",
|
| 80 |
+
num_pe_neuron=10,
|
| 81 |
+
neuron_pe_scale=10000.0,
|
| 82 |
+
dropout=0.1,
|
| 83 |
+
num_steps=4,
|
| 84 |
+
):
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.max_len = 50000 # different from windows
|
| 87 |
+
self.pe_mode = pe_mode
|
| 88 |
+
self.neuron_pe_scale = neuron_pe_scale
|
| 89 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 90 |
+
if self.pe_mode == "concat":
|
| 91 |
+
self.num_pe_neuron = copy.deepcopy(num_pe_neuron)
|
| 92 |
+
elif self.pe_mode == "add":
|
| 93 |
+
self.num_pe_neuron = copy.deepcopy(d_model)
|
| 94 |
+
pe = torch.zeros(self.max_len, self.num_pe_neuron) # MaxL, Neur
|
| 95 |
+
position = torch.arange(0, self.max_len, dtype=torch.float).unsqueeze(
|
| 96 |
+
1
|
| 97 |
+
) # MaxL, 1
|
| 98 |
+
div_term = torch.exp(
|
| 99 |
+
torch.arange(0, self.num_pe_neuron, 2).float()
|
| 100 |
+
* (-math.log(neuron_pe_scale) / self.num_pe_neuron)
|
| 101 |
+
)
|
| 102 |
+
div_term_single = torch.exp(
|
| 103 |
+
torch.arange(0, self.num_pe_neuron - 1, 2).float()
|
| 104 |
+
* (-math.log(neuron_pe_scale) / self.num_pe_neuron)
|
| 105 |
+
)
|
| 106 |
+
pe[:, 0::2] = torch.heaviside(
|
| 107 |
+
torch.sin(position * div_term) - 0.8, torch.tensor([1.0])
|
| 108 |
+
)
|
| 109 |
+
pe[:, 1::2] = torch.heaviside(
|
| 110 |
+
torch.cos(position * div_term_single) - 0.8, torch.tensor([1.0])
|
| 111 |
+
)
|
| 112 |
+
pe = pe.unsqueeze(0).transpose(0, 1) # MaxL, 1, Neur
|
| 113 |
+
print("pe.shape: ", pe.shape)
|
| 114 |
+
self.register_buffer("pe", pe)
|
| 115 |
+
|
| 116 |
+
def forward(self, x):
|
| 117 |
+
# T, B, L, D
|
| 118 |
+
T, B, L, _ = x.shape
|
| 119 |
+
x = x.permute(1, 0, 2, 3) # B, T, L, D
|
| 120 |
+
x = x.flatten(1, 2) # B, TL, D
|
| 121 |
+
if self.pe_mode == "concat":
|
| 122 |
+
# tmp: TL, 1, Neur -> TL, B, Neur -> B, TL, Neur
|
| 123 |
+
tmp = self.pe[: x.size(-2), :].repeat(1, B, 1).transpose(0, 1)
|
| 124 |
+
x = torch.concat([x, tmp], dim=-1)
|
| 125 |
+
# print(x.shape) # B, TL, D'
|
| 126 |
+
elif self.pe_mode == "add":
|
| 127 |
+
# [B, TL, D] + [1, TL, Neur]
|
| 128 |
+
# print(self.pe[:x.size(-2), :].shape)
|
| 129 |
+
x = x + self.pe[: x.size(-2), :].transpose(0, 1)
|
| 130 |
+
# print(x.shape) # B, TL, D
|
| 131 |
+
x = x.transpose(0, 1) # TL, B D
|
| 132 |
+
x = x.reshape(T, L, B, -1) # T, L, B, D
|
| 133 |
+
x = x.permute(0, 2, 1, 3) # T, B, L, D
|
| 134 |
+
return self.dropout(x)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class StaticPE(nn.Module):
|
| 138 |
+
r"""Inject some information about the relative or absolute position of the tokens
|
| 139 |
+
in the sequence. The positional encodings have the same dimension as
|
| 140 |
+
the embeddings, so that the two can be summed. Here, we use sine and cosine
|
| 141 |
+
functions of different frequencies.
|
| 142 |
+
.. math::
|
| 143 |
+
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
|
| 144 |
+
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
|
| 145 |
+
\text{where pos is the word position and i is the embed idx)"""
|
| 146 |
+
|
| 147 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 150 |
+
pe = torch.zeros(max_len, d_model) # MaxL, D
|
| 151 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # MaxL, 1
|
| 152 |
+
div_term = torch.exp(
|
| 153 |
+
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
|
| 154 |
+
)
|
| 155 |
+
div_term_single = torch.exp(
|
| 156 |
+
torch.arange(0, d_model - 1, 2).float() * (-math.log(10000.0) / d_model)
|
| 157 |
+
)
|
| 158 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 159 |
+
pe[:, 1::2] = torch.cos(position * div_term_single)
|
| 160 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
| 161 |
+
self.register_buffer("pe", pe)
|
| 162 |
+
|
| 163 |
+
def forward(self, x):
|
| 164 |
+
# x: L, TB, D
|
| 165 |
+
x = x + self.pe[: x.size(0), :]
|
| 166 |
+
x = self.dropout(x)
|
| 167 |
+
return x
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class ConvPE(nn.Module):
|
| 171 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000, num_steps=4):
|
| 172 |
+
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.T = num_steps
|
| 175 |
+
self.rpe_conv = nn.Conv1d(
|
| 176 |
+
d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False
|
| 177 |
+
)
|
| 178 |
+
self.rpe_bn = nn.BatchNorm1d(d_model)
|
| 179 |
+
self.rpe_lif = neuron.LIFNode(
|
| 180 |
+
step_mode="m",
|
| 181 |
+
detach_reset=True,
|
| 182 |
+
surrogate_function=surrogate.ATan(),
|
| 183 |
+
v_threshold=1.0,
|
| 184 |
+
)
|
| 185 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 186 |
+
|
| 187 |
+
def forward(self, x):
|
| 188 |
+
# x: L, TB, D
|
| 189 |
+
L, TB, D = x.shape
|
| 190 |
+
x_feat = x.permute(1, 2, 0) # TB, D, L
|
| 191 |
+
x_feat = self.rpe_conv(x_feat) # TB, D, L
|
| 192 |
+
x_feat = (
|
| 193 |
+
self.rpe_bn(x_feat).reshape(self.T, int(TB / self.T), D, L).contiguous()
|
| 194 |
+
) # T, B, D, L
|
| 195 |
+
x_feat = self.rpe_lif(x_feat)
|
| 196 |
+
x_feat = x_feat.flatten(0, 1) # TB, D, L
|
| 197 |
+
x_feat = self.dropout(x_feat) # TB, D, L
|
| 198 |
+
x_feat = x_feat.permute(2, 0, 1) # L, TB, D
|
| 199 |
+
x = x + x_feat
|
| 200 |
+
return x
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class PositionEmbedding(nn.Module):
|
| 204 |
+
def __init__(
|
| 205 |
+
self,
|
| 206 |
+
input_size: int,
|
| 207 |
+
pe_type: str,
|
| 208 |
+
max_len: int = 5000,
|
| 209 |
+
pe_mode: str = "add",
|
| 210 |
+
num_pe_neuron: int = 10,
|
| 211 |
+
neuron_pe_scale: float = 1000.0,
|
| 212 |
+
dropout=0.1,
|
| 213 |
+
num_steps=4,
|
| 214 |
+
):
|
| 215 |
+
super().__init__()
|
| 216 |
+
self.emb_type = pe_type
|
| 217 |
+
if pe_type in ["learn", "none"]:
|
| 218 |
+
self.emb = nn.Embedding(max_len, input_size)
|
| 219 |
+
elif pe_type == "conv":
|
| 220 |
+
self.emb = ConvPE(
|
| 221 |
+
d_model=input_size,
|
| 222 |
+
max_len=max_len,
|
| 223 |
+
dropout=dropout,
|
| 224 |
+
num_steps=num_steps,
|
| 225 |
+
)
|
| 226 |
+
elif pe_type == "static":
|
| 227 |
+
self.emb = StaticPE(d_model=input_size, max_len=max_len, dropout=dropout)
|
| 228 |
+
elif pe_type == "neuron":
|
| 229 |
+
self.emb = NeuronPE(
|
| 230 |
+
d_model=input_size,
|
| 231 |
+
pe_mode=pe_mode,
|
| 232 |
+
num_pe_neuron=num_pe_neuron,
|
| 233 |
+
neuron_pe_scale=neuron_pe_scale,
|
| 234 |
+
dropout=dropout,
|
| 235 |
+
num_steps=num_steps,
|
| 236 |
+
)
|
| 237 |
+
elif pe_type == "random":
|
| 238 |
+
self.emb = RandomPE(
|
| 239 |
+
d_model=input_size,
|
| 240 |
+
pe_mode=pe_mode,
|
| 241 |
+
num_pe_neuron=num_pe_neuron,
|
| 242 |
+
neuron_pe_scale=neuron_pe_scale,
|
| 243 |
+
dropout=dropout,
|
| 244 |
+
num_steps=num_steps,
|
| 245 |
+
)
|
| 246 |
+
else:
|
| 247 |
+
raise ValueError("Unknown embedding type: {}".format(pe_type))
|
| 248 |
+
|
| 249 |
+
def forward(self, x):
|
| 250 |
+
if self.emb_type == "learn":
|
| 251 |
+
# T, B, L, D = x.shape # x: T, B, L, D
|
| 252 |
+
# x = x.flatten(0, 1) # TB, L, D
|
| 253 |
+
tmp = torch.arange(
|
| 254 |
+
end=x.size()[1], device=x.device
|
| 255 |
+
) # [0,1,2,...,L-1], shape: L
|
| 256 |
+
embedding = self.emb(tmp) # shape: L, D
|
| 257 |
+
embedding = embedding.repeat([x.size()[0], 1, 1]) # TB, L, D'
|
| 258 |
+
x = x + embedding
|
| 259 |
+
# x = x.reshape(T, B, L, -1)
|
| 260 |
+
elif self.emb_type in ["static", "conv"]:
|
| 261 |
+
T, B, L, _ = x.shape # x: T, B, L, D
|
| 262 |
+
x = x.flatten(0, 1) # TB, L, D
|
| 263 |
+
x = self.emb(x.transpose(0, 1)).transpose(0, 1) # x: TB, L, D'
|
| 264 |
+
x = x.reshape(T, B, L, -1)
|
| 265 |
+
elif self.emb_type in ["neuron", "random"]:
|
| 266 |
+
T, B, L, _ = x.shape # x: T, B, L, D
|
| 267 |
+
# T, B, L, D
|
| 268 |
+
x = self.emb(x)
|
| 269 |
+
x = x.reshape(T, B, L, -1)
|
| 270 |
+
return x # T, B, L, D'
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
tau = 2.0 # beta = 1 - 1/tau
|
| 274 |
+
backend = "torch"
|
| 275 |
+
detach_reset = True
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class RepeatEncoder(nn.Module):
|
| 279 |
+
def __init__(self, output_size: int):
|
| 280 |
+
super().__init__()
|
| 281 |
+
self.out_size = output_size
|
| 282 |
+
self.lif = neuron.LIFNode(
|
| 283 |
+
tau=tau,
|
| 284 |
+
step_mode="m",
|
| 285 |
+
detach_reset=detach_reset,
|
| 286 |
+
surrogate_function=surrogate.ATan(),
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
def forward(self, inputs: torch.Tensor):
|
| 290 |
+
# inputs: B, L, C
|
| 291 |
+
inputs = inputs.repeat(
|
| 292 |
+
tuple([self.out_size] + torch.ones(len(inputs.size()), dtype=int).tolist())
|
| 293 |
+
) # T B L C
|
| 294 |
+
inputs = inputs.permute(0, 1, 3, 2) # T B C L
|
| 295 |
+
spks = self.lif(inputs) # T B C L
|
| 296 |
+
return spks
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class DeltaEncoder(nn.Module):
|
| 300 |
+
def __init__(self, output_size: int):
|
| 301 |
+
super().__init__()
|
| 302 |
+
self.norm = nn.BatchNorm2d(1)
|
| 303 |
+
self.enc = nn.Linear(1, output_size)
|
| 304 |
+
self.lif = neuron.LIFNode(
|
| 305 |
+
tau=tau,
|
| 306 |
+
step_mode="m",
|
| 307 |
+
detach_reset=detach_reset,
|
| 308 |
+
surrogate_function=surrogate.ATan(),
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
def forward(self, inputs: torch.Tensor):
|
| 312 |
+
# inputs: B, L, C
|
| 313 |
+
delta = torch.zeros_like(inputs)
|
| 314 |
+
delta[:, 1:] = inputs[:, 1:, :] - inputs[:, :-1, :]
|
| 315 |
+
delta = delta.unsqueeze(1).permute(0, 1, 3, 2) # B, 1, C, L
|
| 316 |
+
delta = self.norm(delta)
|
| 317 |
+
delta = delta.permute(0, 2, 3, 1) # B, C, L, 1
|
| 318 |
+
enc = self.enc(delta) # B, C, L, T
|
| 319 |
+
enc = enc.permute(3, 0, 1, 2) # T, B, C, L
|
| 320 |
+
spks = self.lif(enc)
|
| 321 |
+
return spks
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
class ConvEncoder(nn.Module):
|
| 325 |
+
def __init__(self, output_size: int, kernel_size: int = 3):
|
| 326 |
+
super().__init__()
|
| 327 |
+
self.encoder = nn.Sequential(
|
| 328 |
+
nn.Conv2d(
|
| 329 |
+
in_channels=1,
|
| 330 |
+
out_channels=output_size,
|
| 331 |
+
kernel_size=(1, kernel_size),
|
| 332 |
+
stride=1,
|
| 333 |
+
padding=(0, kernel_size // 2),
|
| 334 |
+
),
|
| 335 |
+
nn.BatchNorm2d(output_size),
|
| 336 |
+
)
|
| 337 |
+
self.lif = neuron.LIFNode(
|
| 338 |
+
tau=tau,
|
| 339 |
+
step_mode="m",
|
| 340 |
+
detach_reset=detach_reset,
|
| 341 |
+
surrogate_function=surrogate.ATan(),
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
def forward(self, inputs: torch.Tensor):
|
| 345 |
+
# inputs: B, L, C
|
| 346 |
+
inputs = inputs.permute(0, 2, 1).unsqueeze(1) # B, 1, C, L
|
| 347 |
+
enc = self.encoder(inputs) # B, T, C, L
|
| 348 |
+
enc = enc.permute(1, 0, 2, 3) # T, B, C, L
|
| 349 |
+
spks = self.lif(enc) # T, B, C, L
|
| 350 |
+
return spks
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
SpikeEncoder = {
|
| 356 |
+
"snntorch": {
|
| 357 |
+
"repeat": RepeatEncoder,
|
| 358 |
+
"conv": ConvEncoder,
|
| 359 |
+
"delta": DeltaEncoder,
|
| 360 |
+
},
|
| 361 |
+
"spikingjelly": {
|
| 362 |
+
"repeat": RepeatEncoder,
|
| 363 |
+
"conv": ConvEncoder,
|
| 364 |
+
"delta": DeltaEncoder,
|
| 365 |
+
},
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class SSA(nn.Module):
|
| 371 |
+
def __init__(
|
| 372 |
+
self, length, tau, common_thr, dim, heads=8, qkv_bias=False, qk_scale=0.25
|
| 373 |
+
):
|
| 374 |
+
super().__init__()
|
| 375 |
+
assert dim % heads == 0, f"dim {dim} should be divided by num_heads {heads}."
|
| 376 |
+
|
| 377 |
+
self.dim = dim
|
| 378 |
+
self.heads = heads
|
| 379 |
+
self.qk_scale = qk_scale
|
| 380 |
+
|
| 381 |
+
self.q_m = nn.Linear(dim, dim)
|
| 382 |
+
self.q_bn = nn.BatchNorm1d(dim)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
self.q_tslif = TSLIFNode(
|
| 386 |
+
surrogate_function=SG.apply,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
self.k_m = nn.Linear(dim, dim)
|
| 390 |
+
self.k_bn = nn.BatchNorm1d(dim)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
self.k_tslif = TSLIFNode(
|
| 394 |
+
surrogate_function =SG.apply,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
self.v_m = nn.Linear(dim, dim)
|
| 398 |
+
self.v_bn = nn.BatchNorm1d(dim)
|
| 399 |
+
|
| 400 |
+
self.v_tslif = TSLIFNode(
|
| 401 |
+
surrogate_function =SG.apply,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
self.attn_tslif = TSLIFNode(
|
| 406 |
+
v_threshold=0.7,
|
| 407 |
+
surrogate_function=SG.apply
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
self.last_m = nn.Linear(dim, dim)
|
| 411 |
+
self.last_bn = nn.BatchNorm1d(dim)
|
| 412 |
+
|
| 413 |
+
self.last_tslif = TSLIFNode(
|
| 414 |
+
surrogate_function=SG.apply
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
def forward(self, x):
|
| 418 |
+
utils.reset(self.q_tslif)
|
| 419 |
+
utils.reset(self.k_tslif)
|
| 420 |
+
utils.reset(self.v_tslif)
|
| 421 |
+
utils.reset(self.attn_tslif)
|
| 422 |
+
utils.reset(self.last_tslif)
|
| 423 |
+
# x = x.transpose(0, 1)
|
| 424 |
+
|
| 425 |
+
# T, B, L, D = x.shape
|
| 426 |
+
B, T, L, D = x.shape
|
| 427 |
+
x_for_qkv = x.flatten(0, 1) # BT L D
|
| 428 |
+
q_m_out = self.q_m(x_for_qkv) # BT L D
|
| 429 |
+
|
| 430 |
+
q_m_out = (
|
| 431 |
+
self.q_bn(q_m_out.transpose(-1, -2))
|
| 432 |
+
.transpose(-1, -2)
|
| 433 |
+
.reshape(B, T, L, D)
|
| 434 |
+
.contiguous()
|
| 435 |
+
)
|
| 436 |
+
q_m_out = self.q_tslif(q_m_out)
|
| 437 |
+
|
| 438 |
+
q = (
|
| 439 |
+
q_m_out.reshape(B, T, L, self.heads, D // self.heads)
|
| 440 |
+
.permute(0, 1, 3, 2, 4)
|
| 441 |
+
.contiguous()
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
k_m_out = self.k_m(x_for_qkv)
|
| 445 |
+
|
| 446 |
+
k_m_out = (
|
| 447 |
+
self.k_bn(k_m_out.transpose(-1, -2))
|
| 448 |
+
.transpose(-1, -2)
|
| 449 |
+
.reshape(B, T, L, D)
|
| 450 |
+
.contiguous()
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
k_m_out = self.k_tslif(k_m_out)
|
| 454 |
+
k = (
|
| 455 |
+
k_m_out.reshape(B, T, L, self.heads, D // self.heads)
|
| 456 |
+
.permute(0, 1, 3, 2, 4)
|
| 457 |
+
.contiguous()
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
v_m_out = self.v_m(x_for_qkv)
|
| 461 |
+
v_m_out = (
|
| 462 |
+
self.v_bn(v_m_out.transpose(-1, -2))
|
| 463 |
+
.transpose(-1, -2)
|
| 464 |
+
.reshape(B, T, L, D)
|
| 465 |
+
.contiguous()
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
v_m_out = self.v_tslif(v_m_out)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
v = (
|
| 472 |
+
v_m_out.reshape(B, T, L, self.heads, D // self.heads)
|
| 473 |
+
.permute(0, 1, 3, 2, 4)
|
| 474 |
+
.contiguous()
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
attn = (q @ k.transpose(-2, -1)) * self.qk_scale
|
| 478 |
+
x = attn @ v # x_shape: T * B * heads * L * D//heads
|
| 479 |
+
|
| 480 |
+
x = x.transpose(2, 3).reshape(B, T, L, D).contiguous()
|
| 481 |
+
x = self.attn_tslif(x)
|
| 482 |
+
x = x.flatten(0, 1)
|
| 483 |
+
x = self.last_m(x)
|
| 484 |
+
x = self.last_bn(x.transpose(-1, -2)).transpose(-1, -2)
|
| 485 |
+
x = self.last_tslif(x.reshape(B, T, L, D).contiguous())
|
| 486 |
+
return x
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
class MLP(nn.Module):
|
| 490 |
+
def __init__(
|
| 491 |
+
self,
|
| 492 |
+
length,
|
| 493 |
+
tau,
|
| 494 |
+
common_thr,
|
| 495 |
+
in_features,
|
| 496 |
+
hidden_features=None,
|
| 497 |
+
out_features=None,
|
| 498 |
+
):
|
| 499 |
+
super().__init__()
|
| 500 |
+
out_features = out_features or in_features
|
| 501 |
+
self.in_features = in_features
|
| 502 |
+
self.hidden_features = hidden_features
|
| 503 |
+
self.out_features = out_features
|
| 504 |
+
|
| 505 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 506 |
+
self.bn1 = nn.BatchNorm1d(hidden_features)
|
| 507 |
+
|
| 508 |
+
self.mlp_tclif1 = TCLIFNode2(
|
| 509 |
+
surrogate_function =SG.apply,
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 513 |
+
self.bn2 = nn.BatchNorm1d(out_features)
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
self.mlp_tclif2 = TCLIFNode(
|
| 518 |
+
surrogate_function =SG.apply,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
def forward(self, x):
|
| 522 |
+
utils.reset(self.mlp_tclif1)
|
| 523 |
+
utils.reset(self.mlp_tclif2)
|
| 524 |
+
# T, B, L, D = x.shape
|
| 525 |
+
B, T, L, D = x.shape
|
| 526 |
+
x = x.flatten(0, 1) # BT L D
|
| 527 |
+
x = self.fc1(x) # TB L H
|
| 528 |
+
x = (
|
| 529 |
+
self.bn1(x.transpose(-1, -2))
|
| 530 |
+
.transpose(-1, -2)
|
| 531 |
+
.reshape(B, T, L, self.hidden_features)
|
| 532 |
+
.contiguous()
|
| 533 |
+
)
|
| 534 |
+
x = self.mlp_tclif1(x)
|
| 535 |
+
x = x.flatten(0, 1) # TB L H
|
| 536 |
+
x = self.fc2(x) # TB L D
|
| 537 |
+
x = (
|
| 538 |
+
self.bn2(x.transpose(-1, -2))
|
| 539 |
+
.transpose(-1, -2)
|
| 540 |
+
.reshape(B, T, L, D)
|
| 541 |
+
.contiguous()
|
| 542 |
+
)
|
| 543 |
+
x = self.mlp_tclif2(x)
|
| 544 |
+
return x
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
class Block(nn.Module):
|
| 548 |
+
def __init__(
|
| 549 |
+
self,
|
| 550 |
+
length,
|
| 551 |
+
tau,
|
| 552 |
+
common_thr,
|
| 553 |
+
dim,
|
| 554 |
+
d_ff,
|
| 555 |
+
heads=8,
|
| 556 |
+
qkv_bias=False,
|
| 557 |
+
qk_scale=0.125,
|
| 558 |
+
):
|
| 559 |
+
super().__init__()
|
| 560 |
+
self.attn = SSA(
|
| 561 |
+
length=length,
|
| 562 |
+
tau=tau,
|
| 563 |
+
common_thr=common_thr,
|
| 564 |
+
dim=dim,
|
| 565 |
+
heads=heads,
|
| 566 |
+
qkv_bias=qkv_bias,
|
| 567 |
+
qk_scale=qk_scale,
|
| 568 |
+
)
|
| 569 |
+
self.mlp = MLP(
|
| 570 |
+
length=length,
|
| 571 |
+
tau=tau,
|
| 572 |
+
common_thr=common_thr,
|
| 573 |
+
in_features=dim,
|
| 574 |
+
hidden_features=d_ff,
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
def forward(self, x):
|
| 578 |
+
x = x + self.attn(x)
|
| 579 |
+
x = x + self.mlp(x)
|
| 580 |
+
return x
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
@torch.jit.script
|
| 587 |
+
def heaviside(x: torch.Tensor):
|
| 588 |
+
return (x >= 0).to(x)
|
| 589 |
+
|
| 590 |
+
@torch.jit.script
|
| 591 |
+
def atan_backward(grad_output: torch.Tensor, x: torch.Tensor, alpha: float):
|
| 592 |
+
|
| 593 |
+
return alpha / 2 / (1 + (math.pi / 2 * alpha * x).pow_(2)) * grad_output, None
|
| 594 |
+
#
|
| 595 |
+
|
| 596 |
+
class SG(torch.autograd.Function):
|
| 597 |
+
@staticmethod
|
| 598 |
+
def forward(ctx, x, alpha=2.0):
|
| 599 |
+
if x.requires_grad:
|
| 600 |
+
#ctx.save_for_backward(x.detach().clone()) # additional instead
|
| 601 |
+
ctx.save_for_backward(x)
|
| 602 |
+
ctx.alpha = alpha
|
| 603 |
+
return heaviside(x)
|
| 604 |
+
|
| 605 |
+
@staticmethod
|
| 606 |
+
def backward(ctx, grad_output):
|
| 607 |
+
return atan_backward(grad_output, ctx.saved_tensors[0], ctx.alpha)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
class MemoryModule(nn.Module):
|
| 611 |
+
def __init__(self):
|
| 612 |
+
"""
|
| 613 |
+
* :ref:`API in English <MemoryModule.__init__-en>`
|
| 614 |
+
|
| 615 |
+
.. _MemoryModule.__init__-cn:
|
| 616 |
+
|
| 617 |
+
``MemoryModule`` 是SpikingJelly中所有有状态(记忆)模块的基类。
|
| 618 |
+
|
| 619 |
+
* :ref:`中文API <MemoryModule.__init__-cn>`
|
| 620 |
+
|
| 621 |
+
.. _MemoryModule.__init__-en:
|
| 622 |
+
|
| 623 |
+
``MemoryModule`` is the base class of all stateful modules in SpikingJelly.
|
| 624 |
+
|
| 625 |
+
"""
|
| 626 |
+
super().__init__()
|
| 627 |
+
self._memories = {}
|
| 628 |
+
self._memories_rv = {}
|
| 629 |
+
|
| 630 |
+
def register_memory(self, name: str, value):
|
| 631 |
+
"""
|
| 632 |
+
* :ref:`API in English <MemoryModule.register_memory-en>`
|
| 633 |
+
|
| 634 |
+
.. _MemoryModule.register_memory-cn:
|
| 635 |
+
|
| 636 |
+
:param name: 变量的名字
|
| 637 |
+
:type name: str
|
| 638 |
+
:param value: 变量的值
|
| 639 |
+
:type value: any
|
| 640 |
+
|
| 641 |
+
将变量存入用于保存有状态变量(例如脉冲神经元的膜电位)的字典中。这个变量的重置值会被设置为 ``value``。每次调用 ``self.reset()``
|
| 642 |
+
函数后, ``self.name`` 都会被重置为 ``value``。
|
| 643 |
+
|
| 644 |
+
* :ref:`中文API <MemoryModule.register_memory-cn>`
|
| 645 |
+
|
| 646 |
+
.. _MemoryModule.register_memory-en:
|
| 647 |
+
|
| 648 |
+
:param name: variable's name
|
| 649 |
+
:type name: str
|
| 650 |
+
:param value: variable's value
|
| 651 |
+
:type value: any
|
| 652 |
+
|
| 653 |
+
Register the variable to memory dict, which saves stateful variables (e.g., the membrane potential of a
|
| 654 |
+
spiking neuron). The reset value of this variable will be ``value``. ``self.name`` will be set to ``value`` after
|
| 655 |
+
each calling of ``self.reset()``.
|
| 656 |
+
|
| 657 |
+
"""
|
| 658 |
+
assert not hasattr(self, name), f'{name} has been set as a member variable!'
|
| 659 |
+
self._memories[name] = value
|
| 660 |
+
self.set_reset_value(name, value)
|
| 661 |
+
|
| 662 |
+
def reset(self):
|
| 663 |
+
"""
|
| 664 |
+
* :ref:`API in English <MemoryModule.reset-en>`
|
| 665 |
+
|
| 666 |
+
.. _MemoryModule.reset-cn:
|
| 667 |
+
|
| 668 |
+
重置所有有状态变量为默认值。
|
| 669 |
+
|
| 670 |
+
* :ref:`中文API <MemoryModule.reset-cn>`
|
| 671 |
+
|
| 672 |
+
.. _MemoryModule.reset-en:
|
| 673 |
+
|
| 674 |
+
Reset all stateful variables to their default values.
|
| 675 |
+
"""
|
| 676 |
+
for key in self._memories.keys():
|
| 677 |
+
self._memories[key] = copy.deepcopy(self._memories_rv[key])
|
| 678 |
+
|
| 679 |
+
def set_reset_value(self, name: str, value):
|
| 680 |
+
self._memories_rv[name] = copy.deepcopy(value)
|
| 681 |
+
|
| 682 |
+
def __getattr__(self, name: str):
|
| 683 |
+
if '_memories' in self.__dict__:
|
| 684 |
+
memories = self.__dict__['_memories']
|
| 685 |
+
if name in memories:
|
| 686 |
+
return memories[name]
|
| 687 |
+
|
| 688 |
+
return super().__getattr__(name)
|
| 689 |
+
|
| 690 |
+
def __setattr__(self, name: str, value) -> None:
|
| 691 |
+
_memories = self.__dict__.get('_memories')
|
| 692 |
+
if _memories is not None and name in _memories:
|
| 693 |
+
_memories[name] = value
|
| 694 |
+
else:
|
| 695 |
+
super().__setattr__(name, value)
|
| 696 |
+
|
| 697 |
+
def __delattr__(self, name):
|
| 698 |
+
if name in self._memories:
|
| 699 |
+
del self._memories[name]
|
| 700 |
+
del self._memories_rv[name]
|
| 701 |
+
else:
|
| 702 |
+
return super().__delattr__(name)
|
| 703 |
+
|
| 704 |
+
def __dir__(self):
|
| 705 |
+
module_attrs = dir(self.__class__)
|
| 706 |
+
attrs = list(self.__dict__.keys())
|
| 707 |
+
parameters = list(self._parameters.keys())
|
| 708 |
+
modules = list(self._modules.keys())
|
| 709 |
+
buffers = list(self._buffers.keys())
|
| 710 |
+
memories = list(self._memories.keys())
|
| 711 |
+
keys = module_attrs + attrs + parameters + modules + buffers + memories
|
| 712 |
+
|
| 713 |
+
# Eliminate attrs that are not legal Python variable names
|
| 714 |
+
keys = [key for key in keys if not key[0].isdigit()]
|
| 715 |
+
|
| 716 |
+
return sorted(keys)
|
| 717 |
+
|
| 718 |
+
def memories(self):
|
| 719 |
+
"""
|
| 720 |
+
* :ref:`API in English <MemoryModule.memories-en>`
|
| 721 |
+
|
| 722 |
+
.. _MemoryModule.memories-cn:
|
| 723 |
+
|
| 724 |
+
:return: 返回一个所有状态变量的迭代器
|
| 725 |
+
:rtype: Iterator
|
| 726 |
+
|
| 727 |
+
* :ref:`中文API <MemoryModule.memories-cn>`
|
| 728 |
+
|
| 729 |
+
.. _MemoryModule.memories-en:
|
| 730 |
+
|
| 731 |
+
:return: an iterator over all stateful variables
|
| 732 |
+
:rtype: Iterator
|
| 733 |
+
"""
|
| 734 |
+
for name, value in self._memories.items():
|
| 735 |
+
yield value
|
| 736 |
+
|
| 737 |
+
def named_memories(self):
|
| 738 |
+
"""
|
| 739 |
+
* :ref:`API in English <MemoryModule.named_memories-en>`
|
| 740 |
+
|
| 741 |
+
.. _MemoryModule.named_memories-cn:
|
| 742 |
+
|
| 743 |
+
:return: 返回一个所有状态变量及其名称的迭代器
|
| 744 |
+
:rtype: Iterator
|
| 745 |
+
|
| 746 |
+
* :ref:`中文API <MemoryModule.named_memories-cn>`
|
| 747 |
+
|
| 748 |
+
.. _MemoryModule.named_memories-en:
|
| 749 |
+
|
| 750 |
+
:return: an iterator over all stateful variables and their names
|
| 751 |
+
:rtype: Iterator
|
| 752 |
+
"""
|
| 753 |
+
|
| 754 |
+
for name, value in self._memories.items():
|
| 755 |
+
yield name, value
|
| 756 |
+
|
| 757 |
+
def detach(self):
|
| 758 |
+
"""
|
| 759 |
+
* :ref:`API in English <MemoryModule.detach-en>`
|
| 760 |
+
|
| 761 |
+
.. _MemoryModule.detach-cn:
|
| 762 |
+
|
| 763 |
+
从计算图中分离所有有状态变量。
|
| 764 |
+
|
| 765 |
+
.. tip::
|
| 766 |
+
|
| 767 |
+
可以使用这个函数实现TBPTT(Truncated Back Propagation Through Time)。
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
* :ref:`中文API <MemoryModule.detach-cn>`
|
| 771 |
+
|
| 772 |
+
.. _MemoryModule.detach-en:
|
| 773 |
+
|
| 774 |
+
Detach all stateful variables.
|
| 775 |
+
|
| 776 |
+
.. admonition:: Tip
|
| 777 |
+
:class: tip
|
| 778 |
+
|
| 779 |
+
We can use this function to implement TBPTT(Truncated Back Propagation Through Time).
|
| 780 |
+
|
| 781 |
+
"""
|
| 782 |
+
|
| 783 |
+
for key in self._memories.keys():
|
| 784 |
+
if isinstance(self._memories[key], torch.Tensor):
|
| 785 |
+
self._memories[key].detach_()
|
| 786 |
+
|
| 787 |
+
def _apply(self, fn):
|
| 788 |
+
for key, value in self._memories.items():
|
| 789 |
+
if isinstance(value, torch.Tensor):
|
| 790 |
+
self._memories[key] = fn(value)
|
| 791 |
+
# do not apply on default values
|
| 792 |
+
# for key, value in self._memories_rv.items():
|
| 793 |
+
# if isinstance(value, torch.Tensor):
|
| 794 |
+
# self._memories_rv[key] = fn(value)
|
| 795 |
+
return super()._apply(fn)
|
| 796 |
+
|
| 797 |
+
def _replicate_for_data_parallel(self):
|
| 798 |
+
replica = super()._replicate_for_data_parallel()
|
| 799 |
+
replica._memories = self._memories.copy()
|
| 800 |
+
return replica
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
class StepModule:
|
| 804 |
+
def supported_step_mode(self):
|
| 805 |
+
"""
|
| 806 |
+
* :ref:`API in English <StepModule.supported_step_mode-en>`
|
| 807 |
+
.. _StepModule.supported_step_mode-cn:
|
| 808 |
+
:return: 包含支持的后端的tuple
|
| 809 |
+
:rtype: tuple[str]
|
| 810 |
+
返回此模块支持的步进模式。
|
| 811 |
+
* :ref:`中文 API <StepModule.supported_step_mode-cn>`
|
| 812 |
+
.. _StepModule.supported_step_mode-en:
|
| 813 |
+
:return: a tuple that contains the supported backends
|
| 814 |
+
:rtype: tuple[str]
|
| 815 |
+
"""
|
| 816 |
+
return ('s', 'm')
|
| 817 |
+
|
| 818 |
+
@property
|
| 819 |
+
def step_mode(self):
|
| 820 |
+
"""
|
| 821 |
+
* :ref:`API in English <StepModule.step_mode-en>`
|
| 822 |
+
.. _StepModule.step_mode-cn:
|
| 823 |
+
:return: 模块当前使用的步进模式
|
| 824 |
+
:rtype: str
|
| 825 |
+
* :ref:`中文 API <StepModule.step_mode-cn>`
|
| 826 |
+
.. _StepModule.step_mode-en:
|
| 827 |
+
:return: the current step mode of this module
|
| 828 |
+
:rtype: str
|
| 829 |
+
"""
|
| 830 |
+
return self._step_mode
|
| 831 |
+
|
| 832 |
+
@step_mode.setter
|
| 833 |
+
def step_mode(self, value: str):
|
| 834 |
+
"""
|
| 835 |
+
* :ref:`API in English <StepModule.step_mode-setter-en>`
|
| 836 |
+
.. _StepModule.step_mode-setter-cn:
|
| 837 |
+
:param value: 步进模式
|
| 838 |
+
:type value: str
|
| 839 |
+
将本模块的步进模式设置为 ``value``
|
| 840 |
+
* :ref:`中文 API <StepModule.step_mode-setter-cn>`
|
| 841 |
+
.. _StepModule.step_mode-setter-en:
|
| 842 |
+
:param value: the step mode
|
| 843 |
+
:type value: str
|
| 844 |
+
Set the step mode of this module to be ``value``
|
| 845 |
+
"""
|
| 846 |
+
if value not in self.supported_step_mode():
|
| 847 |
+
raise ValueError(f'step_mode can only be {self.supported_step_mode()}, but got "{value}"!')
|
| 848 |
+
self._step_mode = value
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
class BaseNode(MemoryModule):
|
| 853 |
+
def __init__(self,
|
| 854 |
+
v_threshold: float = 1.,
|
| 855 |
+
v_reset: float = 0.,
|
| 856 |
+
surrogate_function: Callable = None,
|
| 857 |
+
detach_reset: bool = False,
|
| 858 |
+
step_mode='s', backend='torch',
|
| 859 |
+
store_v_seq: bool = True):
|
| 860 |
+
|
| 861 |
+
assert isinstance(v_reset, float) or v_reset is None
|
| 862 |
+
assert isinstance(v_threshold, float)
|
| 863 |
+
assert isinstance(detach_reset, bool)
|
| 864 |
+
super().__init__()
|
| 865 |
+
|
| 866 |
+
if v_reset is None:
|
| 867 |
+
self.register_memory('v', 0.)
|
| 868 |
+
self.register_memory('v_s', 0.)
|
| 869 |
+
else:
|
| 870 |
+
self.register_memory('v', v_reset)
|
| 871 |
+
|
| 872 |
+
self.v_threshold = v_threshold
|
| 873 |
+
|
| 874 |
+
self.v_reset = v_reset
|
| 875 |
+
self.detach_reset = detach_reset
|
| 876 |
+
self.surrogate_function = surrogate_function
|
| 877 |
+
|
| 878 |
+
self.step_mode = step_mode
|
| 879 |
+
self.backend = backend
|
| 880 |
+
|
| 881 |
+
self.store_v_seq = store_v_seq
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
self.alpha_s = torch.nn.Parameter(torch.tensor(0.5, dtype=torch.float))
|
| 885 |
+
self.alpha_l = torch.nn.Parameter(torch.tensor(0.5, dtype=torch.float))
|
| 886 |
+
#self.alpha_s = torch.nn.Parameter(torch.randn([1, 128], dtype=torch.float))
|
| 887 |
+
#self.alpha_l = torch.nn.Parameter(torch.randn([1, 128], dtype=torch.float))
|
| 888 |
+
|
| 889 |
+
@property
|
| 890 |
+
def store_v_seq(self):
|
| 891 |
+
return self._store_v_seq
|
| 892 |
+
|
| 893 |
+
@store_v_seq.setter
|
| 894 |
+
def store_v_seq(self, value: bool):
|
| 895 |
+
self._store_v_seq = value
|
| 896 |
+
if value:
|
| 897 |
+
if not hasattr(self, 'v_seq'):
|
| 898 |
+
self.register_memory('v_seq', None)
|
| 899 |
+
|
| 900 |
+
@staticmethod
|
| 901 |
+
@torch.jit.script
|
| 902 |
+
def jit_hard_reset(v: torch.Tensor, spike: torch.Tensor, v_reset: float):
|
| 903 |
+
v = (1. - spike) * v + spike * v_reset
|
| 904 |
+
|
| 905 |
+
return v
|
| 906 |
+
|
| 907 |
+
@staticmethod
|
| 908 |
+
@torch.jit.script
|
| 909 |
+
def jit_soft_reset(v: torch.Tensor, spike: torch.Tensor, v_threshold: float):
|
| 910 |
+
v = v - spike * v_threshold
|
| 911 |
+
return v
|
| 912 |
+
|
| 913 |
+
|
| 914 |
+
@abstractmethod
|
| 915 |
+
def neuronal_charge(self, x: torch.Tensor):
|
| 916 |
+
raise NotImplementedError
|
| 917 |
+
|
| 918 |
+
def neuronal_fire(self):
|
| 919 |
+
return self.surrogate_function(self.v - self.v_threshold, 2.0)
|
| 920 |
+
|
| 921 |
+
def sl_neuronal_fire(self):
|
| 922 |
+
s_s = self.surrogate_function(self.v - self.v_threshold, 2.0)
|
| 923 |
+
s_l = self.surrogate_function(self.v_s - self.v_threshold, 2.0)
|
| 924 |
+
return s_s, s_l
|
| 925 |
+
|
| 926 |
+
def extra_repr(self):
|
| 927 |
+
return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, step_mode={self.step_mode}, backend={self.backend}'
|
| 928 |
+
|
| 929 |
+
def single_step_forward(self, x: torch.Tensor):
|
| 930 |
+
self.v_float_to_tensor(x)
|
| 931 |
+
self.neuronal_charge(x)
|
| 932 |
+
s_s, s_l = self.sl_neuronal_fire()
|
| 933 |
+
spike = self.alpha_s * s_s + self.alpha_l * s_l
|
| 934 |
+
self.neuronal_reset(s_s, s_l)
|
| 935 |
+
|
| 936 |
+
return spike
|
| 937 |
+
|
| 938 |
+
def multi_step_forward(self, x_seq: torch.Tensor):
|
| 939 |
+
|
| 940 |
+
#### time series ###
|
| 941 |
+
T = x_seq.shape[-1]
|
| 942 |
+
y_seq = []
|
| 943 |
+
if self.store_v_seq:
|
| 944 |
+
v_seq = []
|
| 945 |
+
for t in range(T):
|
| 946 |
+
y = self.single_step_forward(x_seq[:, t])
|
| 947 |
+
y_seq.append(y)
|
| 948 |
+
if self.store_v_seq:
|
| 949 |
+
v_seq.append(self.v)
|
| 950 |
+
if self.store_v_seq:
|
| 951 |
+
self.v_seq = torch.stack(v_seq)
|
| 952 |
+
|
| 953 |
+
# if self.store_v_seq:
|
| 954 |
+
# self.v_seq = torch.stack(v_seq)
|
| 955 |
+
outputs = torch.stack(y_seq, dim=0).permute(1, 0)
|
| 956 |
+
|
| 957 |
+
return outputs
|
| 958 |
+
|
| 959 |
+
def v_float_to_tensor(self, x: torch.Tensor):
|
| 960 |
+
if isinstance(self.v, float):
|
| 961 |
+
v_init = self.v
|
| 962 |
+
self.v = torch.full_like(x.data, v_init)
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
class TSLIFNode(BaseNode):
|
| 966 |
+
def __init__(self,
|
| 967 |
+
v_threshold=1.0,
|
| 968 |
+
v_reset=0.,
|
| 969 |
+
surrogate_function: Callable = None,
|
| 970 |
+
detach_reset=False,
|
| 971 |
+
hard_reset=False,
|
| 972 |
+
step_mode='s',
|
| 973 |
+
k=2,
|
| 974 |
+
decay_factor: torch.Tensor = torch.tensor([0.8, 0.2, 0.3, 0.7], dtype=torch.float),
|
| 975 |
+
gamma: float = 0.5):
|
| 976 |
+
super(TSLIFNode, self).__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode)
|
| 977 |
+
self.k = k
|
| 978 |
+
for i in range(1, self.k + 1):
|
| 979 |
+
self.register_memory('v' + str(i), 0.)
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
self.names = self._memories
|
| 983 |
+
self.hard_reset = hard_reset
|
| 984 |
+
self.gamma = gamma
|
| 985 |
+
self.decay_factor = torch.nn.Parameter(decay_factor)
|
| 986 |
+
self.kk = torch.nn.Parameter(torch.tensor([0.8], dtype=torch.float))
|
| 987 |
+
self.yy = torch.nn.Parameter(torch.tensor([0.1], dtype=torch.float))
|
| 988 |
+
|
| 989 |
+
@property
|
| 990 |
+
def supported_backends(self):
|
| 991 |
+
if self.step_mode == 's':
|
| 992 |
+
return ('torch',)
|
| 993 |
+
elif self.step_mode == 'm':
|
| 994 |
+
return ('torch', 'cupy')
|
| 995 |
+
else:
|
| 996 |
+
raise ValueError(self.step_mode)
|
| 997 |
+
|
| 998 |
+
def neuronal_charge(self, x: torch.Tensor):
|
| 999 |
+
self.names['v1'] = self.decay_factor[0] * self.names['v1'] + self.decay_factor[1] * x - self.yy * self.names['v2']
|
| 1000 |
+
self.names['v2'] = self.decay_factor[2] * self.names['v2'] + self.decay_factor[3] * x - self.kk * self.names['v1']
|
| 1001 |
+
self.v = self.names['v2']
|
| 1002 |
+
self.v_s = self.names['v1']
|
| 1003 |
+
|
| 1004 |
+
def neuronal_reset(self, spike_s, spike_l):
|
| 1005 |
+
if not self.hard_reset:
|
| 1006 |
+
self.names['v1'] = self.jit_soft_reset(self.names['v1'], spike_l, self.gamma)
|
| 1007 |
+
self.names['v2'] = self.jit_soft_reset(self.names['v2'], spike_s, self.v_threshold)
|
| 1008 |
+
else:
|
| 1009 |
+
for i in range(2, self.k + 1):
|
| 1010 |
+
self.names['v' + str(i)] = self.jit_hard_reset(self.names['v' + str(i)], spike_s, self.v_reset)
|
| 1011 |
+
|
| 1012 |
+
def forward(self, x: torch.Tensor):
|
| 1013 |
+
return super().single_step_forward(x)
|
| 1014 |
+
def extra_repr(self):
|
| 1015 |
+
return f"v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, " \
|
| 1016 |
+
f"hard_reset={self.hard_reset}, " \
|
| 1017 |
+
f"gamma={self.gamma}, k={self.k}, step_mode={self.step_mode}, backend={self.backend}"
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
|
| 1022 |
+
|
| 1023 |
+
class BaseNode1(MemoryModule):
|
| 1024 |
+
def __init__(self,
|
| 1025 |
+
v_threshold: float = 1.,
|
| 1026 |
+
v_reset: float = 0.,
|
| 1027 |
+
surrogate_function: Callable = None,
|
| 1028 |
+
detach_reset: bool = False,
|
| 1029 |
+
step_mode='s', backend='torch',
|
| 1030 |
+
store_v_seq: bool = True):
|
| 1031 |
+
|
| 1032 |
+
assert isinstance(v_reset, float) or v_reset is None
|
| 1033 |
+
assert isinstance(v_threshold, float)
|
| 1034 |
+
assert isinstance(detach_reset, bool)
|
| 1035 |
+
super().__init__()
|
| 1036 |
+
|
| 1037 |
+
if v_reset is None:
|
| 1038 |
+
self.register_memory('v', 0.)
|
| 1039 |
+
self.register_memory('v_s', 0.)
|
| 1040 |
+
else:
|
| 1041 |
+
self.register_memory('v', v_reset)
|
| 1042 |
+
|
| 1043 |
+
self.v_threshold = v_threshold
|
| 1044 |
+
|
| 1045 |
+
self.v_reset = v_reset
|
| 1046 |
+
self.detach_reset = detach_reset
|
| 1047 |
+
self.surrogate_function = surrogate_function
|
| 1048 |
+
|
| 1049 |
+
self.step_mode = step_mode
|
| 1050 |
+
self.backend = backend
|
| 1051 |
+
|
| 1052 |
+
self.store_v_seq = store_v_seq
|
| 1053 |
+
self.alpha_s = torch.nn.Parameter(torch.tensor(0.5, dtype=torch.float))
|
| 1054 |
+
self.alpha_l = torch.nn.Parameter(torch.tensor(0.5, dtype=torch.float))
|
| 1055 |
+
|
| 1056 |
+
@property
|
| 1057 |
+
def store_v_seq(self):
|
| 1058 |
+
return self._store_v_seq
|
| 1059 |
+
|
| 1060 |
+
@store_v_seq.setter
|
| 1061 |
+
def store_v_seq(self, value: bool):
|
| 1062 |
+
self._store_v_seq = value
|
| 1063 |
+
if value:
|
| 1064 |
+
if not hasattr(self, 'v_seq'):
|
| 1065 |
+
self.register_memory('v_seq', None)
|
| 1066 |
+
|
| 1067 |
+
@staticmethod
|
| 1068 |
+
@torch.jit.script
|
| 1069 |
+
def jit_hard_reset(v: torch.Tensor, spike: torch.Tensor, v_reset: float):
|
| 1070 |
+
v = (1. - spike) * v + spike * v_reset
|
| 1071 |
+
|
| 1072 |
+
return v
|
| 1073 |
+
|
| 1074 |
+
@staticmethod
|
| 1075 |
+
@torch.jit.script
|
| 1076 |
+
def jit_soft_reset(v: torch.Tensor, spike: torch.Tensor, v_threshold: float):
|
| 1077 |
+
v = v - spike * v_threshold
|
| 1078 |
+
return v
|
| 1079 |
+
|
| 1080 |
+
|
| 1081 |
+
@abstractmethod
|
| 1082 |
+
def neuronal_charge(self, x: torch.Tensor):
|
| 1083 |
+
raise NotImplementedError
|
| 1084 |
+
|
| 1085 |
+
def neuronal_fire(self):
|
| 1086 |
+
return self.surrogate_function(self.v - self.v_threshold, 2.0)
|
| 1087 |
+
|
| 1088 |
+
def sl_neuronal_fire(self):
|
| 1089 |
+
s_s = self.surrogate_function(self.v - self.v_threshold, 2.0)
|
| 1090 |
+
s_l = self.surrogate_function(self.v_s - self.v_threshold, 2.0)
|
| 1091 |
+
return s_s, s_l
|
| 1092 |
+
|
| 1093 |
+
def extra_repr(self):
|
| 1094 |
+
return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, step_mode={self.step_mode}, backend={self.backend}'
|
| 1095 |
+
|
| 1096 |
+
def single_step_forward(self, x: torch.Tensor):
|
| 1097 |
+
self.v_float_to_tensor(x)
|
| 1098 |
+
self.neuronal_charge(x)
|
| 1099 |
+
s_s, s_l = self.sl_neuronal_fire()
|
| 1100 |
+
spike = self.alpha_s * s_s + self.alpha_l * s_l
|
| 1101 |
+
self.neuronal_reset(s_s, s_l)
|
| 1102 |
+
return spike
|
| 1103 |
+
|
| 1104 |
+
def multi_step_forward(self, x_seq: torch.Tensor):
|
| 1105 |
+
|
| 1106 |
+
#### time series ###
|
| 1107 |
+
T = x_seq.shape[-1]
|
| 1108 |
+
y_seq = []
|
| 1109 |
+
if self.store_v_seq:
|
| 1110 |
+
v_seq = []
|
| 1111 |
+
for t in range(2):
|
| 1112 |
+
y = self.single_step_forward(x_seq[:, t, :, :])
|
| 1113 |
+
y_seq.append(y)
|
| 1114 |
+
if self.store_v_seq:
|
| 1115 |
+
v_seq.append(self.v)
|
| 1116 |
+
if self.store_v_seq:
|
| 1117 |
+
self.v_seq = torch.stack(v_seq)
|
| 1118 |
+
outputs = torch.stack(y_seq, dim=0)
|
| 1119 |
+
outputs = outputs.permute(1, 0, 2, 3)
|
| 1120 |
+
|
| 1121 |
+
return outputs
|
| 1122 |
+
|
| 1123 |
+
|
| 1124 |
+
def v_float_to_tensor(self, x: torch.Tensor):
|
| 1125 |
+
if isinstance(self.v, float):
|
| 1126 |
+
v_init = self.v
|
| 1127 |
+
self.v = torch.full_like(x.data, v_init)
|
| 1128 |
+
|
| 1129 |
+
|
| 1130 |
+
|
| 1131 |
+
class TCLIFNode2(BaseNode1):
|
| 1132 |
+
def __init__(self,
|
| 1133 |
+
v_threshold=0.8,
|
| 1134 |
+
v_reset=0.,
|
| 1135 |
+
surrogate_function: Callable = None,
|
| 1136 |
+
detach_reset=False,
|
| 1137 |
+
hard_reset=False,
|
| 1138 |
+
step_mode='s',
|
| 1139 |
+
k=2,
|
| 1140 |
+
decay_factor: torch.Tensor = torch.tensor([0.8, 0.2, 0.3, 0.7], dtype=torch.float),
|
| 1141 |
+
gamma: float = 0.5):
|
| 1142 |
+
super(TCLIFNode2, self).__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode)
|
| 1143 |
+
self.k = k
|
| 1144 |
+
for i in range(1, self.k + 1):
|
| 1145 |
+
self.register_memory('v' + str(i), 0.)
|
| 1146 |
+
|
| 1147 |
+
self.names = self._memories
|
| 1148 |
+
self.hard_reset = hard_reset
|
| 1149 |
+
self.gamma = gamma
|
| 1150 |
+
self.decay_factor = torch.nn.Parameter(decay_factor)
|
| 1151 |
+
self.kk = torch.nn.Parameter(torch.tensor([0.8], dtype=torch.float))
|
| 1152 |
+
self.yy = torch.nn.Parameter(torch.tensor([0.1], dtype=torch.float))
|
| 1153 |
+
|
| 1154 |
+
@property
|
| 1155 |
+
def supported_backends(self):
|
| 1156 |
+
if self.step_mode == 's':
|
| 1157 |
+
return ('torch',)
|
| 1158 |
+
elif self.step_mode == 'm':
|
| 1159 |
+
return ('torch', 'cupy')
|
| 1160 |
+
else:
|
| 1161 |
+
raise ValueError(self.step_mode)
|
| 1162 |
+
|
| 1163 |
+
def neuronal_charge(self, x: torch.Tensor):
|
| 1164 |
+
self.names['v1'] = self.decay_factor[0] * self.names['v1'] + self.decay_factor[1] * x - self.yy * self.names['v2']
|
| 1165 |
+
self.names['v2'] = self.decay_factor[2] * self.names['v2'] + self.decay_factor[3] * x - self.kk * self.names['v1']
|
| 1166 |
+
self.v = self.names['v2']
|
| 1167 |
+
self.v_s = self.names['v1']
|
| 1168 |
+
|
| 1169 |
+
def neuronal_reset(self, spike_s, spike_l):
|
| 1170 |
+
if not self.hard_reset:
|
| 1171 |
+
self.names['v1'] = self.jit_soft_reset(self.names['v1'], spike_l , self.gamma)
|
| 1172 |
+
self.names['v2'] = self.jit_soft_reset(self.names['v2'], spike_s, self.v_threshold)
|
| 1173 |
+
else:
|
| 1174 |
+
# hard reset
|
| 1175 |
+
for i in range(2, self.k + 1):
|
| 1176 |
+
self.names['v' + str(i)] = self.jit_hard_reset(self.names['v' + str(i)], spike_d, self.v_reset)
|
| 1177 |
+
|
| 1178 |
+
def forward(self, x: torch.Tensor):
|
| 1179 |
+
return super().single_step_forward(x)
|
| 1180 |
+
|
| 1181 |
+
def extra_repr(self):
|
| 1182 |
+
return f"v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, " \
|
| 1183 |
+
f"hard_reset={self.hard_reset}, " \
|
| 1184 |
+
f"gamma={self.gamma}, k={self.k}, step_mode={self.step_mode}, backend={self.backend}"
|
| 1185 |
+
|
| 1186 |
+
|
| 1187 |
+
|
| 1188 |
+
|
| 1189 |
+
|
| 1190 |
+
class TCLIFNode(BaseNode):
|
| 1191 |
+
def __init__(self,
|
| 1192 |
+
v_threshold=1.0,
|
| 1193 |
+
v_reset=0.,
|
| 1194 |
+
surrogate_function: Callable = None,
|
| 1195 |
+
detach_reset=False,
|
| 1196 |
+
hard_reset=False,
|
| 1197 |
+
step_mode='s',
|
| 1198 |
+
k=2,
|
| 1199 |
+
decay_factor: torch.Tensor = torch.tensor([0.8, 0.2, 0.3, 0.7], dtype=torch.float),
|
| 1200 |
+
gamma: float = 0.5):
|
| 1201 |
+
super(TCLIFNode, self).__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode)
|
| 1202 |
+
self.k = k
|
| 1203 |
+
for i in range(1, self.k + 1):
|
| 1204 |
+
self.register_memory('v' + str(i), 0.)
|
| 1205 |
+
|
| 1206 |
+
self.names = self._memories
|
| 1207 |
+
self.hard_reset = hard_reset
|
| 1208 |
+
self.gamma = gamma
|
| 1209 |
+
self.decay_factor = torch.nn.Parameter(decay_factor)
|
| 1210 |
+
self.kk = torch.nn.Parameter(torch.tensor([0.8], dtype=torch.float))
|
| 1211 |
+
self.yy = torch.nn.Parameter(torch.tensor([0.1], dtype=torch.float))
|
| 1212 |
+
|
| 1213 |
+
@property
|
| 1214 |
+
def supported_backends(self):
|
| 1215 |
+
if self.step_mode == 's':
|
| 1216 |
+
return ('torch',)
|
| 1217 |
+
elif self.step_mode == 'm':
|
| 1218 |
+
return ('torch', 'cupy')
|
| 1219 |
+
else:
|
| 1220 |
+
raise ValueError(self.step_mode)
|
| 1221 |
+
|
| 1222 |
+
def neuronal_charge(self, x: torch.Tensor):
|
| 1223 |
+
self.names['v1'] = self.decay_factor[0] * self.names['v1'] + self.decay_factor[1] * x - self.yy * self.names['v2']
|
| 1224 |
+
self.names['v2'] = self.decay_factor[2] * self.names['v2'] + self.decay_factor[3] * x - self.kk * self.names['v1']
|
| 1225 |
+
self.v = self.names['v2']
|
| 1226 |
+
self.v_s = self.names['v1']
|
| 1227 |
+
|
| 1228 |
+
def neuronal_reset(self, spike_s, spike_l):
|
| 1229 |
+
if not self.hard_reset:
|
| 1230 |
+
self.names['v1'] = self.jit_soft_reset(self.names['v1'], spike_l , self.gamma)
|
| 1231 |
+
self.names['v2'] = self.jit_soft_reset(self.names['v2'], spike_s, self.v_threshold)
|
| 1232 |
+
else:
|
| 1233 |
+
# hard reset
|
| 1234 |
+
for i in range(2, self.k + 1):
|
| 1235 |
+
self.names['v' + str(i)] = self.jit_hard_reset(self.names['v' + str(i)], spike_d, self.v_reset)
|
| 1236 |
+
|
| 1237 |
+
def forward(self, x: torch.Tensor):
|
| 1238 |
+
return super().single_step_forward(x)
|
| 1239 |
+
def extra_repr(self):
|
| 1240 |
+
return f"v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, " \
|
| 1241 |
+
f"hard_reset={self.hard_reset}, " \
|
| 1242 |
+
f"gamma={self.gamma}, k={self.k}, step_mode={self.step_mode}, backend={self.backend}"
|
| 1243 |
+
|
| 1244 |
+
|
| 1245 |
+
|
| 1246 |
+
|
| 1247 |
+
|
| 1248 |
+
class TSFormer(nn.Module):
|
| 1249 |
+
|
| 1250 |
+
def __init__(
|
| 1251 |
+
self,
|
| 1252 |
+
args,
|
| 1253 |
+
dim: int = 256,
|
| 1254 |
+
d_ff: Optional[int] = None,
|
| 1255 |
+
num_pe_neuron: int = 40,
|
| 1256 |
+
pe_type: str = "neuron",
|
| 1257 |
+
pe_mode: str = "concat", # "add" or concat
|
| 1258 |
+
neuron_pe_scale: float = 10000.0, # "100" or "1000" or "10000"
|
| 1259 |
+
depths: int = 2,
|
| 1260 |
+
common_thr: float = 1.0,
|
| 1261 |
+
max_length: int = 5000,
|
| 1262 |
+
num_steps: int = 4,
|
| 1263 |
+
heads: int = 8,
|
| 1264 |
+
qkv_bias: bool = False,
|
| 1265 |
+
qk_scale: float = 0.125,
|
| 1266 |
+
input_size: Optional[int] = None,
|
| 1267 |
+
weight_file: Optional[Path] = None,
|
| 1268 |
+
):
|
| 1269 |
+
super().__init__()
|
| 1270 |
+
self.dim = 256
|
| 1271 |
+
self.d_ff = 1024
|
| 1272 |
+
self.T = args.T
|
| 1273 |
+
self.depths = args.blocks
|
| 1274 |
+
self.pe_type = pe_type
|
| 1275 |
+
self.pe_mode = pe_mode
|
| 1276 |
+
self.num_pe_neuron = num_pe_neuron
|
| 1277 |
+
self.input_size = args.feature_size
|
| 1278 |
+
self._snn_backend = "spikingjelly"
|
| 1279 |
+
self.temporal_encoder = SpikeEncoder[self._snn_backend]["conv"](num_steps)
|
| 1280 |
+
self.pre_length = args.pre_length
|
| 1281 |
+
self.feature_size = args.feature_size
|
| 1282 |
+
self.args = args
|
| 1283 |
+
self.pe = PositionEmbedding(
|
| 1284 |
+
pe_type=pe_type,
|
| 1285 |
+
pe_mode=pe_mode,
|
| 1286 |
+
neuron_pe_scale=neuron_pe_scale,
|
| 1287 |
+
input_size=self.input_size,
|
| 1288 |
+
max_len=max_length,
|
| 1289 |
+
num_pe_neuron=self.num_pe_neuron,
|
| 1290 |
+
dropout=0.1,
|
| 1291 |
+
num_steps=num_steps,
|
| 1292 |
+
)
|
| 1293 |
+
if (self.pe_type == "neuron" and self.pe_mode == "concat") or (
|
| 1294 |
+
self.pe_type == "random" and self.pe_mode == "concat"
|
| 1295 |
+
):
|
| 1296 |
+
self.encoder = nn.Linear(self.input_size + num_pe_neuron, dim)
|
| 1297 |
+
else:
|
| 1298 |
+
self.encoder = nn.Linear(self.input_size, dim)
|
| 1299 |
+
|
| 1300 |
+
self.init_lif = neuron.LIFNode(
|
| 1301 |
+
tau=tau,
|
| 1302 |
+
step_mode="m",
|
| 1303 |
+
detach_reset=detach_reset,
|
| 1304 |
+
surrogate_function=surrogate.ATan(),
|
| 1305 |
+
v_threshold=common_thr,
|
| 1306 |
+
backend=backend,
|
| 1307 |
+
)
|
| 1308 |
+
|
| 1309 |
+
self.blocks = nn.ModuleList(
|
| 1310 |
+
[
|
| 1311 |
+
Block(
|
| 1312 |
+
length=max_length,
|
| 1313 |
+
tau=tau,
|
| 1314 |
+
common_thr=common_thr,
|
| 1315 |
+
dim=dim,
|
| 1316 |
+
d_ff=self.d_ff,
|
| 1317 |
+
heads=heads,
|
| 1318 |
+
qkv_bias=qkv_bias,
|
| 1319 |
+
qk_scale=qk_scale,
|
| 1320 |
+
)
|
| 1321 |
+
for _ in range(depths)
|
| 1322 |
+
]
|
| 1323 |
+
)
|
| 1324 |
+
|
| 1325 |
+
self.apply(self._init_weights)
|
| 1326 |
+
|
| 1327 |
+
self.fc = nn.Linear(args.seq_length*dim, args.pre_length*args.feature_size)
|
| 1328 |
+
|
| 1329 |
+
def _init_weights(self, m):
|
| 1330 |
+
if isinstance(m, nn.Linear):
|
| 1331 |
+
nn.init.normal_(m.weight, std=0.02)
|
| 1332 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 1333 |
+
nn.init.constant_(m.bias, 0.0)
|
| 1334 |
+
elif isinstance(m, nn.LayerNorm):
|
| 1335 |
+
nn.init.constant_(m.weight, 1.0)
|
| 1336 |
+
nn.init.constant_(m.bias, 0.0)
|
| 1337 |
+
|
| 1338 |
+
def forward(self, x):
|
| 1339 |
+
functional.reset_net(self)
|
| 1340 |
+
|
| 1341 |
+
if self.args.normalize:
|
| 1342 |
+
|
| 1343 |
+
mean = x.mean(dim=1, keepdim=True).detach() # shape [B, 1, D]
|
| 1344 |
+
x = x - mean
|
| 1345 |
+
|
| 1346 |
+
std = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
| 1347 |
+
x = x / std
|
| 1348 |
+
|
| 1349 |
+
x = self.temporal_encoder(x) # B L C -> T B C L
|
| 1350 |
+
x = x.transpose(-2, -1) # T B L C
|
| 1351 |
+
if self.pe_type != "none":
|
| 1352 |
+
x = self.pe(x) # T B L C'
|
| 1353 |
+
T, B, L, _ = x.shape
|
| 1354 |
+
x = self.encoder(x.flatten(0, 1)).reshape(T, B, L, -1) # T B L D
|
| 1355 |
+
x = self.init_lif(x)
|
| 1356 |
+
|
| 1357 |
+
for blk in self.blocks:
|
| 1358 |
+
x = blk(x) # T B L D
|
| 1359 |
+
out = x.mean(0) # B L D
|
| 1360 |
+
out = self.fc(out.flatten(-2, -1)).reshape(-1, self.pre_length, self.feature_size) # B D L -> B L D
|
| 1361 |
+
if self.args.normalize:
|
| 1362 |
+
out = out * std + mean # denormalization
|
| 1363 |
+
aux = {'gate_l0': torch.tensor(0.0, device=out.device)} # placeholder
|
| 1364 |
+
return out, aux # B D L -> B L D
|
| 1365 |
+
|
model/TS_GRU.py
ADDED
|
@@ -0,0 +1,640 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Callable
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from spikingjelly.activation_based import surrogate as sj_surrogate
|
| 4 |
+
from snntorch import utils
|
| 5 |
+
import snntorch as snn
|
| 6 |
+
from snntorch import surrogate
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
import numpy as np
|
| 10 |
+
import copy
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import math
|
| 13 |
+
from abc import abstractmethod
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@torch.jit.script
|
| 17 |
+
def heaviside(x: torch.Tensor):
|
| 18 |
+
return (x >= 0).to(x)
|
| 19 |
+
|
| 20 |
+
@torch.jit.script
|
| 21 |
+
def atan_backward(grad_output: torch.Tensor, x: torch.Tensor, alpha: float):
|
| 22 |
+
|
| 23 |
+
return alpha / 2 / (1 + (math.pi / 2 * alpha * x).pow_(2)) * grad_output, None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SG(torch.autograd.Function):
|
| 27 |
+
@staticmethod
|
| 28 |
+
def forward(ctx, x, alpha=2.0):
|
| 29 |
+
if x.requires_grad:
|
| 30 |
+
ctx.save_for_backward(x)
|
| 31 |
+
ctx.alpha = alpha
|
| 32 |
+
return heaviside(x)
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def backward(ctx, grad_output):
|
| 36 |
+
return atan_backward(grad_output, ctx.saved_tensors[0], ctx.alpha)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MemoryModule(nn.Module):
|
| 40 |
+
def __init__(self):
|
| 41 |
+
"""
|
| 42 |
+
* :ref:`API in English <MemoryModule.__init__-en>`
|
| 43 |
+
|
| 44 |
+
.. _MemoryModule.__init__-cn:
|
| 45 |
+
|
| 46 |
+
``MemoryModule`` 是SpikingJelly中所有有状态(记忆)模块的基类。
|
| 47 |
+
|
| 48 |
+
* :ref:`中文API <MemoryModule.__init__-cn>`
|
| 49 |
+
|
| 50 |
+
.. _MemoryModule.__init__-en:
|
| 51 |
+
|
| 52 |
+
``MemoryModule`` is the base class of all stateful modules in SpikingJelly.
|
| 53 |
+
|
| 54 |
+
"""
|
| 55 |
+
super().__init__()
|
| 56 |
+
self._memories = {}
|
| 57 |
+
self._memories_rv = {}
|
| 58 |
+
|
| 59 |
+
def register_memory(self, name: str, value):
|
| 60 |
+
"""
|
| 61 |
+
* :ref:`API in English <MemoryModule.register_memory-en>`
|
| 62 |
+
|
| 63 |
+
.. _MemoryModule.register_memory-cn:
|
| 64 |
+
|
| 65 |
+
:param name: 变量的名字
|
| 66 |
+
:type name: str
|
| 67 |
+
:param value: 变量的值
|
| 68 |
+
:type value: any
|
| 69 |
+
|
| 70 |
+
将变量存入用于保存有状态变量(例如脉冲神经元的膜电位)的字典中。这个变量的重置值会被设置为 ``value``。每次调用 ``self.reset()``
|
| 71 |
+
函数后, ``self.name`` 都会被重置为 ``value``。
|
| 72 |
+
|
| 73 |
+
* :ref:`中文API <MemoryModule.register_memory-cn>`
|
| 74 |
+
|
| 75 |
+
.. _MemoryModule.register_memory-en:
|
| 76 |
+
|
| 77 |
+
:param name: variable's name
|
| 78 |
+
:type name: str
|
| 79 |
+
:param value: variable's value
|
| 80 |
+
:type value: any
|
| 81 |
+
|
| 82 |
+
Register the variable to memory dict, which saves stateful variables (e.g., the membrane potential of a
|
| 83 |
+
spiking neuron). The reset value of this variable will be ``value``. ``self.name`` will be set to ``value`` after
|
| 84 |
+
each calling of ``self.reset()``.
|
| 85 |
+
|
| 86 |
+
"""
|
| 87 |
+
assert not hasattr(self, name), f'{name} has been set as a member variable!'
|
| 88 |
+
self._memories[name] = value
|
| 89 |
+
self.set_reset_value(name, value)
|
| 90 |
+
|
| 91 |
+
def reset(self):
|
| 92 |
+
"""
|
| 93 |
+
* :ref:`API in English <MemoryModule.reset-en>`
|
| 94 |
+
|
| 95 |
+
.. _MemoryModule.reset-cn:
|
| 96 |
+
|
| 97 |
+
重置所有有状态变量为默认值。
|
| 98 |
+
|
| 99 |
+
* :ref:`中文API <MemoryModule.reset-cn>`
|
| 100 |
+
|
| 101 |
+
.. _MemoryModule.reset-en:
|
| 102 |
+
|
| 103 |
+
Reset all stateful variables to their default values.
|
| 104 |
+
"""
|
| 105 |
+
for key in self._memories.keys():
|
| 106 |
+
self._memories[key] = copy.deepcopy(self._memories_rv[key])
|
| 107 |
+
|
| 108 |
+
def set_reset_value(self, name: str, value):
|
| 109 |
+
self._memories_rv[name] = copy.deepcopy(value)
|
| 110 |
+
|
| 111 |
+
def __getattr__(self, name: str):
|
| 112 |
+
if '_memories' in self.__dict__:
|
| 113 |
+
memories = self.__dict__['_memories']
|
| 114 |
+
if name in memories:
|
| 115 |
+
return memories[name]
|
| 116 |
+
|
| 117 |
+
return super().__getattr__(name)
|
| 118 |
+
|
| 119 |
+
def __setattr__(self, name: str, value) -> None:
|
| 120 |
+
_memories = self.__dict__.get('_memories')
|
| 121 |
+
if _memories is not None and name in _memories:
|
| 122 |
+
_memories[name] = value
|
| 123 |
+
else:
|
| 124 |
+
super().__setattr__(name, value)
|
| 125 |
+
|
| 126 |
+
def __delattr__(self, name):
|
| 127 |
+
if name in self._memories:
|
| 128 |
+
del self._memories[name]
|
| 129 |
+
del self._memories_rv[name]
|
| 130 |
+
else:
|
| 131 |
+
return super().__delattr__(name)
|
| 132 |
+
|
| 133 |
+
def __dir__(self):
|
| 134 |
+
module_attrs = dir(self.__class__)
|
| 135 |
+
attrs = list(self.__dict__.keys())
|
| 136 |
+
parameters = list(self._parameters.keys())
|
| 137 |
+
modules = list(self._modules.keys())
|
| 138 |
+
buffers = list(self._buffers.keys())
|
| 139 |
+
memories = list(self._memories.keys())
|
| 140 |
+
keys = module_attrs + attrs + parameters + modules + buffers + memories
|
| 141 |
+
keys = [key for key in keys if not key[0].isdigit()]
|
| 142 |
+
|
| 143 |
+
return sorted(keys)
|
| 144 |
+
|
| 145 |
+
def memories(self):
|
| 146 |
+
"""
|
| 147 |
+
* :ref:`API in English <MemoryModule.memories-en>`
|
| 148 |
+
|
| 149 |
+
.. _MemoryModule.memories-cn:
|
| 150 |
+
|
| 151 |
+
:return: 返回一个所有状态变量的迭代器
|
| 152 |
+
:rtype: Iterator
|
| 153 |
+
|
| 154 |
+
* :ref:`中文API <MemoryModule.memories-cn>`
|
| 155 |
+
|
| 156 |
+
.. _MemoryModule.memories-en:
|
| 157 |
+
|
| 158 |
+
:return: an iterator over all stateful variables
|
| 159 |
+
:rtype: Iterator
|
| 160 |
+
"""
|
| 161 |
+
for name, value in self._memories.items():
|
| 162 |
+
yield value
|
| 163 |
+
|
| 164 |
+
def named_memories(self):
|
| 165 |
+
"""
|
| 166 |
+
* :ref:`API in English <MemoryModule.named_memories-en>`
|
| 167 |
+
|
| 168 |
+
.. _MemoryModule.named_memories-cn:
|
| 169 |
+
|
| 170 |
+
:return: 返回一个所有状态变量及其名称的迭代器
|
| 171 |
+
:rtype: Iterator
|
| 172 |
+
|
| 173 |
+
* :ref:`中文API <MemoryModule.named_memories-cn>`
|
| 174 |
+
|
| 175 |
+
.. _MemoryModule.named_memories-en:
|
| 176 |
+
|
| 177 |
+
:return: an iterator over all stateful variables and their names
|
| 178 |
+
:rtype: Iterator
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
for name, value in self._memories.items():
|
| 182 |
+
yield name, value
|
| 183 |
+
|
| 184 |
+
def detach(self):
|
| 185 |
+
"""
|
| 186 |
+
* :ref:`API in English <MemoryModule.detach-en>`
|
| 187 |
+
|
| 188 |
+
.. _MemoryModule.detach-cn:
|
| 189 |
+
|
| 190 |
+
从计算图中分离所有有状态变量。
|
| 191 |
+
|
| 192 |
+
.. tip::
|
| 193 |
+
|
| 194 |
+
可以使用这个函数实现TBPTT(Truncated Back Propagation Through Time)。
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
* :ref:`中文API <MemoryModule.detach-cn>`
|
| 198 |
+
|
| 199 |
+
.. _MemoryModule.detach-en:
|
| 200 |
+
|
| 201 |
+
Detach all stateful variables.
|
| 202 |
+
|
| 203 |
+
.. admonition:: Tip
|
| 204 |
+
:class: tip
|
| 205 |
+
|
| 206 |
+
We can use this function to implement TBPTT(Truncated Back Propagation Through Time).
|
| 207 |
+
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
for key in self._memories.keys():
|
| 211 |
+
if isinstance(self._memories[key], torch.Tensor):
|
| 212 |
+
self._memories[key].detach_()
|
| 213 |
+
|
| 214 |
+
def _apply(self, fn):
|
| 215 |
+
for key, value in self._memories.items():
|
| 216 |
+
if isinstance(value, torch.Tensor):
|
| 217 |
+
self._memories[key] = fn(value)
|
| 218 |
+
return super()._apply(fn)
|
| 219 |
+
|
| 220 |
+
def _replicate_for_data_parallel(self):
|
| 221 |
+
replica = super()._replicate_for_data_parallel()
|
| 222 |
+
replica._memories = self._memories.copy()
|
| 223 |
+
return replica
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class StepModule:
|
| 227 |
+
def supported_step_mode(self):
|
| 228 |
+
"""
|
| 229 |
+
* :ref:`API in English <StepModule.supported_step_mode-en>`
|
| 230 |
+
.. _StepModule.supported_step_mode-cn:
|
| 231 |
+
:return: 包含支持的后端的tuple
|
| 232 |
+
:rtype: tuple[str]
|
| 233 |
+
返回此模块支持的步进模式。
|
| 234 |
+
* :ref:`中文 API <StepModule.supported_step_mode-cn>`
|
| 235 |
+
.. _StepModule.supported_step_mode-en:
|
| 236 |
+
:return: a tuple that contains the supported backends
|
| 237 |
+
:rtype: tuple[str]
|
| 238 |
+
"""
|
| 239 |
+
return ('s', 'm')
|
| 240 |
+
|
| 241 |
+
@property
|
| 242 |
+
def step_mode(self):
|
| 243 |
+
"""
|
| 244 |
+
* :ref:`API in English <StepModule.step_mode-en>`
|
| 245 |
+
.. _StepModule.step_mode-cn:
|
| 246 |
+
:return: 模块当前使用的步进模式
|
| 247 |
+
:rtype: str
|
| 248 |
+
* :ref:`中文 API <StepModule.step_mode-cn>`
|
| 249 |
+
.. _StepModule.step_mode-en:
|
| 250 |
+
:return: the current step mode of this module
|
| 251 |
+
:rtype: str
|
| 252 |
+
"""
|
| 253 |
+
return self._step_mode
|
| 254 |
+
|
| 255 |
+
@step_mode.setter
|
| 256 |
+
def step_mode(self, value: str):
|
| 257 |
+
"""
|
| 258 |
+
* :ref:`API in English <StepModule.step_mode-setter-en>`
|
| 259 |
+
.. _StepModule.step_mode-setter-cn:
|
| 260 |
+
:param value: 步进模式
|
| 261 |
+
:type value: str
|
| 262 |
+
将本模块的步进模式设置为 ``value``
|
| 263 |
+
* :ref:`中文 API <StepModule.step_mode-setter-cn>`
|
| 264 |
+
.. _StepModule.step_mode-setter-en:
|
| 265 |
+
:param value: the step mode
|
| 266 |
+
:type value: str
|
| 267 |
+
Set the step mode of this module to be ``value``
|
| 268 |
+
"""
|
| 269 |
+
if value not in self.supported_step_mode():
|
| 270 |
+
raise ValueError(f'step_mode can only be {self.supported_step_mode()}, but got "{value}"!')
|
| 271 |
+
self._step_mode = value
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class BaseNode(MemoryModule):
|
| 276 |
+
def __init__(self,
|
| 277 |
+
v_threshold: float = 1.,
|
| 278 |
+
v_reset: float = 0.,
|
| 279 |
+
surrogate_function: Callable = None,
|
| 280 |
+
detach_reset: bool = False,
|
| 281 |
+
step_mode='s', backend='torch',
|
| 282 |
+
store_v_seq: bool = True):
|
| 283 |
+
|
| 284 |
+
assert isinstance(v_reset, float) or v_reset is None
|
| 285 |
+
assert isinstance(v_threshold, float)
|
| 286 |
+
assert isinstance(detach_reset, bool)
|
| 287 |
+
super().__init__()
|
| 288 |
+
|
| 289 |
+
if v_reset is None:
|
| 290 |
+
self.register_memory('v', 0.)
|
| 291 |
+
self.register_memory('v_s', 0.)
|
| 292 |
+
else:
|
| 293 |
+
self.register_memory('v', v_reset)
|
| 294 |
+
|
| 295 |
+
self.v_threshold = v_threshold
|
| 296 |
+
|
| 297 |
+
self.v_reset = v_reset
|
| 298 |
+
self.detach_reset = detach_reset
|
| 299 |
+
self.surrogate_function = surrogate_function
|
| 300 |
+
|
| 301 |
+
self.step_mode = step_mode
|
| 302 |
+
self.backend = backend
|
| 303 |
+
|
| 304 |
+
self.store_v_seq = store_v_seq
|
| 305 |
+
self.alpha_s = torch.nn.Parameter(torch.tensor(0.5, dtype=torch.float))
|
| 306 |
+
self.alpha_l = torch.nn.Parameter(torch.tensor(0.5, dtype=torch.float))
|
| 307 |
+
|
| 308 |
+
@property
|
| 309 |
+
def store_v_seq(self):
|
| 310 |
+
return self._store_v_seq
|
| 311 |
+
|
| 312 |
+
@store_v_seq.setter
|
| 313 |
+
def store_v_seq(self, value: bool):
|
| 314 |
+
self._store_v_seq = value
|
| 315 |
+
if value:
|
| 316 |
+
if not hasattr(self, 'v_seq'):
|
| 317 |
+
self.register_memory('v_seq', None)
|
| 318 |
+
|
| 319 |
+
@staticmethod
|
| 320 |
+
@torch.jit.script
|
| 321 |
+
def jit_hard_reset(v: torch.Tensor, spike: torch.Tensor, v_reset: float):
|
| 322 |
+
v = (1. - spike) * v + spike * v_reset
|
| 323 |
+
|
| 324 |
+
return v
|
| 325 |
+
|
| 326 |
+
@staticmethod
|
| 327 |
+
@torch.jit.script
|
| 328 |
+
def jit_soft_reset(v: torch.Tensor, spike: torch.Tensor, v_threshold: float):
|
| 329 |
+
v = v - spike * v_threshold
|
| 330 |
+
return v
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
@abstractmethod
|
| 334 |
+
def neuronal_charge(self, x: torch.Tensor):
|
| 335 |
+
raise NotImplementedError
|
| 336 |
+
|
| 337 |
+
def neuronal_fire(self):
|
| 338 |
+
return self.surrogate_function(self.v - self.v_threshold, 2.0)
|
| 339 |
+
|
| 340 |
+
def sl_neuronal_fire(self):
|
| 341 |
+
s_s = self.surrogate_function(self.v - self.v_threshold, 2.0)
|
| 342 |
+
s_l = self.surrogate_function(self.v_s - self.v_threshold, 2.0)
|
| 343 |
+
return s_s, s_l
|
| 344 |
+
|
| 345 |
+
def extra_repr(self):
|
| 346 |
+
return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, step_mode={self.step_mode}, backend={self.backend}'
|
| 347 |
+
|
| 348 |
+
def single_step_forward(self, x: torch.Tensor):
|
| 349 |
+
self.v_float_to_tensor(x)
|
| 350 |
+
self.neuronal_charge(x)
|
| 351 |
+
s_s, s_l = self.sl_neuronal_fire()
|
| 352 |
+
spike = self.alpha_s * s_s + self.alpha_l * s_l
|
| 353 |
+
self.neuronal_reset(s_s, s_l)
|
| 354 |
+
|
| 355 |
+
return spike
|
| 356 |
+
|
| 357 |
+
def multi_step_forward(self, x_seq: torch.Tensor):
|
| 358 |
+
|
| 359 |
+
T = x_seq.shape[-1]
|
| 360 |
+
y_seq = []
|
| 361 |
+
if self.store_v_seq:
|
| 362 |
+
v_seq = []
|
| 363 |
+
for t in range(T):
|
| 364 |
+
y = self.single_step_forward(x_seq[:, t])
|
| 365 |
+
y_seq.append(y)
|
| 366 |
+
if self.store_v_seq:
|
| 367 |
+
v_seq.append(self.v)
|
| 368 |
+
if self.store_v_seq:
|
| 369 |
+
self.v_seq = torch.stack(v_seq)
|
| 370 |
+
outputs = torch.stack(y_seq, dim=0).permute(1, 0)
|
| 371 |
+
|
| 372 |
+
return outputs
|
| 373 |
+
|
| 374 |
+
def v_float_to_tensor(self, x: torch.Tensor):
|
| 375 |
+
if isinstance(self.v, float):
|
| 376 |
+
v_init = self.v
|
| 377 |
+
self.v = torch.full_like(x.data, v_init)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class TSLIFNode(BaseNode):
|
| 381 |
+
def __init__(self,
|
| 382 |
+
v_threshold=1.0,
|
| 383 |
+
v_reset=0.,
|
| 384 |
+
surrogate_function: Callable = None,
|
| 385 |
+
detach_reset=False,
|
| 386 |
+
hard_reset=False,
|
| 387 |
+
step_mode='s',
|
| 388 |
+
k=2,
|
| 389 |
+
decay_factor: torch.Tensor = torch.tensor([0.8, 0.2, 0.3, 0.7], dtype=torch.float),
|
| 390 |
+
gamma: float = 0.5):
|
| 391 |
+
super(TSLIFNode, self).__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode)
|
| 392 |
+
self.k = k
|
| 393 |
+
for i in range(1, self.k + 1):
|
| 394 |
+
self.register_memory('v' + str(i), 0.)
|
| 395 |
+
self.names = self._memories
|
| 396 |
+
self.hard_reset = hard_reset
|
| 397 |
+
self.gamma = gamma
|
| 398 |
+
self.decay_factor = torch.nn.Parameter(decay_factor)
|
| 399 |
+
self.kk = torch.nn.Parameter(torch.tensor([0.8], dtype=torch.float))
|
| 400 |
+
self.yy = torch.nn.Parameter(torch.tensor([0.1], dtype=torch.float))
|
| 401 |
+
|
| 402 |
+
@property
|
| 403 |
+
def supported_backends(self):
|
| 404 |
+
if self.step_mode == 's':
|
| 405 |
+
return ('torch',)
|
| 406 |
+
elif self.step_mode == 'm':
|
| 407 |
+
return ('torch', 'cupy')
|
| 408 |
+
else:
|
| 409 |
+
raise ValueError(self.step_mode)
|
| 410 |
+
|
| 411 |
+
def neuronal_charge(self, x: torch.Tensor):
|
| 412 |
+
self.names['v1'] = self.decay_factor[0] * self.names['v1'] + self.decay_factor[1] * x - self.yy * self.names['v2']
|
| 413 |
+
self.names['v2'] = self.decay_factor[2] * self.names['v2'] + self.decay_factor[3] * x - self.kk * self.names['v1']
|
| 414 |
+
self.v = self.names['v2']
|
| 415 |
+
self.v_s = self.names['v1']
|
| 416 |
+
|
| 417 |
+
def neuronal_reset(self, spike_s, spike_l):
|
| 418 |
+
if not self.hard_reset:
|
| 419 |
+
self.names['v1'] = self.jit_soft_reset(self.names['v1'], spike_l, self.gamma)
|
| 420 |
+
self.names['v2'] = self.jit_soft_reset(self.names['v2'], spike_s, self.v_threshold)
|
| 421 |
+
else:
|
| 422 |
+
for i in range(2, self.k + 1):
|
| 423 |
+
self.names['v' + str(i)] = self.jit_hard_reset(self.names['v' + str(i)], spike_s, self.v_reset)
|
| 424 |
+
|
| 425 |
+
def forward(self, x: torch.Tensor):
|
| 426 |
+
return super().single_step_forward(x)
|
| 427 |
+
def extra_repr(self):
|
| 428 |
+
return f"v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, " \
|
| 429 |
+
f"hard_reset={self.hard_reset}, " \
|
| 430 |
+
f"gamma={self.gamma}, k={self.k}, step_mode={self.step_mode}, backend={self.backend}"
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
class GRUCell(nn.Module):
|
| 436 |
+
def __init__(
|
| 437 |
+
self,
|
| 438 |
+
input_size: int,
|
| 439 |
+
hidden_size: int,
|
| 440 |
+
num_steps: int = 4,
|
| 441 |
+
grad_slope: float = 25.0,
|
| 442 |
+
beta: float = 0.99,
|
| 443 |
+
output_mems: bool = False,
|
| 444 |
+
):
|
| 445 |
+
super().__init__()
|
| 446 |
+
self.spike_grad = surrogate.atan(alpha=2.0)
|
| 447 |
+
self.input_size = input_size
|
| 448 |
+
self.num_steps = num_steps
|
| 449 |
+
self.hidden_size = hidden_size
|
| 450 |
+
self.beta = beta
|
| 451 |
+
self.full_rec = output_mems
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
self.linear_ih = nn.Linear(input_size, 3 * hidden_size)
|
| 455 |
+
self.linear_hh = nn.Linear(hidden_size, 3 * hidden_size)
|
| 456 |
+
self.surrogate_function1 = sj_surrogate.ATan()
|
| 457 |
+
|
| 458 |
+
self.tslif = TSLIFNode(
|
| 459 |
+
surrogate_function=SG.apply
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
def forward(self, inputs):
|
| 463 |
+
if inputs.size(-1) == self.input_size:
|
| 464 |
+
h = torch.zeros(
|
| 465 |
+
size=[inputs.shape[0], self.hidden_size],
|
| 466 |
+
dtype=torch.float,
|
| 467 |
+
device=inputs.device,
|
| 468 |
+
)
|
| 469 |
+
y_ih = torch.split(self.linear_ih(inputs), self.hidden_size, dim=1)
|
| 470 |
+
y_hh = torch.split(self.linear_hh(h), self.hidden_size, dim=1)
|
| 471 |
+
r = self.surrogate_function1(y_ih[0] + y_hh[0])
|
| 472 |
+
z = self.surrogate_function1(y_ih[1] + y_hh[1])
|
| 473 |
+
n = self.surrogate_function1(y_ih[2] + r * y_hh[2])
|
| 474 |
+
h = (1.0 - z) * n + z * h
|
| 475 |
+
cur = h
|
| 476 |
+
elif inputs.size(-1) == self.num_steps and inputs.size(-2) == self.input_size:
|
| 477 |
+
inputs = inputs.transpose(-1, -2) # BC, T, H
|
| 478 |
+
h = torch.zeros(
|
| 479 |
+
size=[inputs.shape[0], self.hidden_size, self.num_steps],
|
| 480 |
+
dtype=torch.float,
|
| 481 |
+
device=inputs.device,
|
| 482 |
+
)
|
| 483 |
+
y_ih = torch.split(
|
| 484 |
+
self.linear_ih(inputs).transpose(-1, -2), self.hidden_size, dim=1
|
| 485 |
+
)
|
| 486 |
+
y_hh = torch.split(
|
| 487 |
+
self.linear_hh(h.transpose(-1, -2)).transpose(-1, -2),
|
| 488 |
+
self.hidden_size,
|
| 489 |
+
dim=1,
|
| 490 |
+
)
|
| 491 |
+
r = self.surrogate_function1(y_ih[0] + y_hh[0])
|
| 492 |
+
z = self.surrogate_function1(y_ih[1] + y_hh[1])
|
| 493 |
+
n = self.surrogate_function1(y_ih[2] + r * y_hh[2])
|
| 494 |
+
h = (1.0 - z) * n + z * h
|
| 495 |
+
cur = h
|
| 496 |
+
static = False
|
| 497 |
+
else:
|
| 498 |
+
raise ValueError(
|
| 499 |
+
f"Input size mismatch! Got {inputs.size()} but expected "
|
| 500 |
+
f"(..., {self.input_size}, {self.num_steps}) or (..., {self.input_size})"
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
spks = self.tslif(cur)
|
| 504 |
+
return spks
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
class DeltaEncoder(nn.Module):
|
| 508 |
+
def __init__(self, output_size: int):
|
| 509 |
+
super().__init__()
|
| 510 |
+
self.norm = nn.BatchNorm2d(1)
|
| 511 |
+
self.enc = nn.Linear(1, output_size)
|
| 512 |
+
self.lif = snn.Leaky(
|
| 513 |
+
beta=0.99, spike_grad=SG.apply, init_hidden=True, output=False
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
def forward(self, inputs: torch.Tensor):
|
| 517 |
+
# inputs: batch, L, C
|
| 518 |
+
delta = torch.zeros_like(inputs)
|
| 519 |
+
delta[:, 1:] = inputs[:, 1:, :] - inputs[:, :-1, :]
|
| 520 |
+
delta = delta.unsqueeze(1).permute(0, 1, 3, 2) # batch, 1, C, L
|
| 521 |
+
delta = self.norm(delta)
|
| 522 |
+
delta = delta.permute(0, 2, 3, 1) # batch, C, L, 1
|
| 523 |
+
enc = self.enc(delta) # batch, C, L, output_size
|
| 524 |
+
enc = enc.permute(0, 3, 1, 2) # batch, output_size, C, L
|
| 525 |
+
spks = self.lif(enc)
|
| 526 |
+
return spks
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
class ConvEncoder(nn.Module):
|
| 530 |
+
def __init__(self, output_size: int, kernel_size: int = 3):
|
| 531 |
+
super().__init__()
|
| 532 |
+
self.encoder = nn.Sequential(
|
| 533 |
+
nn.Conv2d(
|
| 534 |
+
in_channels=1,
|
| 535 |
+
out_channels=output_size,
|
| 536 |
+
kernel_size=(1, kernel_size),
|
| 537 |
+
stride=1,
|
| 538 |
+
padding=(0, kernel_size // 2),
|
| 539 |
+
),
|
| 540 |
+
nn.BatchNorm2d(output_size),
|
| 541 |
+
)
|
| 542 |
+
self.lif = snn.Leaky(
|
| 543 |
+
beta=0.99,
|
| 544 |
+
spike_grad=surrogate.atan(alpha=2.0),
|
| 545 |
+
init_hidden=True,
|
| 546 |
+
output=False,
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
def forward(self, inputs: torch.Tensor):
|
| 550 |
+
# inputs: batch, L, C
|
| 551 |
+
inputs = inputs.permute(0, 2, 1).unsqueeze(1) # batch, 1, C, L
|
| 552 |
+
enc = self.encoder(inputs) # batch, output_size, C, L
|
| 553 |
+
spks = self.lif(enc)
|
| 554 |
+
return spks
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
class TSGRU(nn.Module):
|
| 558 |
+
def __init__(
|
| 559 |
+
self,
|
| 560 |
+
args,
|
| 561 |
+
hidden_size: int,
|
| 562 |
+
layers: int = 1,
|
| 563 |
+
num_steps: int = 50,
|
| 564 |
+
grad_slope: float = 25.0,
|
| 565 |
+
input_size: Optional[int] = None,
|
| 566 |
+
max_length: Optional[int] = None,
|
| 567 |
+
weight_file: Optional[Path] = None,
|
| 568 |
+
encoder_type: Optional[str] = "conv",
|
| 569 |
+
):
|
| 570 |
+
super().__init__()
|
| 571 |
+
|
| 572 |
+
self.hidden_size = args.hidden_size
|
| 573 |
+
self.num_steps = args.T
|
| 574 |
+
self.input_size = args.feature_size
|
| 575 |
+
self.pre_length = args.pre_length
|
| 576 |
+
self.layers = args.blocks
|
| 577 |
+
self.args = args
|
| 578 |
+
|
| 579 |
+
if encoder_type == "conv":
|
| 580 |
+
self.encoder = ConvEncoder(self.hidden_size)
|
| 581 |
+
elif encoder_type == "delta":
|
| 582 |
+
self.encoder = DeltaEncoder(self.hidden_size)
|
| 583 |
+
else:
|
| 584 |
+
raise ValueError(f"Unknown encoder type {encoder_type}")
|
| 585 |
+
|
| 586 |
+
self.net = nn.Sequential(
|
| 587 |
+
*[
|
| 588 |
+
GRUCell(
|
| 589 |
+
self.hidden_size,
|
| 590 |
+
self.hidden_size,
|
| 591 |
+
num_steps=self.num_steps,
|
| 592 |
+
grad_slope=grad_slope,
|
| 593 |
+
output_mems=(i == self.layers - 1),
|
| 594 |
+
)
|
| 595 |
+
for i in range(self.layers)
|
| 596 |
+
]
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
self.__output_size = self.hidden_size
|
| 600 |
+
self.fc = nn.Linear(self.__output_size, self.pre_length)
|
| 601 |
+
|
| 602 |
+
self.to('cuda:0')
|
| 603 |
+
|
| 604 |
+
def forward(self, inputs: torch.Tensor):
|
| 605 |
+
|
| 606 |
+
utils.reset(self.encoder)
|
| 607 |
+
for layer in self.net:
|
| 608 |
+
utils.reset(layer)
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
bs, length, c_num = inputs.size()
|
| 612 |
+
|
| 613 |
+
if self.args.normalize:
|
| 614 |
+
|
| 615 |
+
mean = inputs.mean(dim=1, keepdim=True).detach() # shape [B, 1, D]
|
| 616 |
+
inputs = inputs - mean
|
| 617 |
+
|
| 618 |
+
std = torch.sqrt(torch.var(inputs, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
| 619 |
+
inputs = inputs / std
|
| 620 |
+
|
| 621 |
+
h = self.encoder(inputs)
|
| 622 |
+
hidden_size = h.size(1)
|
| 623 |
+
h = h.permute(0, 2, 3, 1).reshape(bs * c_num, length, hidden_size) # (BC, L, H)
|
| 624 |
+
|
| 625 |
+
for i in range(length):
|
| 626 |
+
spks = self.net(h[:, i, :])
|
| 627 |
+
|
| 628 |
+
spks = spks.reshape(bs, c_num * hidden_size, -1) # B, CH, Time Step
|
| 629 |
+
|
| 630 |
+
spks = spks[:, :, -1] # aggregate over time dimension shape, (B, CH)
|
| 631 |
+
preds = self.fc(spks.view(bs, c_num, -1)).squeeze(-1) # B, O, C
|
| 632 |
+
preds = preds.permute(0, 2, 1).contiguous()
|
| 633 |
+
if self.args.normalize:
|
| 634 |
+
preds = preds * std + mean # denormalize
|
| 635 |
+
aux = {'gate_l0': torch.tensor(0.0, device=preds.device)} # placeholder
|
| 636 |
+
return preds, aux
|
| 637 |
+
|
| 638 |
+
@property
|
| 639 |
+
def output_size(self):
|
| 640 |
+
return self.__output_size
|
model/TS_TCN.py
ADDED
|
@@ -0,0 +1,1030 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Callable
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.utils import weight_norm
|
| 6 |
+
import snntorch as snn
|
| 7 |
+
from snntorch import surrogate
|
| 8 |
+
from snntorch import utils
|
| 9 |
+
import numpy as np
|
| 10 |
+
import math
|
| 11 |
+
import copy
|
| 12 |
+
from spikingjelly.activation_based import surrogate, neuron
|
| 13 |
+
from abc import abstractmethod
|
| 14 |
+
import warnings
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
surrogate.atan = lambda alpha=2.0: SG.apply
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Chomp1d(nn.Module):
|
| 22 |
+
def __init__(self, chomp_size):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.chomp_size = chomp_size
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
return x[:, :, : -self.chomp_size].contiguous()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Chomp2d(nn.Module):
|
| 31 |
+
def __init__(self, chomp_size):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.chomp_size = chomp_size
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
return x[:, :, :, : -self.chomp_size].contiguous()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class RepeatEncoder(nn.Module):
|
| 41 |
+
def __init__(self, output_size: int):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.out_size = output_size
|
| 44 |
+
self.lif = snn.Leaky(
|
| 45 |
+
beta=0.99, spike_grad=surrogate.atan(), init_hidden=True, output=False
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def forward(self, inputs: torch.Tensor):
|
| 49 |
+
# inputs: batch, L, C
|
| 50 |
+
inputs = inputs.repeat(
|
| 51 |
+
tuple([self.out_size] + torch.ones(len(inputs.size()), dtype=int).tolist())
|
| 52 |
+
) # out_size batch L C
|
| 53 |
+
inputs = inputs.permute(1, 0, 3, 2) # batch out_size L C
|
| 54 |
+
spks = self.lif(inputs)
|
| 55 |
+
return spks
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ConvEncoder(nn.Module):
|
| 59 |
+
def __init__(self, output_size: int, kernel_size: int = 3):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.encoder = nn.Sequential(
|
| 62 |
+
nn.Conv2d(
|
| 63 |
+
in_channels=1,
|
| 64 |
+
out_channels=output_size,
|
| 65 |
+
kernel_size=(1, kernel_size),
|
| 66 |
+
stride=1,
|
| 67 |
+
padding=(0, kernel_size // 2),
|
| 68 |
+
),
|
| 69 |
+
nn.BatchNorm2d(output_size),
|
| 70 |
+
)
|
| 71 |
+
self.lif = snn.Leaky(
|
| 72 |
+
beta=0.99,
|
| 73 |
+
spike_grad=surrogate.atan(alpha=2.0),
|
| 74 |
+
init_hidden=True,
|
| 75 |
+
output=False,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def forward(self, inputs: torch.Tensor):
|
| 79 |
+
# inputs: batch, L, C
|
| 80 |
+
inputs = inputs.permute(0, 2, 1).unsqueeze(1) # batch, 1, C, L
|
| 81 |
+
enc = self.encoder(inputs) # batch, output_size, C, L
|
| 82 |
+
spks = self.lif(enc)
|
| 83 |
+
return spks
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class DeltaEncoder(nn.Module):
|
| 90 |
+
def __init__(self, output_size: int):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.norm = nn.BatchNorm2d(1)
|
| 93 |
+
self.enc = nn.Linear(1, output_size)
|
| 94 |
+
self.lif = snn.Leaky(
|
| 95 |
+
beta=0.99, spike_grad=surrogate.atan(), init_hidden=True, output=False
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def forward(self, inputs: torch.Tensor):
|
| 99 |
+
# inputs: batch, L, C
|
| 100 |
+
delta = torch.zeros_like(inputs)
|
| 101 |
+
delta[:, 1:] = inputs[:, 1:, :] - inputs[:, :-1, :]
|
| 102 |
+
delta = delta.unsqueeze(1).permute(0, 1, 3, 2) # batch, 1, C, L
|
| 103 |
+
delta = self.norm(delta)
|
| 104 |
+
delta = delta.permute(0, 2, 3, 1) # batch, C, L, 1
|
| 105 |
+
enc = self.enc(delta) # batch, C, L, output_size
|
| 106 |
+
enc = enc.permute(0, 3, 1, 2) # batch, output_size, C, L
|
| 107 |
+
spks = self.lif(enc)
|
| 108 |
+
return spks
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
SpikeEncoder = {
|
| 112 |
+
"snntorch": {
|
| 113 |
+
"repeat": RepeatEncoder,
|
| 114 |
+
"conv": ConvEncoder,
|
| 115 |
+
"delta": DeltaEncoder,
|
| 116 |
+
},
|
| 117 |
+
"spikingjelly": {
|
| 118 |
+
"repeat": RepeatEncoder,
|
| 119 |
+
"conv": ConvEncoder,
|
| 120 |
+
"delta": DeltaEncoder,
|
| 121 |
+
},
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def generate_ones_and_minus_ones_matrix(rows, cols):
|
| 126 |
+
random_matrix = torch.randint(0, 2, (rows, cols))
|
| 127 |
+
binary_matrix = torch.where(
|
| 128 |
+
random_matrix == 0,
|
| 129 |
+
-1 * torch.ones_like(random_matrix),
|
| 130 |
+
torch.ones_like(random_matrix),
|
| 131 |
+
)
|
| 132 |
+
return binary_matrix.float()
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class RandomPE(nn.Module):
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
d_model,
|
| 139 |
+
pe_mode="concat",
|
| 140 |
+
num_pe_neuron=10,
|
| 141 |
+
neuron_pe_scale=1000.0,
|
| 142 |
+
dropout=0.1,
|
| 143 |
+
num_steps=4,
|
| 144 |
+
):
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.max_len = 5000 # different from windows
|
| 147 |
+
self.pe_mode = pe_mode
|
| 148 |
+
self.neuron_pe_scale = neuron_pe_scale
|
| 149 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 150 |
+
if self.pe_mode == "concat":
|
| 151 |
+
self.num_pe_neuron = copy.deepcopy(num_pe_neuron)
|
| 152 |
+
elif self.pe_mode == "add":
|
| 153 |
+
self.num_pe_neuron = copy.deepcopy(d_model)
|
| 154 |
+
pe = generate_ones_and_minus_ones_matrix(
|
| 155 |
+
self.max_len, self.num_pe_neuron
|
| 156 |
+
) # MaxL, Neur
|
| 157 |
+
pe = pe.unsqueeze(0).transpose(0, 1) # MaxL, 1, Neur
|
| 158 |
+
print("pe.shape: ", pe.shape)
|
| 159 |
+
self.register_buffer("pe", pe)
|
| 160 |
+
|
| 161 |
+
def forward(self, x):
|
| 162 |
+
# T, B, L, D
|
| 163 |
+
T, B, L, _ = x.shape
|
| 164 |
+
x = x.permute(1, 0, 2, 3) # B, T, L, D
|
| 165 |
+
x = x.flatten(1, 2) # B, TL, D
|
| 166 |
+
if self.pe_mode == "concat":
|
| 167 |
+
# tmp: TL, 1, Neur -> TL, B, Neur -> B, TL, Neur
|
| 168 |
+
tmp = self.pe[: x.size(-2), :].repeat(1, B, 1).transpose(0, 1)
|
| 169 |
+
x = torch.concat([x, tmp], dim=-1)
|
| 170 |
+
# print(x.shape) # B, TL, D'
|
| 171 |
+
elif self.pe_mode == "add":
|
| 172 |
+
# [B, TL, D] + [1, TL, Neur]
|
| 173 |
+
x = x + self.pe[: x.size(-2), :].transpose(0, 1)
|
| 174 |
+
# print(x.shape) # B, TL, D
|
| 175 |
+
x = x.transpose(0, 1) # TL, B D
|
| 176 |
+
x = x.reshape(T, L, B, -1) # T, L, B, D
|
| 177 |
+
x = x.permute(0, 2, 1, 3) # T, B, L, D
|
| 178 |
+
return self.dropout(x)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class NeuronPE(nn.Module):
|
| 182 |
+
def __init__(
|
| 183 |
+
self,
|
| 184 |
+
d_model,
|
| 185 |
+
pe_mode="concat",
|
| 186 |
+
num_pe_neuron=10,
|
| 187 |
+
neuron_pe_scale=10000.0,
|
| 188 |
+
dropout=0.1,
|
| 189 |
+
num_steps=4,
|
| 190 |
+
):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.max_len = 50000 # different from windows
|
| 193 |
+
self.pe_mode = pe_mode
|
| 194 |
+
self.neuron_pe_scale = neuron_pe_scale
|
| 195 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 196 |
+
if self.pe_mode == "concat":
|
| 197 |
+
self.num_pe_neuron = copy.deepcopy(num_pe_neuron)
|
| 198 |
+
elif self.pe_mode == "add":
|
| 199 |
+
self.num_pe_neuron = copy.deepcopy(d_model)
|
| 200 |
+
pe = torch.zeros(self.max_len, self.num_pe_neuron) # MaxL, Neur
|
| 201 |
+
position = torch.arange(0, self.max_len, dtype=torch.float).unsqueeze(
|
| 202 |
+
1
|
| 203 |
+
) # MaxL, 1
|
| 204 |
+
div_term = torch.exp(
|
| 205 |
+
torch.arange(0, self.num_pe_neuron, 2).float()
|
| 206 |
+
* (-math.log(neuron_pe_scale) / self.num_pe_neuron)
|
| 207 |
+
)
|
| 208 |
+
div_term_single = torch.exp(
|
| 209 |
+
torch.arange(0, self.num_pe_neuron - 1, 2).float()
|
| 210 |
+
* (-math.log(neuron_pe_scale) / self.num_pe_neuron)
|
| 211 |
+
)
|
| 212 |
+
pe[:, 0::2] = torch.heaviside(
|
| 213 |
+
torch.sin(position * div_term) - 0.8, torch.tensor([1.0])
|
| 214 |
+
)
|
| 215 |
+
pe[:, 1::2] = torch.heaviside(
|
| 216 |
+
torch.cos(position * div_term_single) - 0.8, torch.tensor([1.0])
|
| 217 |
+
)
|
| 218 |
+
pe = pe.unsqueeze(0).transpose(0, 1) # MaxL, 1, Neur
|
| 219 |
+
print("pe.shape: ", pe.shape)
|
| 220 |
+
self.register_buffer("pe", pe)
|
| 221 |
+
|
| 222 |
+
def forward(self, x):
|
| 223 |
+
# T, B, L, D
|
| 224 |
+
T, B, L, _ = x.shape
|
| 225 |
+
x = x.permute(1, 0, 2, 3) # B, T, L, D
|
| 226 |
+
x = x.flatten(1, 2) # B, TL, D
|
| 227 |
+
if self.pe_mode == "concat":
|
| 228 |
+
# tmp: TL, 1, Neur -> TL, B, Neur -> B, TL, Neur
|
| 229 |
+
tmp = self.pe[: x.size(-2), :].repeat(1, B, 1).transpose(0, 1)
|
| 230 |
+
x = torch.concat([x, tmp], dim=-1)
|
| 231 |
+
# print(x.shape) # B, TL, D'
|
| 232 |
+
elif self.pe_mode == "add":
|
| 233 |
+
# [B, TL, D] + [1, TL, Neur]
|
| 234 |
+
# print(self.pe[:x.size(-2), :].shape)
|
| 235 |
+
x = x + self.pe[: x.size(-2), :].transpose(0, 1)
|
| 236 |
+
# print(x.shape) # B, TL, D
|
| 237 |
+
x = x.transpose(0, 1) # TL, B D
|
| 238 |
+
x = x.reshape(T, L, B, -1) # T, L, B, D
|
| 239 |
+
x = x.permute(0, 2, 1, 3) # T, B, L, D
|
| 240 |
+
return self.dropout(x)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class StaticPE(nn.Module):
|
| 244 |
+
r"""Inject some information about the relative or absolute position of the tokens
|
| 245 |
+
in the sequence. The positional encodings have the same dimension as
|
| 246 |
+
the embeddings, so that the two can be summed. Here, we use sine and cosine
|
| 247 |
+
functions of different frequencies.
|
| 248 |
+
.. math::
|
| 249 |
+
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
|
| 250 |
+
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
|
| 251 |
+
\text{where pos is the word position and i is the embed idx)"""
|
| 252 |
+
|
| 253 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
| 254 |
+
super().__init__()
|
| 255 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 256 |
+
pe = torch.zeros(max_len, d_model) # MaxL, D
|
| 257 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # MaxL, 1
|
| 258 |
+
div_term = torch.exp(
|
| 259 |
+
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
|
| 260 |
+
)
|
| 261 |
+
div_term_single = torch.exp(
|
| 262 |
+
torch.arange(0, d_model - 1, 2).float() * (-math.log(10000.0) / d_model)
|
| 263 |
+
)
|
| 264 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 265 |
+
pe[:, 1::2] = torch.cos(position * div_term_single)
|
| 266 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
| 267 |
+
self.register_buffer("pe", pe)
|
| 268 |
+
|
| 269 |
+
def forward(self, x):
|
| 270 |
+
# x: L, TB, D
|
| 271 |
+
x = x + self.pe[: x.size(0), :]
|
| 272 |
+
x = self.dropout(x)
|
| 273 |
+
return x
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class ConvPE(nn.Module):
|
| 277 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000, num_steps=4):
|
| 278 |
+
|
| 279 |
+
super().__init__()
|
| 280 |
+
self.T = num_steps
|
| 281 |
+
self.rpe_conv = nn.Conv1d(
|
| 282 |
+
d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False
|
| 283 |
+
)
|
| 284 |
+
self.rpe_bn = nn.BatchNorm1d(d_model)
|
| 285 |
+
self.rpe_lif = neuron.LIFNode(
|
| 286 |
+
step_mode="m",
|
| 287 |
+
detach_reset=True,
|
| 288 |
+
surrogate_function=surrogate.ATan(),
|
| 289 |
+
v_threshold=1.0,
|
| 290 |
+
)
|
| 291 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 292 |
+
|
| 293 |
+
def forward(self, x):
|
| 294 |
+
# x: L, TB, D
|
| 295 |
+
L, TB, D = x.shape
|
| 296 |
+
x_feat = x.permute(1, 2, 0) # TB, D, L
|
| 297 |
+
x_feat = self.rpe_conv(x_feat) # TB, D, L
|
| 298 |
+
x_feat = (
|
| 299 |
+
self.rpe_bn(x_feat).reshape(self.T, int(TB / self.T), D, L).contiguous()
|
| 300 |
+
) # T, B, D, L
|
| 301 |
+
x_feat = self.rpe_lif(x_feat)
|
| 302 |
+
x_feat = x_feat.flatten(0, 1) # TB, D, L
|
| 303 |
+
x_feat = self.dropout(x_feat) # TB, D, L
|
| 304 |
+
x_feat = x_feat.permute(2, 0, 1) # L, TB, D
|
| 305 |
+
x = x + x_feat
|
| 306 |
+
return x
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class PositionEmbedding(nn.Module):
|
| 310 |
+
def __init__(
|
| 311 |
+
self,
|
| 312 |
+
input_size: int,
|
| 313 |
+
pe_type: str,
|
| 314 |
+
max_len: int = 5000,
|
| 315 |
+
pe_mode: str = "add",
|
| 316 |
+
num_pe_neuron: int = 10,
|
| 317 |
+
neuron_pe_scale: float = 1000.0,
|
| 318 |
+
dropout=0.1,
|
| 319 |
+
num_steps=4,
|
| 320 |
+
):
|
| 321 |
+
super().__init__()
|
| 322 |
+
self.emb_type = pe_type
|
| 323 |
+
if pe_type in ["learn", "none"]:
|
| 324 |
+
self.emb = nn.Embedding(max_len, input_size)
|
| 325 |
+
elif pe_type == "conv":
|
| 326 |
+
self.emb = ConvPE(
|
| 327 |
+
d_model=input_size,
|
| 328 |
+
max_len=max_len,
|
| 329 |
+
dropout=dropout,
|
| 330 |
+
num_steps=num_steps,
|
| 331 |
+
)
|
| 332 |
+
elif pe_type == "static":
|
| 333 |
+
self.emb = StaticPE(d_model=input_size, max_len=max_len, dropout=dropout)
|
| 334 |
+
elif pe_type == "neuron":
|
| 335 |
+
self.emb = NeuronPE(
|
| 336 |
+
d_model=input_size,
|
| 337 |
+
pe_mode=pe_mode,
|
| 338 |
+
num_pe_neuron=num_pe_neuron,
|
| 339 |
+
neuron_pe_scale=neuron_pe_scale,
|
| 340 |
+
dropout=dropout,
|
| 341 |
+
num_steps=num_steps,
|
| 342 |
+
)
|
| 343 |
+
elif pe_type == "random":
|
| 344 |
+
self.emb = RandomPE(
|
| 345 |
+
d_model=input_size,
|
| 346 |
+
pe_mode=pe_mode,
|
| 347 |
+
num_pe_neuron=num_pe_neuron,
|
| 348 |
+
neuron_pe_scale=neuron_pe_scale,
|
| 349 |
+
dropout=dropout,
|
| 350 |
+
num_steps=num_steps,
|
| 351 |
+
)
|
| 352 |
+
else:
|
| 353 |
+
raise ValueError("Unknown embedding type: {}".format(pe_type))
|
| 354 |
+
|
| 355 |
+
def forward(self, x):
|
| 356 |
+
if self.emb_type == "learn":
|
| 357 |
+
# T, B, L, D = x.shape # x: T, B, L, D
|
| 358 |
+
# x = x.flatten(0, 1) # TB, L, D
|
| 359 |
+
tmp = torch.arange(
|
| 360 |
+
end=x.size()[1], device=x.device
|
| 361 |
+
) # [0,1,2,...,L-1], shape: L
|
| 362 |
+
embedding = self.emb(tmp) # shape: L, D
|
| 363 |
+
embedding = embedding.repeat([x.size()[0], 1, 1]) # TB, L, D'
|
| 364 |
+
x = x + embedding
|
| 365 |
+
# x = x.reshape(T, B, L, -1)
|
| 366 |
+
elif self.emb_type in ["static", "conv"]:
|
| 367 |
+
T, B, L, _ = x.shape # x: T, B, L, D
|
| 368 |
+
x = x.flatten(0, 1) # TB, L, D
|
| 369 |
+
x = self.emb(x.transpose(0, 1)).transpose(0, 1) # x: TB, L, D'
|
| 370 |
+
x = x.reshape(T, B, L, -1)
|
| 371 |
+
elif self.emb_type in ["neuron", "random"]:
|
| 372 |
+
T, B, L, _ = x.shape # x: T, B, L, D
|
| 373 |
+
# T, B, L, D
|
| 374 |
+
x = self.emb(x)
|
| 375 |
+
x = x.reshape(T, B, L, -1)
|
| 376 |
+
return x # T, B, L, D'
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
@torch.jit.script
|
| 380 |
+
def heaviside(x: torch.Tensor):
|
| 381 |
+
return (x >= 0).to(x)
|
| 382 |
+
|
| 383 |
+
@torch.jit.script
|
| 384 |
+
def atan_backward(grad_output: torch.Tensor, x: torch.Tensor, alpha: float):
|
| 385 |
+
|
| 386 |
+
return alpha / 2 / (1 + (math.pi / 2 * alpha * x).pow_(2)) * grad_output, None
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
class SG(torch.autograd.Function):
|
| 390 |
+
@staticmethod
|
| 391 |
+
def forward(ctx, x, alpha=2.0):
|
| 392 |
+
if x.requires_grad:
|
| 393 |
+
ctx.save_for_backward(x)
|
| 394 |
+
ctx.alpha = alpha
|
| 395 |
+
return heaviside(x)
|
| 396 |
+
|
| 397 |
+
@staticmethod
|
| 398 |
+
def backward(ctx, grad_output):
|
| 399 |
+
return atan_backward(grad_output, ctx.saved_tensors[0], ctx.alpha)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class MemoryModule(nn.Module):
|
| 403 |
+
def __init__(self):
|
| 404 |
+
"""
|
| 405 |
+
* :ref:`API in English <MemoryModule.__init__-en>`
|
| 406 |
+
|
| 407 |
+
.. _MemoryModule.__init__-cn:
|
| 408 |
+
|
| 409 |
+
``MemoryModule`` 是SpikingJelly中所有有状态(记忆)模块的基类。
|
| 410 |
+
|
| 411 |
+
* :ref:`中文API <MemoryModule.__init__-cn>`
|
| 412 |
+
|
| 413 |
+
.. _MemoryModule.__init__-en:
|
| 414 |
+
|
| 415 |
+
``MemoryModule`` is the base class of all stateful modules in SpikingJelly.
|
| 416 |
+
|
| 417 |
+
"""
|
| 418 |
+
super().__init__()
|
| 419 |
+
self._memories = {}
|
| 420 |
+
self._memories_rv = {}
|
| 421 |
+
|
| 422 |
+
def register_memory(self, name: str, value):
|
| 423 |
+
"""
|
| 424 |
+
* :ref:`API in English <MemoryModule.register_memory-en>`
|
| 425 |
+
|
| 426 |
+
.. _MemoryModule.register_memory-cn:
|
| 427 |
+
|
| 428 |
+
:param name: 变量的名字
|
| 429 |
+
:type name: str
|
| 430 |
+
:param value: 变量的值
|
| 431 |
+
:type value: any
|
| 432 |
+
|
| 433 |
+
将变量存入用于保存有状态变量(例如脉冲神经元的膜电位)的字典中。这个变量的重置值会被设置为 ``value``。每次调用 ``self.reset()``
|
| 434 |
+
函数后, ``self.name`` 都会被重置为 ``value``。
|
| 435 |
+
|
| 436 |
+
* :ref:`中文API <MemoryModule.register_memory-cn>`
|
| 437 |
+
|
| 438 |
+
.. _MemoryModule.register_memory-en:
|
| 439 |
+
|
| 440 |
+
:param name: variable's name
|
| 441 |
+
:type name: str
|
| 442 |
+
:param value: variable's value
|
| 443 |
+
:type value: any
|
| 444 |
+
|
| 445 |
+
Register the variable to memory dict, which saves stateful variables (e.g., the membrane potential of a
|
| 446 |
+
spiking neuron). The reset value of this variable will be ``value``. ``self.name`` will be set to ``value`` after
|
| 447 |
+
each calling of ``self.reset()``.
|
| 448 |
+
|
| 449 |
+
"""
|
| 450 |
+
assert not hasattr(self, name), f'{name} has been set as a member variable!'
|
| 451 |
+
self._memories[name] = value
|
| 452 |
+
self.set_reset_value(name, value)
|
| 453 |
+
|
| 454 |
+
def reset(self):
|
| 455 |
+
"""
|
| 456 |
+
* :ref:`API in English <MemoryModule.reset-en>`
|
| 457 |
+
|
| 458 |
+
.. _MemoryModule.reset-cn:
|
| 459 |
+
|
| 460 |
+
重置所有有状态变量为默认值。
|
| 461 |
+
|
| 462 |
+
* :ref:`中文API <MemoryModule.reset-cn>`
|
| 463 |
+
|
| 464 |
+
.. _MemoryModule.reset-en:
|
| 465 |
+
|
| 466 |
+
Reset all stateful variables to their default values.
|
| 467 |
+
"""
|
| 468 |
+
for key in self._memories.keys():
|
| 469 |
+
self._memories[key] = copy.deepcopy(self._memories_rv[key])
|
| 470 |
+
|
| 471 |
+
def set_reset_value(self, name: str, value):
|
| 472 |
+
self._memories_rv[name] = copy.deepcopy(value)
|
| 473 |
+
|
| 474 |
+
def __getattr__(self, name: str):
|
| 475 |
+
if '_memories' in self.__dict__:
|
| 476 |
+
memories = self.__dict__['_memories']
|
| 477 |
+
if name in memories:
|
| 478 |
+
return memories[name]
|
| 479 |
+
|
| 480 |
+
return super().__getattr__(name)
|
| 481 |
+
|
| 482 |
+
def __setattr__(self, name: str, value) -> None:
|
| 483 |
+
_memories = self.__dict__.get('_memories')
|
| 484 |
+
if _memories is not None and name in _memories:
|
| 485 |
+
_memories[name] = value
|
| 486 |
+
else:
|
| 487 |
+
super().__setattr__(name, value)
|
| 488 |
+
|
| 489 |
+
def __delattr__(self, name):
|
| 490 |
+
if name in self._memories:
|
| 491 |
+
del self._memories[name]
|
| 492 |
+
del self._memories_rv[name]
|
| 493 |
+
else:
|
| 494 |
+
return super().__delattr__(name)
|
| 495 |
+
|
| 496 |
+
def __dir__(self):
|
| 497 |
+
module_attrs = dir(self.__class__)
|
| 498 |
+
attrs = list(self.__dict__.keys())
|
| 499 |
+
parameters = list(self._parameters.keys())
|
| 500 |
+
modules = list(self._modules.keys())
|
| 501 |
+
buffers = list(self._buffers.keys())
|
| 502 |
+
memories = list(self._memories.keys())
|
| 503 |
+
keys = module_attrs + attrs + parameters + modules + buffers + memories
|
| 504 |
+
|
| 505 |
+
# Eliminate attrs that are not legal Python variable names
|
| 506 |
+
keys = [key for key in keys if not key[0].isdigit()]
|
| 507 |
+
|
| 508 |
+
return sorted(keys)
|
| 509 |
+
|
| 510 |
+
def memories(self):
|
| 511 |
+
"""
|
| 512 |
+
* :ref:`API in English <MemoryModule.memories-en>`
|
| 513 |
+
|
| 514 |
+
.. _MemoryModule.memories-cn:
|
| 515 |
+
|
| 516 |
+
:return: 返回一个所有状态变量的迭代器
|
| 517 |
+
:rtype: Iterator
|
| 518 |
+
|
| 519 |
+
* :ref:`中文API <MemoryModule.memories-cn>`
|
| 520 |
+
|
| 521 |
+
.. _MemoryModule.memories-en:
|
| 522 |
+
|
| 523 |
+
:return: an iterator over all stateful variables
|
| 524 |
+
:rtype: Iterator
|
| 525 |
+
"""
|
| 526 |
+
for name, value in self._memories.items():
|
| 527 |
+
yield value
|
| 528 |
+
|
| 529 |
+
def named_memories(self):
|
| 530 |
+
"""
|
| 531 |
+
* :ref:`API in English <MemoryModule.named_memories-en>`
|
| 532 |
+
|
| 533 |
+
.. _MemoryModule.named_memories-cn:
|
| 534 |
+
|
| 535 |
+
:return: 返回一个所有状态变量及其名称的迭代器
|
| 536 |
+
:rtype: Iterator
|
| 537 |
+
|
| 538 |
+
* :ref:`中文API <MemoryModule.named_memories-cn>`
|
| 539 |
+
|
| 540 |
+
.. _MemoryModule.named_memories-en:
|
| 541 |
+
|
| 542 |
+
:return: an iterator over all stateful variables and their names
|
| 543 |
+
:rtype: Iterator
|
| 544 |
+
"""
|
| 545 |
+
|
| 546 |
+
for name, value in self._memories.items():
|
| 547 |
+
yield name, value
|
| 548 |
+
|
| 549 |
+
def detach(self):
|
| 550 |
+
"""
|
| 551 |
+
* :ref:`API in English <MemoryModule.detach-en>`
|
| 552 |
+
|
| 553 |
+
.. _MemoryModule.detach-cn:
|
| 554 |
+
|
| 555 |
+
从计算图中分离所有有状态变量。
|
| 556 |
+
|
| 557 |
+
.. tip::
|
| 558 |
+
|
| 559 |
+
可以使用这个函数实现TBPTT(Truncated Back Propagation Through Time)。
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
* :ref:`中文API <MemoryModule.detach-cn>`
|
| 563 |
+
|
| 564 |
+
.. _MemoryModule.detach-en:
|
| 565 |
+
|
| 566 |
+
Detach all stateful variables.
|
| 567 |
+
|
| 568 |
+
.. admonition:: Tip
|
| 569 |
+
:class: tip
|
| 570 |
+
|
| 571 |
+
We can use this function to implement TBPTT(Truncated Back Propagation Through Time).
|
| 572 |
+
|
| 573 |
+
"""
|
| 574 |
+
|
| 575 |
+
for key in self._memories.keys():
|
| 576 |
+
if isinstance(self._memories[key], torch.Tensor):
|
| 577 |
+
self._memories[key].detach_()
|
| 578 |
+
|
| 579 |
+
def _apply(self, fn):
|
| 580 |
+
for key, value in self._memories.items():
|
| 581 |
+
if isinstance(value, torch.Tensor):
|
| 582 |
+
self._memories[key] = fn(value)
|
| 583 |
+
return super()._apply(fn)
|
| 584 |
+
|
| 585 |
+
def _replicate_for_data_parallel(self):
|
| 586 |
+
replica = super()._replicate_for_data_parallel()
|
| 587 |
+
replica._memories = self._memories.copy()
|
| 588 |
+
return replica
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
class StepModule:
|
| 592 |
+
def supported_step_mode(self):
|
| 593 |
+
"""
|
| 594 |
+
* :ref:`API in English <StepModule.supported_step_mode-en>`
|
| 595 |
+
.. _StepModule.supported_step_mode-cn:
|
| 596 |
+
:return: 包含支持的后端的tuple
|
| 597 |
+
:rtype: tuple[str]
|
| 598 |
+
返回此模块支持的步进模式。
|
| 599 |
+
* :ref:`中文 API <StepModule.supported_step_mode-cn>`
|
| 600 |
+
.. _StepModule.supported_step_mode-en:
|
| 601 |
+
:return: a tuple that contains the supported backends
|
| 602 |
+
:rtype: tuple[str]
|
| 603 |
+
"""
|
| 604 |
+
return ('s', 'm')
|
| 605 |
+
|
| 606 |
+
@property
|
| 607 |
+
def step_mode(self):
|
| 608 |
+
"""
|
| 609 |
+
* :ref:`API in English <StepModule.step_mode-en>`
|
| 610 |
+
.. _StepModule.step_mode-cn:
|
| 611 |
+
:return: 模块当前使用的步进模式
|
| 612 |
+
:rtype: str
|
| 613 |
+
* :ref:`中文 API <StepModule.step_mode-cn>`
|
| 614 |
+
.. _StepModule.step_mode-en:
|
| 615 |
+
:return: the current step mode of this module
|
| 616 |
+
:rtype: str
|
| 617 |
+
"""
|
| 618 |
+
return self._step_mode
|
| 619 |
+
|
| 620 |
+
@step_mode.setter
|
| 621 |
+
def step_mode(self, value: str):
|
| 622 |
+
"""
|
| 623 |
+
* :ref:`API in English <StepModule.step_mode-setter-en>`
|
| 624 |
+
.. _StepModule.step_mode-setter-cn:
|
| 625 |
+
:param value: 步进模式
|
| 626 |
+
:type value: str
|
| 627 |
+
将本模块的步进模式设置为 ``value``
|
| 628 |
+
* :ref:`中文 API <StepModule.step_mode-setter-cn>`
|
| 629 |
+
.. _StepModule.step_mode-setter-en:
|
| 630 |
+
:param value: the step mode
|
| 631 |
+
:type value: str
|
| 632 |
+
Set the step mode of this module to be ``value``
|
| 633 |
+
"""
|
| 634 |
+
if value not in self.supported_step_mode():
|
| 635 |
+
raise ValueError(f'step_mode can only be {self.supported_step_mode()}, but got "{value}"!')
|
| 636 |
+
self._step_mode = value
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
class BaseNode(MemoryModule):
|
| 641 |
+
def __init__(self,
|
| 642 |
+
v_threshold: float = 1.,
|
| 643 |
+
v_reset: float = 0.,
|
| 644 |
+
surrogate_function: Callable = None,
|
| 645 |
+
detach_reset: bool = False,
|
| 646 |
+
step_mode='s', backend='torch',
|
| 647 |
+
store_v_seq: bool = True):
|
| 648 |
+
|
| 649 |
+
assert isinstance(v_reset, float) or v_reset is None
|
| 650 |
+
assert isinstance(v_threshold, float)
|
| 651 |
+
assert isinstance(detach_reset, bool)
|
| 652 |
+
super().__init__()
|
| 653 |
+
|
| 654 |
+
if v_reset is None:
|
| 655 |
+
self.register_memory('v', 0.)
|
| 656 |
+
self.register_memory('v_s', 0.)
|
| 657 |
+
else:
|
| 658 |
+
self.register_memory('v', v_reset)
|
| 659 |
+
|
| 660 |
+
self.v_threshold = v_threshold
|
| 661 |
+
|
| 662 |
+
self.v_reset = v_reset
|
| 663 |
+
self.detach_reset = detach_reset
|
| 664 |
+
self.surrogate_function = surrogate_function
|
| 665 |
+
|
| 666 |
+
self.step_mode = step_mode
|
| 667 |
+
self.backend = backend
|
| 668 |
+
|
| 669 |
+
self.store_v_seq = store_v_seq
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
self.alpha_s = torch.nn.Parameter(torch.tensor(0.5, dtype=torch.float))
|
| 673 |
+
self.alpha_l = torch.nn.Parameter(torch.tensor(0.5, dtype=torch.float))
|
| 674 |
+
|
| 675 |
+
@property
|
| 676 |
+
def store_v_seq(self):
|
| 677 |
+
return self._store_v_seq
|
| 678 |
+
|
| 679 |
+
@store_v_seq.setter
|
| 680 |
+
def store_v_seq(self, value: bool):
|
| 681 |
+
self._store_v_seq = value
|
| 682 |
+
if value:
|
| 683 |
+
if not hasattr(self, 'v_seq'):
|
| 684 |
+
self.register_memory('v_seq', None)
|
| 685 |
+
|
| 686 |
+
@staticmethod
|
| 687 |
+
@torch.jit.script
|
| 688 |
+
def jit_hard_reset(v: torch.Tensor, spike: torch.Tensor, v_reset: float):
|
| 689 |
+
v = (1. - spike) * v + spike * v_reset
|
| 690 |
+
|
| 691 |
+
return v
|
| 692 |
+
|
| 693 |
+
@staticmethod
|
| 694 |
+
@torch.jit.script
|
| 695 |
+
def jit_soft_reset(v: torch.Tensor, spike: torch.Tensor, v_threshold: float):
|
| 696 |
+
v = v - spike * v_threshold
|
| 697 |
+
return v
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
@abstractmethod
|
| 701 |
+
def neuronal_charge(self, x: torch.Tensor):
|
| 702 |
+
raise NotImplementedError
|
| 703 |
+
|
| 704 |
+
def neuronal_fire(self):
|
| 705 |
+
return self.surrogate_function(self.v - self.v_threshold, 2.0)
|
| 706 |
+
|
| 707 |
+
def sl_neuronal_fire(self):
|
| 708 |
+
s_s = self.surrogate_function(self.v - self.v_threshold, 2.0)
|
| 709 |
+
s_l = self.surrogate_function(self.v_s - self.v_threshold, 2.0)
|
| 710 |
+
return s_s, s_l
|
| 711 |
+
|
| 712 |
+
def extra_repr(self):
|
| 713 |
+
return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, step_mode={self.step_mode}, backend={self.backend}'
|
| 714 |
+
|
| 715 |
+
def single_step_forward(self, x: torch.Tensor):
|
| 716 |
+
self.v_float_to_tensor(x)
|
| 717 |
+
self.neuronal_charge(x)
|
| 718 |
+
s_s, s_l = self.sl_neuronal_fire()
|
| 719 |
+
spike = self.alpha_s * s_s + self.alpha_l * s_l
|
| 720 |
+
self.neuronal_reset(s_s, s_l)
|
| 721 |
+
|
| 722 |
+
return spike
|
| 723 |
+
|
| 724 |
+
def multi_step_forward(self, x_seq: torch.Tensor):
|
| 725 |
+
|
| 726 |
+
T = x_seq.shape[-1]
|
| 727 |
+
y_seq = []
|
| 728 |
+
if self.store_v_seq:
|
| 729 |
+
v_seq = []
|
| 730 |
+
for t in range(T):
|
| 731 |
+
y = self.single_step_forward(x_seq[:, t])
|
| 732 |
+
y_seq.append(y)
|
| 733 |
+
if self.store_v_seq:
|
| 734 |
+
v_seq.append(self.v)
|
| 735 |
+
if self.store_v_seq:
|
| 736 |
+
self.v_seq = torch.stack(v_seq)
|
| 737 |
+
|
| 738 |
+
outputs = torch.stack(y_seq, dim=0).permute(1, 0)
|
| 739 |
+
|
| 740 |
+
return outputs
|
| 741 |
+
|
| 742 |
+
def v_float_to_tensor(self, x: torch.Tensor):
|
| 743 |
+
if isinstance(self.v, float):
|
| 744 |
+
v_init = self.v
|
| 745 |
+
self.v = torch.full_like(x.data, v_init)
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
class TSLIFNode(BaseNode):
|
| 749 |
+
def __init__(self,
|
| 750 |
+
v_threshold=1.0,
|
| 751 |
+
v_reset=0.,
|
| 752 |
+
surrogate_function: Callable = None,
|
| 753 |
+
detach_reset=False,
|
| 754 |
+
hard_reset=False,
|
| 755 |
+
step_mode='s',
|
| 756 |
+
k=2,
|
| 757 |
+
decay_factor: torch.Tensor = torch.tensor([0.8, 0.2, 0.3, 0.7], dtype=torch.float),
|
| 758 |
+
gamma: float = 0.5):
|
| 759 |
+
super(TSLIFNode, self).__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode)
|
| 760 |
+
self.k = k
|
| 761 |
+
for i in range(1, self.k + 1):
|
| 762 |
+
self.register_memory('v' + str(i), 0.)
|
| 763 |
+
self.names = self._memories
|
| 764 |
+
self.hard_reset = hard_reset
|
| 765 |
+
self.gamma = gamma
|
| 766 |
+
self.decay_factor = torch.nn.Parameter(decay_factor)
|
| 767 |
+
self.kk = torch.nn.Parameter(torch.tensor([0.8], dtype=torch.float))
|
| 768 |
+
self.yy = torch.nn.Parameter(torch.tensor([0.1], dtype=torch.float))
|
| 769 |
+
|
| 770 |
+
@property
|
| 771 |
+
def supported_backends(self):
|
| 772 |
+
if self.step_mode == 's':
|
| 773 |
+
return ('torch',)
|
| 774 |
+
elif self.step_mode == 'm':
|
| 775 |
+
return ('torch', 'cupy')
|
| 776 |
+
else:
|
| 777 |
+
raise ValueError(self.step_mode)
|
| 778 |
+
|
| 779 |
+
def neuronal_charge(self, x: torch.Tensor):
|
| 780 |
+
self.names['v1'] = self.decay_factor[0] * self.names['v1'] + self.decay_factor[1] * x - self.yy * self.names['v2']
|
| 781 |
+
self.names['v2'] = self.decay_factor[2] * self.names['v2'] + self.decay_factor[3] * x - self.kk * self.names['v1']
|
| 782 |
+
self.v = self.names['v2']
|
| 783 |
+
self.v_s = self.names['v1']
|
| 784 |
+
|
| 785 |
+
def neuronal_reset(self, spike_s, spike_l):
|
| 786 |
+
if not self.hard_reset:
|
| 787 |
+
self.names['v1'] = self.jit_soft_reset(self.names['v1'], spike_l, self.gamma)
|
| 788 |
+
self.names['v2'] = self.jit_soft_reset(self.names['v2'], spike_s, self.v_threshold)
|
| 789 |
+
else:
|
| 790 |
+
for i in range(2, self.k + 1):
|
| 791 |
+
self.names['v' + str(i)] = self.jit_hard_reset(self.names['v' + str(i)], spike_s, self.v_reset)
|
| 792 |
+
|
| 793 |
+
def forward(self, x: torch.Tensor):
|
| 794 |
+
return super().single_step_forward(x)
|
| 795 |
+
def extra_repr(self):
|
| 796 |
+
return f"v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, " \
|
| 797 |
+
f"hard_reset={self.hard_reset}, " \
|
| 798 |
+
f"gamma={self.gamma}, k={self.k}, step_mode={self.step_mode}, backend={self.backend}"
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
class SpikeTemporalBlock(nn.Module):
|
| 805 |
+
def __init__(
|
| 806 |
+
self,
|
| 807 |
+
n_inputs,
|
| 808 |
+
n_outputs,
|
| 809 |
+
kernel_size,
|
| 810 |
+
stride,
|
| 811 |
+
dilation,
|
| 812 |
+
padding,
|
| 813 |
+
num_steps=4,
|
| 814 |
+
):
|
| 815 |
+
super().__init__()
|
| 816 |
+
self.num_steps = num_steps
|
| 817 |
+
self.conv1 = weight_norm(
|
| 818 |
+
nn.Conv2d(
|
| 819 |
+
n_inputs,
|
| 820 |
+
n_outputs,
|
| 821 |
+
(1, kernel_size),
|
| 822 |
+
stride=stride,
|
| 823 |
+
padding=(0, padding),
|
| 824 |
+
dilation=(1, dilation),
|
| 825 |
+
)
|
| 826 |
+
)
|
| 827 |
+
self.bn1 = nn.BatchNorm2d(n_outputs)
|
| 828 |
+
self.chomp1 = Chomp2d(padding)
|
| 829 |
+
|
| 830 |
+
self.tslif1 = TSLIFNode(
|
| 831 |
+
surrogate_function =SG.apply,
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
self.conv2 = weight_norm(
|
| 835 |
+
nn.Conv2d(
|
| 836 |
+
n_outputs,
|
| 837 |
+
n_outputs,
|
| 838 |
+
(1, kernel_size),
|
| 839 |
+
stride=stride,
|
| 840 |
+
padding=(0, padding),
|
| 841 |
+
dilation=(1, dilation),
|
| 842 |
+
)
|
| 843 |
+
)
|
| 844 |
+
self.bn2 = nn.BatchNorm2d(n_outputs)
|
| 845 |
+
self.chomp2 = Chomp2d(padding)
|
| 846 |
+
|
| 847 |
+
self.tslif2 = TSLIFNode(
|
| 848 |
+
surrogate_function =SG.apply,
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
self.downsample = (
|
| 852 |
+
nn.Conv2d(n_inputs, n_outputs, (1, 1)) if n_inputs != n_outputs else None
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
self.tslif = TSLIFNode(
|
| 856 |
+
surrogate_function =SG.apply,
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
def init_weights(self):
|
| 860 |
+
self.conv1.weight.data.normal_(0, 0.01)
|
| 861 |
+
self.conv2.weight.data.normal_(0, 0.01)
|
| 862 |
+
if self.downsample is not None:
|
| 863 |
+
self.downsample.weight.data.normal_(0, 0.01)
|
| 864 |
+
|
| 865 |
+
def forward(self, x):
|
| 866 |
+
# out1: 24, 16, 361, 168
|
| 867 |
+
|
| 868 |
+
out1 = self.chomp1(self.bn1(self.conv1(x)))
|
| 869 |
+
spk_rec1 = []
|
| 870 |
+
for _ in range(self.num_steps):
|
| 871 |
+
spk = self.tslif1(out1)
|
| 872 |
+
spk_rec1.append(spk)
|
| 873 |
+
|
| 874 |
+
spks1 = torch.stack(spk_rec1, dim=-1) # spks1: B, H, C, L, T
|
| 875 |
+
spks1 = spks1.mean(-1) # spks1: B, H, C, L
|
| 876 |
+
|
| 877 |
+
out2 = self.chomp2(self.bn2(self.conv2(spks1)))
|
| 878 |
+
spk_rec2 = []
|
| 879 |
+
for _ in range(self.num_steps):
|
| 880 |
+
# spk: 24, 16, 361, 168
|
| 881 |
+
spk = self.tslif2(out2)
|
| 882 |
+
spk_rec2.append(spk)
|
| 883 |
+
|
| 884 |
+
spks2 = torch.stack(spk_rec2, dim=-1) # spks2: B, H, C, L, T
|
| 885 |
+
spks2 = spks2.mean(-1) # spks2: B, H, C, L
|
| 886 |
+
|
| 887 |
+
if torch.isnan(spks2).any() or torch.isinf(spks2).any():
|
| 888 |
+
print("illegal value in TemporalBlock2D")
|
| 889 |
+
|
| 890 |
+
if self.downsample is None:
|
| 891 |
+
res = x
|
| 892 |
+
else:
|
| 893 |
+
res = self.downsample(x)
|
| 894 |
+
|
| 895 |
+
spk_rec3 = []
|
| 896 |
+
for _ in range(self.num_steps):
|
| 897 |
+
|
| 898 |
+
spk = self.tslif(spks2 + res)
|
| 899 |
+
spk_rec3.append(spk)
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
res = torch.stack(spk_rec3, dim=-1) # res: B, H, C, L, T
|
| 903 |
+
|
| 904 |
+
res = res.mean(-1)
|
| 905 |
+
|
| 906 |
+
return res
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
class TSTCN(nn.Module):
|
| 913 |
+
def __init__(
|
| 914 |
+
self,
|
| 915 |
+
args,
|
| 916 |
+
num_levels: int = 3,
|
| 917 |
+
channel: int = 16,
|
| 918 |
+
dilation: int = 2,
|
| 919 |
+
stride: int = 1,
|
| 920 |
+
kernel_size: int = 2,
|
| 921 |
+
dropout: float = 0.2,
|
| 922 |
+
max_length: int = 100,
|
| 923 |
+
encoder_type: str = "conv",
|
| 924 |
+
pe_type: str = "neuron",
|
| 925 |
+
pe_mode: str = "concat",
|
| 926 |
+
num_pe_neuron: int = 40,
|
| 927 |
+
neuron_pe_scale: float = 1000.0,
|
| 928 |
+
):
|
| 929 |
+
super().__init__()
|
| 930 |
+
|
| 931 |
+
self.hidden_size = args.hidden_size
|
| 932 |
+
self.num_steps = args.T
|
| 933 |
+
self.input_size = args.feature_size
|
| 934 |
+
self.feature_size = args.feature_size
|
| 935 |
+
self.pre_length = args.pre_length
|
| 936 |
+
self.num_levels = args.blocks
|
| 937 |
+
self.pe_type = pe_type
|
| 938 |
+
self.pe_mode = pe_mode
|
| 939 |
+
self.num_pe_neuron = num_pe_neuron
|
| 940 |
+
self.kernel_size = args.kernel_size
|
| 941 |
+
self.args = args
|
| 942 |
+
|
| 943 |
+
|
| 944 |
+
|
| 945 |
+
self._snn_backend = "snntorch"
|
| 946 |
+
self.encoder = SpikeEncoder[self._snn_backend][encoder_type](self.hidden_size)
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
self.pe = PositionEmbedding(
|
| 950 |
+
pe_type=pe_type,
|
| 951 |
+
pe_mode=pe_mode,
|
| 952 |
+
neuron_pe_scale=neuron_pe_scale,
|
| 953 |
+
input_size=self.input_size,
|
| 954 |
+
max_len=max_length,
|
| 955 |
+
num_pe_neuron=self.num_pe_neuron,
|
| 956 |
+
dropout=0.1,
|
| 957 |
+
num_steps=self.num_steps,
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
|
| 961 |
+
|
| 962 |
+
layers = []
|
| 963 |
+
num_channels = [channel] * self.num_levels
|
| 964 |
+
num_channels.append(1)
|
| 965 |
+
for i in range(self.num_levels + 1):
|
| 966 |
+
dilation_size = dilation**i
|
| 967 |
+
in_channels = self.hidden_size if i == 0 else num_channels[i - 1]
|
| 968 |
+
out_channels = num_channels[i]
|
| 969 |
+
layers += [
|
| 970 |
+
SpikeTemporalBlock(
|
| 971 |
+
in_channels,
|
| 972 |
+
out_channels,
|
| 973 |
+
self.kernel_size,
|
| 974 |
+
stride=stride,
|
| 975 |
+
dilation=dilation_size,
|
| 976 |
+
padding=(self.kernel_size - 1) * dilation_size,
|
| 977 |
+
num_steps=self.num_steps,
|
| 978 |
+
)
|
| 979 |
+
]
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
|
| 983 |
+
self.network = nn.Sequential(*layers)
|
| 984 |
+
|
| 985 |
+
if (self.pe_type == "neuron" and self.pe_mode == "concat") or (
|
| 986 |
+
self.pe_type == "random" and self.pe_mode == "concat"
|
| 987 |
+
):
|
| 988 |
+
self.__output_size = self.feature_size + num_pe_neuron
|
| 989 |
+
else:
|
| 990 |
+
self.__output_size = args.seq_length
|
| 991 |
+
|
| 992 |
+
self.fc1 = nn.Linear(self.__output_size, args.feature_size)
|
| 993 |
+
self.fc2 = nn.Linear(args.seq_length, self.pre_length)
|
| 994 |
+
self.to('cuda:0')
|
| 995 |
+
|
| 996 |
+
def forward(self, inputs: torch.Tensor):
|
| 997 |
+
utils.reset(self.encoder)
|
| 998 |
+
|
| 999 |
+
if self.args.normalize:
|
| 1000 |
+
|
| 1001 |
+
mean = inputs.mean(dim=1, keepdim=True).detach() # shape [B, 1, D]
|
| 1002 |
+
inputs = inputs - mean
|
| 1003 |
+
|
| 1004 |
+
std = torch.sqrt(torch.var(inputs, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
| 1005 |
+
inputs = inputs / std
|
| 1006 |
+
|
| 1007 |
+
inputs = self.encoder(inputs) # B, H, C, L
|
| 1008 |
+
# inputs: 24, 64, 321, 168
|
| 1009 |
+
|
| 1010 |
+
|
| 1011 |
+
if self.pe_type != "none":
|
| 1012 |
+
inputs = self.pe(inputs.permute(1, 0, 3, 2)).permute(1, 0, 3, 2)
|
| 1013 |
+
|
| 1014 |
+
spks = self.network(inputs)
|
| 1015 |
+
spks = spks.squeeze(1) # B, C', L
|
| 1016 |
+
preds = self.fc1(spks.permute(0, 2, 1)) # B, L, C
|
| 1017 |
+
preds = self.fc2(preds.permute(0, 2, 1)) # B, C', L
|
| 1018 |
+
preds = preds.permute(0, 2, 1).contiguous()
|
| 1019 |
+
if self.args.normalize:
|
| 1020 |
+
preds = preds * std + mean # denormalize
|
| 1021 |
+
|
| 1022 |
+
|
| 1023 |
+
# Create auxiliary output
|
| 1024 |
+
aux = {'gate_l0': torch.tensor(0.0, device=preds.device)} # placeholder
|
| 1025 |
+
|
| 1026 |
+
return preds, aux
|
| 1027 |
+
|
| 1028 |
+
@property
|
| 1029 |
+
def output_size(self):
|
| 1030 |
+
return self.__output_size
|
model/iSpikformer.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from spikingjelly.clock_driven.neuron import MultiStepLIFNode
|
| 4 |
+
|
| 5 |
+
class SPE(nn.Module):
|
| 6 |
+
def __init__(self, input_len, patch_num, patch_dim, T, tau, D):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.patch_projector = nn.Linear(input_len // patch_num, patch_dim)
|
| 9 |
+
self.bn = nn.BatchNorm2d(patch_dim)
|
| 10 |
+
self.encoder_lif = MultiStepLIFNode(tau=tau, detach_reset=False, backend='torch')
|
| 11 |
+
|
| 12 |
+
self.D = D
|
| 13 |
+
self.T = T
|
| 14 |
+
self.patch_dim = patch_dim
|
| 15 |
+
self.patch_num = patch_num
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
B, L, D = x.shape
|
| 19 |
+
|
| 20 |
+
x = x.view(B, self.patch_num, L // self.patch_num, D).contiguous()
|
| 21 |
+
x = x.transpose(-1, -2).contiguous()
|
| 22 |
+
x = self.patch_projector(x)
|
| 23 |
+
x = x.repeat(self.T, 1, 1, 1, 1)
|
| 24 |
+
x = x.permute(0, 1, 4, 2, 3).contiguous()
|
| 25 |
+
x = x.flatten(0, 1)
|
| 26 |
+
x = self.bn(x)
|
| 27 |
+
x = x.view(self.T, B, self.patch_dim, self.patch_num, D)
|
| 28 |
+
x = self.encoder_lif(x)
|
| 29 |
+
|
| 30 |
+
return x
|
| 31 |
+
|
| 32 |
+
class iSSA(nn.Module):
|
| 33 |
+
def __init__(self, patch_num, D, patch_dim, tau, alpha):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.lin1 = nn.Linear(patch_num, patch_num)
|
| 36 |
+
self.lin2 = nn.Linear(patch_num, patch_num)
|
| 37 |
+
self.lin3 = nn.Linear(patch_num, patch_num)
|
| 38 |
+
|
| 39 |
+
self.lif1 = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch')
|
| 40 |
+
self.lif2 = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch')
|
| 41 |
+
self.lif3 = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch')
|
| 42 |
+
self.lif4 = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch')
|
| 43 |
+
|
| 44 |
+
self.b1 = nn.BatchNorm2d(patch_dim)
|
| 45 |
+
self.b2 = nn.BatchNorm2d(patch_dim)
|
| 46 |
+
self.b3 = nn.BatchNorm2d(patch_dim)
|
| 47 |
+
self.b4 = nn.BatchNorm2d(patch_dim)
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
res_x = x
|
| 51 |
+
T, B, pd, pn, D = x.shape
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
x = x.transpose(-1, -2).contiguous()
|
| 55 |
+
q = self.lin1(x).flatten(0, 1)
|
| 56 |
+
k = self.lin2(x).flatten(0, 1)
|
| 57 |
+
v = self.lin3(x).flatten(0, 1)
|
| 58 |
+
|
| 59 |
+
q = self.b1(q)
|
| 60 |
+
k = self.b2(k)
|
| 61 |
+
v = self.b3(v)
|
| 62 |
+
|
| 63 |
+
q = q.view(T, B, pd, D, -1)
|
| 64 |
+
k = k.view(T, B, pd, D, -1)
|
| 65 |
+
v = v.view(T, B, pd, D, -1)
|
| 66 |
+
|
| 67 |
+
q = self.lif1(q)
|
| 68 |
+
k = self.lif2(k).transpose(-1, -2).contiguous()
|
| 69 |
+
v = self.lif3(v)
|
| 70 |
+
|
| 71 |
+
attn = q @ k
|
| 72 |
+
attn = attn @ v
|
| 73 |
+
attn = attn.flatten(0, 1)
|
| 74 |
+
attn = self.b4(attn)
|
| 75 |
+
attn = attn.view(T, B, pd, D, pn)
|
| 76 |
+
attn = self.lif4(attn)
|
| 77 |
+
attn = attn.transpose(-1, -2).contiguous()
|
| 78 |
+
|
| 79 |
+
return attn
|
| 80 |
+
|
| 81 |
+
class iSpikformer(nn.Module):
|
| 82 |
+
def __init__(self, args, input_len, patch_num, patch_dim, T, blocks, D, pred_len, tau, alpha, hidden_dim):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.emb = SPE(input_len, patch_num, patch_dim, T, tau, D)
|
| 85 |
+
self.args = args
|
| 86 |
+
self.attn = nn.ModuleList()
|
| 87 |
+
for i in range(blocks):
|
| 88 |
+
self.attn.append(iSSA(patch_num, D, patch_dim, tau, alpha))
|
| 89 |
+
|
| 90 |
+
self.dense1 = nn.Linear(patch_num*patch_dim, hidden_dim)
|
| 91 |
+
self.dense2 = nn.Linear(hidden_dim, pred_len)
|
| 92 |
+
self.bn = nn.BatchNorm1d(D)
|
| 93 |
+
self.activ = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch')
|
| 94 |
+
self.to('cuda:0')
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def forward(self, x):
|
| 98 |
+
if self.args.normalize:
|
| 99 |
+
mean = x.mean(dim=1, keepdim=True).detach()
|
| 100 |
+
x = x - mean
|
| 101 |
+
|
| 102 |
+
std = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
| 103 |
+
x = x / std
|
| 104 |
+
|
| 105 |
+
x = self.emb(x)
|
| 106 |
+
T, B, pd, pn, D = x.shape
|
| 107 |
+
|
| 108 |
+
for i in range(len(self.attn)):
|
| 109 |
+
x = self.attn[i](x)
|
| 110 |
+
x = x.permute(0, 1, 4, 2, 3).contiguous()
|
| 111 |
+
x = x.flatten(-2, -1)
|
| 112 |
+
x = self.dense1(x)
|
| 113 |
+
x = x.flatten(0, 1)
|
| 114 |
+
x = self.bn(x)
|
| 115 |
+
x = self.activ(x)
|
| 116 |
+
x = self.dense2(x)
|
| 117 |
+
x = x.transpose(-1, -2).contiguous()
|
| 118 |
+
x = x.view(T, B, -1, D)
|
| 119 |
+
|
| 120 |
+
if self.args.normalize:
|
| 121 |
+
x = x * std
|
| 122 |
+
x = x + mean.repeat(T, 1, 1, 1)
|
| 123 |
+
|
| 124 |
+
aux = {
|
| 125 |
+
'gate_l0': torch.tensor(0.0, device=x.device) # placeholder
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
return x.mean(dim=0), aux
|
| 129 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy
|
| 2 |
+
pandas
|
| 3 |
+
torch
|
| 4 |
+
scikit-learn
|
| 5 |
+
snntorch
|
| 6 |
+
spikingjelly
|
scripts/ecl.sh
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
if [ ! -d "./logs" ]; then
|
| 9 |
+
mkdir ./logs
|
| 10 |
+
fi
|
| 11 |
+
|
| 12 |
+
if [ ! -d "./logs/LongForecasting" ]; then
|
| 13 |
+
mkdir ./logs/LongForecasting
|
| 14 |
+
fi
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
python train.py \
|
| 24 |
+
--model FGN \
|
| 25 |
+
--data electricity \
|
| 26 |
+
--feature_size 370\
|
| 27 |
+
--embed_size 128 \
|
| 28 |
+
--hidden_size 256 \
|
| 29 |
+
--batch_size 16 \
|
| 30 |
+
--train_ratio 0.7 \
|
| 31 |
+
--val_ratio 0.2 \
|
| 32 |
+
--seq_length 12 \
|
| 33 |
+
--pre_length 12 \
|
| 34 |
+
--train_epochs 100 \
|
| 35 |
+
--learning_rate 0.00001 \
|
| 36 |
+
--device cuda:0 >logs/LongForecasting/ECL_FGN.log
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
python train.py \
|
| 40 |
+
--model SpikF \
|
| 41 |
+
--data electricity \
|
| 42 |
+
--feature_size 370\
|
| 43 |
+
--embed_size 128 \
|
| 44 |
+
--hidden_size 256 \
|
| 45 |
+
--batch_size 16 \
|
| 46 |
+
--train_ratio 0.7 \
|
| 47 |
+
--val_ratio 0.2 \
|
| 48 |
+
--seq_length 12 \
|
| 49 |
+
--pre_length 12 \
|
| 50 |
+
--train_epochs 100 \
|
| 51 |
+
--learning_rate 0.00001 \
|
| 52 |
+
--T 16 \
|
| 53 |
+
--blocks 2\
|
| 54 |
+
--device cuda:0 >logs/LongForecasting/ECL_SpikF.log
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
python train.py \
|
| 58 |
+
--model iSpikformer \
|
| 59 |
+
--data electricity \
|
| 60 |
+
--feature_size 370\
|
| 61 |
+
--embed_size 128 \
|
| 62 |
+
--hidden_size 256 \
|
| 63 |
+
--batch_size 16 \
|
| 64 |
+
--train_ratio 0.7 \
|
| 65 |
+
--val_ratio 0.2 \
|
| 66 |
+
--seq_length 12 \
|
| 67 |
+
--pre_length 12 \
|
| 68 |
+
--train_epochs 100 \
|
| 69 |
+
--learning_rate 0.00001 \
|
| 70 |
+
--blocks 2 \
|
| 71 |
+
--device cuda:0 >logs/LongForecasting/ECL_iSpikformer.log
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
python train.py \
|
| 76 |
+
--model SpikF_GO \
|
| 77 |
+
--data electricity \
|
| 78 |
+
--feature_size 370\
|
| 79 |
+
--embed_size 128 \
|
| 80 |
+
--hidden_size 256 \
|
| 81 |
+
--batch_size 16 \
|
| 82 |
+
--train_ratio 0.7 \
|
| 83 |
+
--val_ratio 0.2 \
|
| 84 |
+
--seq_length 12 \
|
| 85 |
+
--pre_length 12 \
|
| 86 |
+
--train_epochs 100 \
|
| 87 |
+
--learning_rate 0.00001 \
|
| 88 |
+
--energy_loss True \
|
| 89 |
+
--device cuda:0 >logs/LongForecasting/ECL_SpikFGO.log
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
python train.py \
|
| 93 |
+
--model SpikF_GO_CPG \
|
| 94 |
+
--data electricity \
|
| 95 |
+
--feature_size 370\
|
| 96 |
+
--embed_size 128 \
|
| 97 |
+
--hidden_size 256 \
|
| 98 |
+
--batch_size 16 \
|
| 99 |
+
--train_ratio 0.7 \
|
| 100 |
+
--val_ratio 0.2 \
|
| 101 |
+
--seq_length 12 \
|
| 102 |
+
--pre_length 12 \
|
| 103 |
+
--train_epochs 100 \
|
| 104 |
+
--learning_rate 0.00001 \
|
| 105 |
+
--energy_loss True \
|
| 106 |
+
--device cuda:0 >logs/LongForecasting/ECL_SpikFGOCPG.log
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
python train.py \
|
| 110 |
+
--model SpikeRNN_CPG \
|
| 111 |
+
--data electricity \
|
| 112 |
+
--feature_size 370\
|
| 113 |
+
--embed_size 128 \
|
| 114 |
+
--hidden_size 128\
|
| 115 |
+
--batch_size 16 \
|
| 116 |
+
--train_ratio 0.7 \
|
| 117 |
+
--val_ratio 0.2 \
|
| 118 |
+
--seq_length 12 \
|
| 119 |
+
--pre_length 12 \
|
| 120 |
+
--train_epochs 100 \
|
| 121 |
+
--learning_rate 0.00001 \
|
| 122 |
+
--blocks 2 \
|
| 123 |
+
--device cuda:0 >logs/LongForecasting/ECL_SpikeRNNCPG.log
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
python train.py \
|
| 129 |
+
--model SpikeGRU \
|
| 130 |
+
--data electricity \
|
| 131 |
+
--feature_size 370\
|
| 132 |
+
--embed_size 128 \
|
| 133 |
+
--hidden_size 64 \
|
| 134 |
+
--batch_size 16 \
|
| 135 |
+
--train_ratio 0.7 \
|
| 136 |
+
--val_ratio 0.2 \
|
| 137 |
+
--seq_length 12 \
|
| 138 |
+
--pre_length 12 \
|
| 139 |
+
--train_epochs 100 \
|
| 140 |
+
--learning_rate 0.00001 \
|
| 141 |
+
--device cuda:0 >logs/LongForecasting/ECL_SpikeGRU.log
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
python train.py \
|
| 147 |
+
--model SpikeTCN_CPG \
|
| 148 |
+
--data electricity \
|
| 149 |
+
--feature_size 370\
|
| 150 |
+
--embed_size 128 \
|
| 151 |
+
--hidden_size 64\
|
| 152 |
+
--batch_size 16 \
|
| 153 |
+
--train_ratio 0.7 \
|
| 154 |
+
--val_ratio 0.2 \
|
| 155 |
+
--seq_length 12 \
|
| 156 |
+
--pre_length 12 \
|
| 157 |
+
--train_epochs 100 \
|
| 158 |
+
--learning_rate 0.00001 \
|
| 159 |
+
--blocks 3\
|
| 160 |
+
--device cuda:0 >logs/LongForecasting/ECL_SpikeTCNCPG.log
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
python train.py \
|
| 166 |
+
--model Spikformer_CPG \
|
| 167 |
+
--data electricity \
|
| 168 |
+
--feature_size 370\
|
| 169 |
+
--embed_size 128 \
|
| 170 |
+
--hidden_size 128\
|
| 171 |
+
--batch_size 16 \
|
| 172 |
+
--train_ratio 0.7 \
|
| 173 |
+
--val_ratio 0.2 \
|
| 174 |
+
--seq_length 12 \
|
| 175 |
+
--pre_length 12 \
|
| 176 |
+
--train_epochs 100 \
|
| 177 |
+
--learning_rate 0.00001 \
|
| 178 |
+
--blocks 2 \
|
| 179 |
+
--device cuda:0 >logs/LongForecasting/ECL_SpikformerCPG.log
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
python train.py \
|
| 184 |
+
--model TSTCN \
|
| 185 |
+
--data electricity \
|
| 186 |
+
--feature_size 370\
|
| 187 |
+
--embed_size 128 \
|
| 188 |
+
--hidden_size 64 \
|
| 189 |
+
--batch_size 16 \
|
| 190 |
+
--train_ratio 0.7 \
|
| 191 |
+
--val_ratio 0.2 \
|
| 192 |
+
--seq_length 12 \
|
| 193 |
+
--pre_length 12 \
|
| 194 |
+
--train_epochs 100 \
|
| 195 |
+
--learning_rate 0.00001 \
|
| 196 |
+
--kernel_size 3\
|
| 197 |
+
--blocks 3 \
|
| 198 |
+
--device cuda:0 >logs/LongForecasting/ECL_TSTCN.log
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
python train.py \
|
| 202 |
+
--model TSGRU \
|
| 203 |
+
--data electricity \
|
| 204 |
+
--feature_size 370\
|
| 205 |
+
--embed_size 128 \
|
| 206 |
+
--hidden_size 64 \
|
| 207 |
+
--batch_size 16 \
|
| 208 |
+
--train_ratio 0.7 \
|
| 209 |
+
--val_ratio 0.2 \
|
| 210 |
+
--seq_length 12 \
|
| 211 |
+
--pre_length 12 \
|
| 212 |
+
--train_epochs 100 \
|
| 213 |
+
--learning_rate 0.00001 \
|
| 214 |
+
--device cuda:0 >logs/LongForecasting/ECL_TSGRU.log
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
python train.py \
|
| 220 |
+
--model TSFormer \
|
| 221 |
+
--data electricity \
|
| 222 |
+
--feature_size 370\
|
| 223 |
+
--embed_size 128 \
|
| 224 |
+
--hidden_size 64 \
|
| 225 |
+
--batch_size 16 \
|
| 226 |
+
--train_ratio 0.7 \
|
| 227 |
+
--val_ratio 0.2 \
|
| 228 |
+
--seq_length 12 \
|
| 229 |
+
--pre_length 12 \
|
| 230 |
+
--train_epochs 100 \
|
| 231 |
+
--learning_rate 0.00001 \
|
| 232 |
+
--device cuda:0 >logs/LongForecasting/ECL_TSFormer.log
|
train.py
ADDED
|
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
import snntorch as snn
|
| 6 |
+
import time
|
| 7 |
+
import os
|
| 8 |
+
import numpy as np
|
| 9 |
+
import warnings
|
| 10 |
+
from spikingjelly.clock_driven import functional
|
| 11 |
+
|
| 12 |
+
from data.data_loader import (
|
| 13 |
+
Dataset_ECG, Dataset_Dhfm, Dataset_Solar, Dataset_Wiki, Dataset_PEMS_BAY
|
| 14 |
+
)
|
| 15 |
+
from utils.utils import save_model_ts, load_model_ts, evaluate
|
| 16 |
+
|
| 17 |
+
from model.FourierGNN import FGN
|
| 18 |
+
from model.SpikF import SpikF
|
| 19 |
+
from model.iSpikformer import iSpikformer
|
| 20 |
+
from model.SpikF_GO import SpikF_GO
|
| 21 |
+
from model.SpikF_GO_CPG import SpikF_GO_CPG
|
| 22 |
+
from model.TS_GRU import TSGRU
|
| 23 |
+
from model.TS_TCN import TSTCN
|
| 24 |
+
from model.TS_Former import TSFormer
|
| 25 |
+
from model.SpikeGRU import SpikeGRU
|
| 26 |
+
from model.Spikformer_CPG import Spikformer_CPG
|
| 27 |
+
from model.SpikeRNN_CPG import SpikeRNN_CPG
|
| 28 |
+
from model.SpikeTCN_CPG import SpikeTCN_CPG
|
| 29 |
+
from model.TS_TCN import TSLIFNode
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def remove(model):
|
| 33 |
+
"""Reset states of spiking neurons with warning suppression"""
|
| 34 |
+
if model is None:
|
| 35 |
+
return
|
| 36 |
+
with warnings.catch_warnings():
|
| 37 |
+
warnings.filterwarnings("ignore", message=".*not base.MemoryModule.*")
|
| 38 |
+
if hasattr(model, '__iter__'):
|
| 39 |
+
for m in model:
|
| 40 |
+
if hasattr(m, 'reset'):
|
| 41 |
+
m.reset()
|
| 42 |
+
elif hasattr(m, 'v'):
|
| 43 |
+
m.v = 0.0
|
| 44 |
+
elif hasattr(model, 'reset'):
|
| 45 |
+
model.reset()
|
| 46 |
+
elif hasattr(model, 'v'):
|
| 47 |
+
model.v = 0.0
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def reset_states(model):
|
| 51 |
+
"""Reset states of all spiking neurons (TSLIFNode, Leaky, etc.) with warning suppression."""
|
| 52 |
+
if model is None:
|
| 53 |
+
return
|
| 54 |
+
with warnings.catch_warnings():
|
| 55 |
+
warnings.filterwarnings("ignore", message=".*not base.MemoryModule.*")
|
| 56 |
+
if hasattr(model, '__iter__'):
|
| 57 |
+
for m in model:
|
| 58 |
+
reset_states(m)
|
| 59 |
+
elif hasattr(model, 'modules'):
|
| 60 |
+
for module in model.modules():
|
| 61 |
+
if isinstance(module, (snn.Leaky, TSLIFNode)):
|
| 62 |
+
try:
|
| 63 |
+
module.reset()
|
| 64 |
+
except Exception:
|
| 65 |
+
if hasattr(module, 'v'):
|
| 66 |
+
module.v = 0.0
|
| 67 |
+
elif hasattr(model, 'reset'):
|
| 68 |
+
model.reset()
|
| 69 |
+
elif hasattr(model, 'v'):
|
| 70 |
+
model.v = 0.0
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _inverse_if_possible(arr: np.ndarray, scaler):
|
| 74 |
+
"""
|
| 75 |
+
Inverse-transform arr of shape (..., D) using scaler fitted on train.
|
| 76 |
+
If scaler is None, returns arr unchanged.
|
| 77 |
+
"""
|
| 78 |
+
if scaler is None:
|
| 79 |
+
return arr
|
| 80 |
+
if not hasattr(scaler, "inverse_transform"):
|
| 81 |
+
return arr
|
| 82 |
+
|
| 83 |
+
if arr.ndim < 2:
|
| 84 |
+
return arr
|
| 85 |
+
|
| 86 |
+
D = arr.shape[-1]
|
| 87 |
+
flat = arr.reshape(-1, D)
|
| 88 |
+
inv = scaler.inverse_transform(flat)
|
| 89 |
+
return inv.reshape(arr.shape)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def compute_scores_scaled_and_orig(trues: np.ndarray, preds: np.ndarray, scaler):
|
| 93 |
+
score_scaled = evaluate(trues, preds)
|
| 94 |
+
|
| 95 |
+
trues_inv = _inverse_if_possible(trues, scaler)
|
| 96 |
+
preds_inv = _inverse_if_possible(preds, scaler)
|
| 97 |
+
score_orig = evaluate(trues_inv, preds_inv)
|
| 98 |
+
|
| 99 |
+
return score_scaled, score_orig
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _fmt_score(tag, score):
|
| 103 |
+
mape, mae, rmse, r2, rse = score
|
| 104 |
+
mape_pct = mape * 100.0
|
| 105 |
+
return f"{tag}: MAPE {mape_pct:10.6f}; MAE {mae:10.6f}; RMSE {rmse:10.6f}; R2 {r2:10.6f}; RSE {rse:10.6f}."
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# args
|
| 109 |
+
parser = argparse.ArgumentParser(description='SpikF-GO: Spiking Fourier Graph Operators for Multivariate Time Series Forecasting')
|
| 110 |
+
parser.add_argument('--data', type=str, default='ECG', help='data set')
|
| 111 |
+
parser.add_argument('--feature_size', type=int, default=140, help='feature size')
|
| 112 |
+
parser.add_argument('--seq_length', type=int, default=12, help='input length')
|
| 113 |
+
parser.add_argument('--pre_length', type=int, default=12, help='predict length')
|
| 114 |
+
parser.add_argument('--embed_size', type=int, default=128, help='embedding dimensions')
|
| 115 |
+
parser.add_argument('--hidden_size', type=int, default=256, help='hidden dimensions')
|
| 116 |
+
parser.add_argument('--train_epochs', type=int, default=100, help='train epochs')
|
| 117 |
+
parser.add_argument('--batch_size', type=int, default=4, help='input data batch size')
|
| 118 |
+
parser.add_argument('--learning_rate', type=float, default=0.00001, help='optimizer learning rate')
|
| 119 |
+
parser.add_argument('--exponential_decay_step', type=int, default=5)
|
| 120 |
+
parser.add_argument('--validate_freq', type=int, default=1)
|
| 121 |
+
parser.add_argument('--early_stop', type=bool, default=False)
|
| 122 |
+
parser.add_argument('--decay_rate', type=float, default=0.5)
|
| 123 |
+
parser.add_argument('--train_ratio', type=float, default=0.6)
|
| 124 |
+
parser.add_argument('--val_ratio', type=float, default=0.2)
|
| 125 |
+
parser.add_argument('--device', type=str, default='cuda:0', help='device')
|
| 126 |
+
parser.add_argument('--tau', type=float, default=2.0, help='tau')
|
| 127 |
+
parser.add_argument('--alpha', type=float, default=1.0)
|
| 128 |
+
parser.add_argument('--T', type=int, default=4)
|
| 129 |
+
parser.add_argument('--proj_dim', type=int, default=32, help='proj dim')
|
| 130 |
+
parser.add_argument('--model', type=str, default='FGN', help='model name')
|
| 131 |
+
|
| 132 |
+
parser.add_argument('--patch_num', type=int, default=4)
|
| 133 |
+
parser.add_argument('--patch_dim', type=int, default=16)
|
| 134 |
+
parser.add_argument('--blocks', type=int, default=1)
|
| 135 |
+
parser.add_argument('--energy_loss', type=bool, default=False)
|
| 136 |
+
parser.add_argument('--normalize', action='store_false', help='Disable normalization')
|
| 137 |
+
parser.add_argument('--affine', action='store_false', help='Disable affine layer')
|
| 138 |
+
parser.add_argument('--kernel_size', type=int, default=16)
|
| 139 |
+
|
| 140 |
+
args = parser.parse_args()
|
| 141 |
+
print(f'Training configs: {args}')
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
data_parser = {
|
| 145 |
+
'traffic': {'root_path': 'data/traffic.npy', 'type': '0'},
|
| 146 |
+
'ECG': {'root_path': 'data/ECG_data.csv', 'type': '0'},
|
| 147 |
+
'COVID': {'root_path': 'data/covid.csv', 'type': '0'},
|
| 148 |
+
'electricity': {'root_path': 'data/electricity.csv','type': '0'},
|
| 149 |
+
'solar': {'root_path': './data/solar', 'type': '0'},
|
| 150 |
+
'metr': {'root_path': 'data/metr.csv', 'type': '0'},
|
| 151 |
+
'wiki': {'root_path': 'data/wiki.csv', 'type': '0'},
|
| 152 |
+
'pems_bay': {'root_path': 'data/pems-bay.h5', 'type': '0'},
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
data_dict = {
|
| 156 |
+
'ECG': Dataset_ECG,
|
| 157 |
+
'COVID': Dataset_ECG,
|
| 158 |
+
'traffic': Dataset_Dhfm,
|
| 159 |
+
'solar': Dataset_Solar,
|
| 160 |
+
'wiki': Dataset_Wiki,
|
| 161 |
+
'electricity': Dataset_ECG,
|
| 162 |
+
'metr': Dataset_ECG,
|
| 163 |
+
'pems_bay': Dataset_PEMS_BAY,
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
if args.data not in data_parser:
|
| 167 |
+
raise ValueError(f"Unknown dataset {args.data}. Available: {list(data_parser.keys())}")
|
| 168 |
+
|
| 169 |
+
data_info = data_parser[args.data]
|
| 170 |
+
Data = data_dict[args.data]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
train_set = Data(
|
| 174 |
+
root_path=data_info['root_path'], flag='train',
|
| 175 |
+
seq_len=args.seq_length, pre_len=args.pre_length,
|
| 176 |
+
type=data_info['type'], train_ratio=args.train_ratio, val_ratio=args.val_ratio,
|
| 177 |
+
scaler=None
|
| 178 |
+
)
|
| 179 |
+
train_scaler = getattr(train_set, "scaler", None)
|
| 180 |
+
|
| 181 |
+
val_set = Data(
|
| 182 |
+
root_path=data_info['root_path'], flag='val',
|
| 183 |
+
seq_len=args.seq_length, pre_len=args.pre_length,
|
| 184 |
+
type=data_info['type'], train_ratio=args.train_ratio, val_ratio=args.val_ratio,
|
| 185 |
+
scaler=train_scaler
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
test_set = Data(
|
| 189 |
+
root_path=data_info['root_path'], flag='test',
|
| 190 |
+
seq_len=args.seq_length, pre_len=args.pre_length,
|
| 191 |
+
type=data_info['type'], train_ratio=args.train_ratio, val_ratio=args.val_ratio,
|
| 192 |
+
scaler=train_scaler
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
train_dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
|
| 196 |
+
val_dataloader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=False)
|
| 197 |
+
test_dataloader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=False)
|
| 198 |
+
|
| 199 |
+
print("Train samples:", len(train_set))
|
| 200 |
+
print("Val samples:", len(val_set))
|
| 201 |
+
print("Test samples:", len(test_set))
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
MODELS_SET2 = ["TSGRU", "TSTCN", "TSFormer", "Spikformer_CPG", "SpikeGRU", "SpikeRNN_CPG", "SpikeTCN_CPG"]
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def validate(model, vali_loader, scaler):
|
| 209 |
+
model.eval()
|
| 210 |
+
cnt = 0
|
| 211 |
+
loss_total = 0.0
|
| 212 |
+
preds_list = []
|
| 213 |
+
trues_list = []
|
| 214 |
+
|
| 215 |
+
for x, y in vali_loader:
|
| 216 |
+
if args.model in MODELS_SET2 and args.model != 'TSGRU':
|
| 217 |
+
reset_states(model=model)
|
| 218 |
+
elif args.model == 'TSGRU':
|
| 219 |
+
remove(model=model.net[0].tslif)
|
| 220 |
+
|
| 221 |
+
x = x.float().to(args.device)
|
| 222 |
+
y = y.float().to(args.device)
|
| 223 |
+
|
| 224 |
+
forecast, _ = model(x)
|
| 225 |
+
if len(forecast.shape) == 4:
|
| 226 |
+
forecast = forecast.mean(dim=0)
|
| 227 |
+
|
| 228 |
+
loss = forecast_loss(forecast, y)
|
| 229 |
+
loss_total += float(loss)
|
| 230 |
+
cnt += 1
|
| 231 |
+
|
| 232 |
+
if args.model not in MODELS_SET2:
|
| 233 |
+
functional.reset_net(model)
|
| 234 |
+
|
| 235 |
+
preds_list.append(forecast.detach().cpu().numpy())
|
| 236 |
+
trues_list.append(y.detach().cpu().numpy())
|
| 237 |
+
|
| 238 |
+
preds = np.concatenate(preds_list, axis=0)
|
| 239 |
+
trues = np.concatenate(trues_list, axis=0)
|
| 240 |
+
|
| 241 |
+
score_scaled, score_orig = compute_scores_scaled_and_orig(trues, preds, scaler)
|
| 242 |
+
|
| 243 |
+
print(_fmt_score("SCALED", score_scaled))
|
| 244 |
+
print(_fmt_score("ORIG ", score_orig))
|
| 245 |
+
|
| 246 |
+
model.train()
|
| 247 |
+
return loss_total / max(1, cnt)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def test(model, result_test_file, scaler, load_epoch=97):
|
| 251 |
+
model = load_model_ts(model, result_test_file, load_epoch)
|
| 252 |
+
model.eval()
|
| 253 |
+
|
| 254 |
+
preds_list = []
|
| 255 |
+
trues_list = []
|
| 256 |
+
|
| 257 |
+
for x, y in test_dataloader:
|
| 258 |
+
if args.model in MODELS_SET2 and args.model != 'TSGRU':
|
| 259 |
+
reset_states(model=model)
|
| 260 |
+
elif args.model == 'TSGRU':
|
| 261 |
+
remove(model=model.net[0].tslif)
|
| 262 |
+
|
| 263 |
+
x = x.float().to(args.device)
|
| 264 |
+
y = y.float().to(args.device)
|
| 265 |
+
|
| 266 |
+
forecast, _ = model(x)
|
| 267 |
+
if len(forecast.shape) == 4:
|
| 268 |
+
forecast = forecast.mean(dim=0)
|
| 269 |
+
|
| 270 |
+
if args.model not in MODELS_SET2:
|
| 271 |
+
functional.reset_net(model)
|
| 272 |
+
|
| 273 |
+
preds_list.append(forecast.detach().cpu().numpy())
|
| 274 |
+
trues_list.append(y.detach().cpu().numpy())
|
| 275 |
+
|
| 276 |
+
preds = np.concatenate(preds_list, axis=0)
|
| 277 |
+
trues = np.concatenate(trues_list, axis=0)
|
| 278 |
+
|
| 279 |
+
score_scaled, score_orig = compute_scores_scaled_and_orig(trues, preds, scaler)
|
| 280 |
+
|
| 281 |
+
print(_fmt_score("SCALED", score_scaled))
|
| 282 |
+
print(_fmt_score("ORIG ", score_orig))
|
| 283 |
+
|
| 284 |
+
return score_scaled, score_orig
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def build_opt_sched(model, lr=3e-4, wd=0.01, gate_lr_ratio=0.3,
|
| 288 |
+
warmup_epochs=8, total_epochs=100):
|
| 289 |
+
decay, no_decay, gate = [], [], []
|
| 290 |
+
for name, p in model.named_parameters():
|
| 291 |
+
if not p.requires_grad:
|
| 292 |
+
continue
|
| 293 |
+
name_l = name.lower()
|
| 294 |
+
is_bias = name.endswith('bias')
|
| 295 |
+
is_norm = ('norm' in name_l) or ('bn' in name_l)
|
| 296 |
+
is_embed = ('embeddings' in name_l) or ('time_basis' in name_l)
|
| 297 |
+
if 'freq_gate' in name_l and 'log_alpha' in name_l:
|
| 298 |
+
no_decay.append(p)
|
| 299 |
+
elif is_bias or is_norm or is_embed or p.ndim == 1:
|
| 300 |
+
no_decay.append(p)
|
| 301 |
+
else:
|
| 302 |
+
decay.append(p)
|
| 303 |
+
|
| 304 |
+
optim = torch.optim.AdamW([
|
| 305 |
+
{'params': decay, 'lr': lr, 'weight_decay': wd},
|
| 306 |
+
{'params': no_decay, 'lr': lr, 'weight_decay': 0.0},
|
| 307 |
+
], betas=(0.9, 0.99), eps=1e-8)
|
| 308 |
+
|
| 309 |
+
warmup = torch.optim.lr_scheduler.LinearLR(
|
| 310 |
+
optim, start_factor=0.1, end_factor=1.0, total_iters=warmup_epochs
|
| 311 |
+
)
|
| 312 |
+
cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 313 |
+
optim, T_max=max(1, total_epochs - warmup_epochs), eta_min=lr * 0.1
|
| 314 |
+
)
|
| 315 |
+
sched = torch.optim.lr_scheduler.SequentialLR(
|
| 316 |
+
optim, schedulers=[warmup, cosine], milestones=[warmup_epochs]
|
| 317 |
+
)
|
| 318 |
+
return optim, sched
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
if __name__ == '__main__':
|
| 323 |
+
|
| 324 |
+
seeds = [2021, 2022, 2023, 2024, 2025]
|
| 325 |
+
|
| 326 |
+
scaled_results = {'mape': [], 'mae': [], 'rmse': [], 'r2': [], 'rse': []}
|
| 327 |
+
orig_results = {'mape': [], 'mae': [], 'rmse': [], 'r2': [], 'rse': []}
|
| 328 |
+
|
| 329 |
+
for run_idx, seed in enumerate(seeds):
|
| 330 |
+
print(f"\n{'='*60}")
|
| 331 |
+
print(f"Starting Run {run_idx + 1}/5 | seed={seed}")
|
| 332 |
+
print(f"{'='*60}")
|
| 333 |
+
|
| 334 |
+
torch.manual_seed(seed)
|
| 335 |
+
np.random.seed(seed)
|
| 336 |
+
if torch.cuda.is_available():
|
| 337 |
+
torch.cuda.manual_seed(seed)
|
| 338 |
+
torch.cuda.manual_seed_all(seed)
|
| 339 |
+
|
| 340 |
+
result_train_file = os.path.join('output', args.data, args.model, f'train_run_{run_idx+1}_seed_{seed}')
|
| 341 |
+
result_test_file = os.path.join('output', args.data, args.model, f'train_run_{run_idx+1}_seed_{seed}')
|
| 342 |
+
os.makedirs(result_train_file, exist_ok=True)
|
| 343 |
+
os.makedirs(result_test_file, exist_ok=True)
|
| 344 |
+
|
| 345 |
+
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
| 346 |
+
|
| 347 |
+
if args.model == 'SpikF_GO':
|
| 348 |
+
model = SpikF_GO(args, pre_length=args.pre_length, embed_size=args.embed_size,
|
| 349 |
+
feature_size=args.feature_size, seq_length=args.seq_length, hidden_size=args.hidden_size)
|
| 350 |
+
my_optim, my_lr_scheduler = build_opt_sched(
|
| 351 |
+
model, lr=args.learning_rate, wd=0.01,
|
| 352 |
+
warmup_epochs=max(4, args.train_epochs//8), total_epochs=args.train_epochs
|
| 353 |
+
)
|
| 354 |
+
elif args.model == 'SpikF_GO_CPG':
|
| 355 |
+
model = SpikF_GO_CPG(args, pre_length=args.pre_length, embed_size=args.embed_size,
|
| 356 |
+
feature_size=args.feature_size, seq_length=args.seq_length, hidden_size=args.hidden_size)
|
| 357 |
+
my_optim, my_lr_scheduler = build_opt_sched(
|
| 358 |
+
model, lr=args.learning_rate, wd=0.01,
|
| 359 |
+
warmup_epochs=max(4, args.train_epochs//8), total_epochs=args.train_epochs
|
| 360 |
+
)
|
| 361 |
+
elif args.model == 'FGN':
|
| 362 |
+
model = FGN(args, pre_length=args.pre_length, embed_size=args.embed_size,
|
| 363 |
+
feature_size=args.feature_size, seq_length=args.seq_length, hidden_size=args.hidden_size)
|
| 364 |
+
my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
|
| 365 |
+
my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
|
| 366 |
+
elif args.model == 'SpikF':
|
| 367 |
+
model = SpikF(args, input_len=args.seq_length, patch_num=args.patch_num, patch_dim=args.patch_dim,
|
| 368 |
+
T=args.T, blocks=args.blocks, D=args.feature_size, pred_len=args.pre_length,
|
| 369 |
+
tau=args.tau, alpha=args.alpha, hidden_dim=args.hidden_size)
|
| 370 |
+
my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
|
| 371 |
+
my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
|
| 372 |
+
elif args.model == 'iSpikformer':
|
| 373 |
+
model = iSpikformer(args, input_len=args.seq_length, patch_num=args.patch_num, patch_dim=args.patch_dim,
|
| 374 |
+
T=args.T, blocks=args.blocks, D=args.feature_size, pred_len=args.pre_length,
|
| 375 |
+
tau=args.tau, alpha=args.alpha, hidden_dim=args.hidden_size)
|
| 376 |
+
my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
|
| 377 |
+
my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
|
| 378 |
+
elif args.model == 'TSGRU':
|
| 379 |
+
model = TSGRU(args, hidden_size=args.hidden_size, layers=args.blocks,
|
| 380 |
+
num_steps=args.T, input_size=args.feature_size)
|
| 381 |
+
my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
|
| 382 |
+
my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
|
| 383 |
+
elif args.model == 'TSTCN':
|
| 384 |
+
model = TSTCN(args=args, num_levels=args.blocks)
|
| 385 |
+
my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
|
| 386 |
+
my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
|
| 387 |
+
elif args.model == 'TSFormer':
|
| 388 |
+
model = TSFormer(args=args)
|
| 389 |
+
my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
|
| 390 |
+
my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
|
| 391 |
+
elif args.model == 'Spikformer_CPG':
|
| 392 |
+
model = Spikformer_CPG(args=args)
|
| 393 |
+
my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
|
| 394 |
+
my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
|
| 395 |
+
elif args.model == 'SpikeGRU':
|
| 396 |
+
model = SpikeGRU(args, hidden_size=args.hidden_size, layers=args.blocks,
|
| 397 |
+
num_steps=args.T, input_size=args.feature_size)
|
| 398 |
+
my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
|
| 399 |
+
my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
|
| 400 |
+
elif args.model == 'SpikeRNN_CPG':
|
| 401 |
+
model = SpikeRNN_CPG(args, hidden_size=args.hidden_size, layers=args.blocks,
|
| 402 |
+
num_steps=args.T, input_size=args.feature_size)
|
| 403 |
+
my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
|
| 404 |
+
my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
|
| 405 |
+
elif args.model == 'SpikeTCN_CPG':
|
| 406 |
+
model = SpikeTCN_CPG(args=args, num_levels=args.blocks)
|
| 407 |
+
my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
|
| 408 |
+
my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
|
| 409 |
+
else:
|
| 410 |
+
raise ValueError(f"Unknown model: {args.model}")
|
| 411 |
+
|
| 412 |
+
model = model.to(device)
|
| 413 |
+
forecast_loss = nn.MSELoss(reduction='mean').to(device)
|
| 414 |
+
|
| 415 |
+
# train
|
| 416 |
+
for epoch in range(args.train_epochs):
|
| 417 |
+
warm = int(0.3 * args.train_epochs)
|
| 418 |
+
cool = epoch >= warm
|
| 419 |
+
|
| 420 |
+
epoch_start_time = time.time()
|
| 421 |
+
model.train()
|
| 422 |
+
loss_total = 0.0
|
| 423 |
+
cnt = 0
|
| 424 |
+
|
| 425 |
+
for x, y in train_dataloader:
|
| 426 |
+
if args.model in MODELS_SET2 and args.model != 'TSGRU':
|
| 427 |
+
reset_states(model=model)
|
| 428 |
+
elif args.model == 'TSGRU':
|
| 429 |
+
remove(model=model.net[0].tslif)
|
| 430 |
+
|
| 431 |
+
x = x.float().to(device)
|
| 432 |
+
y = y.float().to(device)
|
| 433 |
+
|
| 434 |
+
forecast, aux = model(x)
|
| 435 |
+
|
| 436 |
+
if len(forecast.shape) == 4:
|
| 437 |
+
y_rep = y.repeat(args.T, 1, 1, 1)
|
| 438 |
+
else:
|
| 439 |
+
y_rep = y
|
| 440 |
+
|
| 441 |
+
if (args.model in ['SpikF_GO', 'SpikF_GO_CPG']) and args.energy_loss:
|
| 442 |
+
energy_lambda = 20.0
|
| 443 |
+
mse = forecast_loss(forecast, y_rep)
|
| 444 |
+
adaptive_lambda = (mse.detach() / 100.0) * energy_lambda
|
| 445 |
+
loss = mse + adaptive_lambda * aux["rho_hat"]
|
| 446 |
+
else:
|
| 447 |
+
loss = forecast_loss(forecast, y_rep)
|
| 448 |
+
|
| 449 |
+
my_optim.zero_grad(set_to_none=True)
|
| 450 |
+
loss.backward()
|
| 451 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 452 |
+
my_optim.step()
|
| 453 |
+
|
| 454 |
+
loss_total += float(loss)
|
| 455 |
+
cnt += 1
|
| 456 |
+
|
| 457 |
+
if args.model not in MODELS_SET2:
|
| 458 |
+
functional.reset_net(model)
|
| 459 |
+
|
| 460 |
+
if (epoch + 1) % args.exponential_decay_step == 0:
|
| 461 |
+
my_lr_scheduler.step()
|
| 462 |
+
|
| 463 |
+
if (epoch + 1) % args.validate_freq == 0:
|
| 464 |
+
val_loss = validate(model, val_dataloader, train_scaler)
|
| 465 |
+
enc_rate_v = float(aux.get('enc_rate', torch.tensor(0.0)))
|
| 466 |
+
gate_l0_v = float(aux.get('rho_hat', torch.tensor(0.0)))
|
| 467 |
+
freq_act_v = float(aux.get('freq_mask_active', torch.tensor(0.0)))
|
| 468 |
+
|
| 469 |
+
print('Run {} | epoch {:03d} | {:5.2f}s | train_loss {:5.4f} | val_loss {:5.4f} | enc_rate {:.3f} | gate_L0 {:.3f} | f_active {:.3f}'.format(
|
| 470 |
+
run_idx + 1, epoch, (time.time() - epoch_start_time), loss_total / max(1, cnt), val_loss,
|
| 471 |
+
enc_rate_v, gate_l0_v, freq_act_v))
|
| 472 |
+
|
| 473 |
+
save_model_ts(model, result_train_file, epoch)
|
| 474 |
+
|
| 475 |
+
save_model_ts(model, result_train_file, f'final_run_{run_idx+1}')
|
| 476 |
+
|
| 477 |
+
print("--- TEST ---")
|
| 478 |
+
score_scaled, score_orig = test(model, result_test_file, train_scaler, load_epoch=97)
|
| 479 |
+
|
| 480 |
+
scaled_results['mape'].append(score_scaled[0])
|
| 481 |
+
scaled_results['mae'].append(score_scaled[1])
|
| 482 |
+
scaled_results['rmse'].append(score_scaled[2])
|
| 483 |
+
scaled_results['r2'].append(score_scaled[3])
|
| 484 |
+
scaled_results['rse'].append(score_scaled[4])
|
| 485 |
+
|
| 486 |
+
orig_results['mape'].append(score_orig[0])
|
| 487 |
+
orig_results['mae'].append(score_orig[1])
|
| 488 |
+
orig_results['rmse'].append(score_orig[2])
|
| 489 |
+
orig_results['r2'].append(score_orig[3])
|
| 490 |
+
orig_results['rse'].append(score_orig[4])
|
| 491 |
+
|
| 492 |
+
print(f"Run {run_idx + 1} completed.")
|
| 493 |
+
print(_fmt_score("Results", score_scaled))
|
| 494 |
+
|
| 495 |
+
def _mean_std(arr):
|
| 496 |
+
arr = np.asarray(arr, dtype=np.float64)
|
| 497 |
+
return float(np.mean(arr)), float(np.std(arr))
|
| 498 |
+
|
| 499 |
+
print(f"\n{'='*60}")
|
| 500 |
+
print("FINAL RESULTS ACROSS RUNS ")
|
| 501 |
+
print(f"{'='*60}")
|
| 502 |
+
|
| 503 |
+
for tag, store in [("SCALED", scaled_results)]:
|
| 504 |
+
mape_pct = np.asarray(store['mape'], dtype=np.float64) * 100.0
|
| 505 |
+
m_mean, m_std = _mean_std(mape_pct)
|
| 506 |
+
a_mean, a_std = _mean_std(store['mae'])
|
| 507 |
+
r_mean, r_std = _mean_std(store['rmse'])
|
| 508 |
+
r2_mean, r2_std = _mean_std(store['r2'])
|
| 509 |
+
rse_mean, rse_std = _mean_std(store['rse'])
|
| 510 |
+
|
| 511 |
+
print(f"\n[{tag}]")
|
| 512 |
+
print(f"MAPE: {mape_pct} | mean={m_mean:.6f} std={m_std:.6f}")
|
| 513 |
+
print(f"MAE : {np.array(store['mae'])} | mean={a_mean:.6f} std={a_std:.6f}")
|
| 514 |
+
print(f"RMSE: {np.array(store['rmse'])} | mean={r_mean:.6f} std={r_std:.6f}")
|
| 515 |
+
print(f"R2 : {np.array(store['r2'])} | mean={r2_mean:.6f} std={r2_std:.6f}")
|
| 516 |
+
print(f"RSE : {np.array(store['rse'])} | mean={rse_mean:.6f} std={rse_std:.6f}")
|
| 517 |
+
|
| 518 |
+
summary_file = os.path.join('output', args.data, args.model, 'summary_results.txt')
|
| 519 |
+
os.makedirs(os.path.dirname(summary_file), exist_ok=True)
|
| 520 |
+
|
| 521 |
+
with open(summary_file, 'w') as f:
|
| 522 |
+
f.write("Results across 5 runs:\n")
|
| 523 |
+
f.write(f"Seeds used: {seeds}\n\n")
|
| 524 |
+
|
| 525 |
+
for tag, store in [("SCALED", scaled_results)]:
|
| 526 |
+
mape_pct = np.asarray(store['mape'], dtype=np.float64) * 100.0
|
| 527 |
+
m_mean, m_std = _mean_std(mape_pct)
|
| 528 |
+
a_mean, a_std = _mean_std(store['mae'])
|
| 529 |
+
r_mean, r_std = _mean_std(store['rmse'])
|
| 530 |
+
r2_mean, r2_std = _mean_std(store['r2'])
|
| 531 |
+
rse_mean, rse_std = _mean_std(store['rse'])
|
| 532 |
+
|
| 533 |
+
f.write(f"[{tag}]\n")
|
| 534 |
+
f.write(f"MAPE - Individual: {mape_pct}\n")
|
| 535 |
+
f.write(f"MAPE - Mean: {m_mean:.6f}, Std: {m_std:.6f}\n")
|
| 536 |
+
f.write(f"MAE - Individual: {np.array(store['mae'])}\n")
|
| 537 |
+
f.write(f"MAE - Mean: {a_mean:.6f}, Std: {a_std:.6f}\n")
|
| 538 |
+
f.write(f"RMSE - Individual: {np.array(store['rmse'])}\n")
|
| 539 |
+
f.write(f"RMSE - Mean: {r_mean:.6f}, Std: {r_std:.6f}\n")
|
| 540 |
+
f.write(f"R2 - Individual: {np.array(store['r2'])}\n")
|
| 541 |
+
f.write(f"R2 - Mean: {r2_mean:.6f}, Std: {r2_std:.6f}\n\n")
|
| 542 |
+
f.write(f"RSE - Individual: {np.array(store['rse'])}\n")
|
| 543 |
+
f.write(f"RSE - Mean: {rse_mean:.6f}, Std: {rse_std:.6f}\n\n")
|
| 544 |
+
|
| 545 |
+
print(f"\nSaved summary to: {summary_file}")
|
utils/utils.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding:utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def concat_fun(inputs, axis=-1):
|
| 9 |
+
if len(inputs) == 1:
|
| 10 |
+
return inputs[0]
|
| 11 |
+
else:
|
| 12 |
+
return torch.cat(inputs, dim=axis)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def slice_arrays(arrays, start=None, stop=None):
|
| 16 |
+
"""Slice an array or list of arrays.
|
| 17 |
+
|
| 18 |
+
This takes an array-like, or a list of
|
| 19 |
+
array-likes, and outputs:
|
| 20 |
+
- arrays[start:stop] if `arrays` is an array-like
|
| 21 |
+
- [x[start:stop] for x in arrays] if `arrays` is a list
|
| 22 |
+
|
| 23 |
+
Can also work on list/array of indices: `slice_arrays(x, indices)`
|
| 24 |
+
|
| 25 |
+
Arguments:
|
| 26 |
+
arrays: Single array or list of arrays.
|
| 27 |
+
start: can be an integer index (start index)
|
| 28 |
+
or a list/array of indices
|
| 29 |
+
stop: integer (stop index); should be None if
|
| 30 |
+
`start` was a list.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
A slice of the array(s).
|
| 34 |
+
|
| 35 |
+
Raises:
|
| 36 |
+
ValueError: If the value of start is a list and stop is not None.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
if arrays is None:
|
| 40 |
+
return [None]
|
| 41 |
+
|
| 42 |
+
if isinstance(arrays, np.ndarray):
|
| 43 |
+
arrays = [arrays]
|
| 44 |
+
|
| 45 |
+
if isinstance(start, list) and stop is not None:
|
| 46 |
+
raise ValueError('The stop argument has to be None if the value of start '
|
| 47 |
+
'is a list.')
|
| 48 |
+
elif isinstance(arrays, list):
|
| 49 |
+
if hasattr(start, '__len__'):
|
| 50 |
+
# hdf5 datasets only support list objects as indices
|
| 51 |
+
if hasattr(start, 'shape'):
|
| 52 |
+
start = start.tolist()
|
| 53 |
+
return [None if x is None else x[start] for x in arrays]
|
| 54 |
+
else:
|
| 55 |
+
if len(arrays) == 1:
|
| 56 |
+
return arrays[0][start:stop]
|
| 57 |
+
return [None if x is None else x[start:stop] for x in arrays]
|
| 58 |
+
else:
|
| 59 |
+
if hasattr(start, '__len__'):
|
| 60 |
+
if hasattr(start, 'shape'):
|
| 61 |
+
start = start.tolist()
|
| 62 |
+
return arrays[start]
|
| 63 |
+
elif hasattr(start, '__getitem__'):
|
| 64 |
+
return arrays[start:stop]
|
| 65 |
+
else:
|
| 66 |
+
return [None]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def save_model(model, model_dir, epoch=None):
|
| 70 |
+
if model_dir is None:
|
| 71 |
+
return
|
| 72 |
+
if not os.path.exists(model_dir):
|
| 73 |
+
os.makedirs(model_dir)
|
| 74 |
+
epoch = str(epoch) if epoch else ''
|
| 75 |
+
file_name = os.path.join(model_dir, epoch + '_dhfm.pt')
|
| 76 |
+
with open(file_name, 'wb') as f:
|
| 77 |
+
torch.save(model, f)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def load_model(model_dir, epoch=None):
|
| 81 |
+
if not model_dir:
|
| 82 |
+
return
|
| 83 |
+
epoch = str(epoch) if epoch else ''
|
| 84 |
+
file_name = os.path.join(model_dir, epoch + '_dhfm.pt')
|
| 85 |
+
if not os.path.exists(model_dir):
|
| 86 |
+
os.makedirs(model_dir)
|
| 87 |
+
if not os.path.exists(file_name):
|
| 88 |
+
return
|
| 89 |
+
with open(file_name, 'rb') as f:
|
| 90 |
+
model = torch.load(f)
|
| 91 |
+
return model
|
| 92 |
+
|
| 93 |
+
def masked_MAPE(v, v_, axis=None):
|
| 94 |
+
'''
|
| 95 |
+
Mean absolute percentage error.
|
| 96 |
+
:param v: np.ndarray or int, ground truth.
|
| 97 |
+
:param v_: np.ndarray or int, prediction.
|
| 98 |
+
:param axis: axis to do calculation.
|
| 99 |
+
:return: int, MAPE averages on all elements of input.
|
| 100 |
+
'''
|
| 101 |
+
mask = (v == 0)
|
| 102 |
+
percentage = np.abs(v_ - v) / np.abs(v)
|
| 103 |
+
if np.any(mask):
|
| 104 |
+
masked_array = np.ma.masked_array(percentage, mask=mask) # mask the dividing-zero as invalid
|
| 105 |
+
result = masked_array.mean(axis=axis)
|
| 106 |
+
if isinstance(result, np.ma.MaskedArray):
|
| 107 |
+
return result.filled(np.nan)
|
| 108 |
+
else:
|
| 109 |
+
return result
|
| 110 |
+
return np.mean(percentage, axis).astype(np.float64)
|
| 111 |
+
|
| 112 |
+
"""
|
| 113 |
+
original
|
| 114 |
+
def MAPE(v, v_, axis=None):
|
| 115 |
+
'''
|
| 116 |
+
Mean absolute percentage error.
|
| 117 |
+
:param v: np.ndarray or int, ground truth.
|
| 118 |
+
:param v_: np.ndarray or int, prediction.
|
| 119 |
+
:param axis: axis to do calculation.
|
| 120 |
+
:return: int, MAPE averages on all elements of input.
|
| 121 |
+
'''
|
| 122 |
+
mape = (np.abs(v_ - v) / np.abs(v)+1e-5).astype(np.float64)
|
| 123 |
+
mape = np.where(mape > 5, 5, mape)
|
| 124 |
+
return np.mean(mape, axis)
|
| 125 |
+
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def MAPE(v, v_, axis=None):
|
| 129 |
+
'''
|
| 130 |
+
Mean absolute percentage error.
|
| 131 |
+
:param v: np.ndarray or int, ground truth.
|
| 132 |
+
:param v_: np.ndarray or int, prediction.
|
| 133 |
+
:param axis: axis to do calculation.
|
| 134 |
+
:return: float, MAPE averages on all elements of input.
|
| 135 |
+
'''
|
| 136 |
+
mape = (np.abs(v_ - v) / (np.abs(v) + 1e-5)).astype(np.float64)
|
| 137 |
+
mape = np.where(mape > 5, 5, mape) # clip extreme values
|
| 138 |
+
return np.mean(mape, axis)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
#def MAPE(true, pred):
|
| 142 |
+
# return np.mean(np.abs((pred - true) / (true+1e-5)))
|
| 143 |
+
|
| 144 |
+
def smape(P, A):
|
| 145 |
+
nz = np.where(A > 0)
|
| 146 |
+
Pz = P[nz]
|
| 147 |
+
Az = A[nz]
|
| 148 |
+
|
| 149 |
+
return np.mean(2 * np.abs(Az - Pz) / (np.abs(Az) + np.abs(Pz)))
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def R2(y, y_hat, axis=None, eps=1e-12):
|
| 153 |
+
"""
|
| 154 |
+
R^2 score for arrays shaped like [count, time_step, node] (or compatible).
|
| 155 |
+
axis=None -> global scalar R2 over all elements.
|
| 156 |
+
axis can be int or tuple of ints: reduce over those axes, keeping the others.
|
| 157 |
+
"""
|
| 158 |
+
y = np.asarray(y, dtype=np.float64)
|
| 159 |
+
y_hat = np.asarray(y_hat, dtype=np.float64)
|
| 160 |
+
|
| 161 |
+
# residual sum of squares
|
| 162 |
+
ss_res = np.sum((y - y_hat) ** 2, axis=axis)
|
| 163 |
+
|
| 164 |
+
# total sum of squares around mean of y along the same reduction axis
|
| 165 |
+
y_mean = np.mean(y, axis=axis, keepdims=True)
|
| 166 |
+
ss_tot = np.sum((y - y_mean) ** 2, axis=axis)
|
| 167 |
+
|
| 168 |
+
# Avoid division by zero (constant targets)
|
| 169 |
+
denom = ss_tot + eps
|
| 170 |
+
r2 = 1.0 - (ss_res / denom)
|
| 171 |
+
|
| 172 |
+
# If ss_tot is truly ~0, R2 is not well-defined; mark as nan
|
| 173 |
+
# (Optional) If you want 0.0 instead, replace np.nan with 0.0
|
| 174 |
+
if np.isscalar(ss_tot):
|
| 175 |
+
if ss_tot < eps:
|
| 176 |
+
return np.nan
|
| 177 |
+
return float(r2)
|
| 178 |
+
|
| 179 |
+
r2 = np.where(ss_tot < eps, np.nan, r2)
|
| 180 |
+
return r2.astype(np.float64)
|
| 181 |
+
|
| 182 |
+
def RSE(v, v_, axis=None, eps=1e-12):
|
| 183 |
+
'''
|
| 184 |
+
Relative squared error (rooted):
|
| 185 |
+
sqrt( sum((v_ - v)^2) / sum((v - mean(v))^2) )
|
| 186 |
+
:param v: np.ndarray or int, ground truth.
|
| 187 |
+
:param v_: np.ndarray or int, prediction.
|
| 188 |
+
:param axis: axis to do calculation.
|
| 189 |
+
:return: float, RSE on all elements of input (or reduced by axis).
|
| 190 |
+
'''
|
| 191 |
+
v = np.asarray(v, dtype=np.float64)
|
| 192 |
+
v_ = np.asarray(v_, dtype=np.float64)
|
| 193 |
+
|
| 194 |
+
v_mean = np.mean(v, axis=axis, keepdims=True)
|
| 195 |
+
num = np.sum((v_ - v) ** 2, axis=axis)
|
| 196 |
+
denom = np.sum((v - v_mean) ** 2, axis=axis)
|
| 197 |
+
return np.sqrt(num / (denom + eps)).astype(np.float64)
|
| 198 |
+
|
| 199 |
+
def RMSE(v, v_, axis=None):
|
| 200 |
+
'''
|
| 201 |
+
Mean squared error.
|
| 202 |
+
:param v: np.ndarray or int, ground truth.
|
| 203 |
+
:param v_: np.ndarray or int, prediction.
|
| 204 |
+
:param axis: axis to do calculation.
|
| 205 |
+
:return: int, RMSE averages on all elements of input.
|
| 206 |
+
'''
|
| 207 |
+
return np.sqrt(np.mean((v_ - v) ** 2, axis)).astype(np.float64)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def MAE(v, v_, axis=None):
|
| 211 |
+
'''
|
| 212 |
+
Mean absolute error.
|
| 213 |
+
:param v: np.ndarray or int, ground truth.
|
| 214 |
+
:param v_: np.ndarray or int, prediction.
|
| 215 |
+
:param axis: axis to do calculation.
|
| 216 |
+
:return: int, MAE averages on all elements of input.
|
| 217 |
+
'''
|
| 218 |
+
return np.mean(np.abs(v_ - v), axis).astype(np.float64)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def evaluate(y, y_hat, by_step=False, by_node=False):
|
| 222 |
+
'''
|
| 223 |
+
:param y: array in shape of [count, time_step, node].
|
| 224 |
+
:param y_hat: in same shape with y.
|
| 225 |
+
:param by_step: evaluate by time_step dim.
|
| 226 |
+
:param by_node: evaluate by node dim.
|
| 227 |
+
:return: array of mape, mae and rmse.
|
| 228 |
+
'''
|
| 229 |
+
if not by_step and not by_node:
|
| 230 |
+
return MAPE(y, y_hat), MAE(y, y_hat), RMSE(y, y_hat), R2(y, y_hat), RSE(y, y_hat)
|
| 231 |
+
if by_step and by_node:
|
| 232 |
+
return MAPE(y, y_hat, axis=0), MAE(y, y_hat, axis=0), RMSE(y, y_hat, axis=0), R2(y, y_hat, axis=0)
|
| 233 |
+
if by_step:
|
| 234 |
+
return MAPE(y, y_hat, axis=(0, 2)), MAE(y, y_hat, axis=(0, 2)), RMSE(y, y_hat, axis=(0, 2)), R2(y, y_hat, axis=(0, 2))
|
| 235 |
+
if by_node:
|
| 236 |
+
return MAPE(y, y_hat, axis=(0, 1)), MAE(y, y_hat, axis=(0, 1)), RMSE(y, y_hat, axis=(0, 1)), R2(y, y_hat, axis=(0, 1))
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def save_model_ts(model, path, epoch):
|
| 240 |
+
if not os.path.exists(path):
|
| 241 |
+
os.makedirs(path)
|
| 242 |
+
filename = 'epoch_{}.pth'.format(epoch)
|
| 243 |
+
f = os.path.join(path, filename)
|
| 244 |
+
# Save state_dict instead of the entire model
|
| 245 |
+
torch.save(model.state_dict(), f)
|
| 246 |
+
|
| 247 |
+
def load_model_ts(model, path, epoch):
|
| 248 |
+
"""Load state dict into an existing model instance"""
|
| 249 |
+
filename = 'epoch_{}.pth'.format(epoch)
|
| 250 |
+
f = os.path.join(path, filename)
|
| 251 |
+
model.load_state_dict(torch.load(f))
|
| 252 |
+
return model
|