[SNNTorch] Tutorial 1 - Spike Encoding
๐ฏ Goal
1) Dataset์ spiking Dataset์ผ๋ก ๋ณํ
2) ์๊ฐํ ๋ฐฉ๋ฒ
3) random spike trains ์์ฑ ๋ฐฉ๋ฒ
Intro
์ฐ๋ฆฌ์ ๊ฐ๊ฐ์ ๊ฐ๊ฐ์ ์ ํธ(๋น, ๋์, ์๋ ฅ ๋ฑ๋ฑ)์ Spike๋ก ๋ณํ ๋ ๋ ๋๋๋ค.
SNN์ ๊ตฌ์ถํ๊ธฐ ์ํด์ ์ ๋ ฅ์์๋ Spike ๋ฐ์ดํฐ๋ฅผ ์ ๋ ฅํ๋ ๊ฒ์ด ๋ง์ผ๋ฉฐ ์ด๋ ๋ฐ์ดํฐ๋ฅผ ์ธ์ฝ๋ฉํ๋ ๋ฐ์๋ 3๊ฐ์ง์ ์์์์ ๋น๋กฏ๋๋ค.
-
Spike (a) - (b)
์๋ฌผํ์ ๋ด๋ฐ๋ค์ spike๋ฅผ ํตํด ์ ๋ณด๋ฅผ ์ฒ๋ฆฌํ๊ณ ์ํตํ๋ค. ๋๋ต 100mV์ ์ ์ ๋ณํ๋ฅผ ํตํด 1, 0์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ๋จ์ํ ํ๋ค. โ ๋ฐ์ดํฐ์ ๋จ์ํ๋ฅผ ํตํด ํ๋์จ์ด์์ ํจ์จ์ ์ผ๋ก ์ฒ๋ฆฌ ๊ฐ๋ฅํด์ง๋ค.
-
Sparsity (c)
์ ๊ฒฝํ ํ๋์จ์ด ๋ด๋ฐ์ ๋๋ถ๋ถ์ด ๋นํ์ฑ ์ฆ 0์ธ ์ํ๋ก ์ ์งํ๊ณ ํน์ ํ ์ํฉ์์๋ง 1๋ก ํ์ฑํ ๋๋ ํฌ์์ฑ(Sparsity)๊ฐ ์๋ค.
space complexity : ๋ชจ๋ ๋ฒกํฐ, ํ๋ ฌ์ด 0์ธ ํํ๋ ํํ๊ธฐ์ ํด๋น ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฌ์ฉ ํ ์ ์์ผ๋ฉฐ ๋ํ ๋๋ถ๋ถ์ ์์๊ฐ 0์ด๋ฉด 1์ธ ์์น๋ง ์ ์ฅํด๋ ๋จ์ผ๋ก ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ์ด ํฌ๋ค.
time complexity : 0๊ณผ ๊ณฑํด์ง๋ฉด ํญ์ 0์ด๋ ์ด๋ฌํ ๊ณผ์ ์ ๊ณ์ฐ ๊ณผ์ ์ ์๋ต์ด ๊ฐ๋ฅํด์ง๋ค. โ ์ฐ์ฐ๋์ด ์ค์ด๋ค์ด ๊ณ์ฐ ์๊ฐ์ด ์ค๊ณ ์๋น ์ ๋ ฅ๋ ์ค์ด๋ ๋ค.
-
Static-suppression (d) - (e)
spike neuron์ ์ถ๋ ฅ์ 0,1 ์ด์ง์ ๊ฐ์ผ๋ก ํํ๋ ์ ์๋ค. ์ฆ ํ๋์จ์ด์ ๊ตฌํ์ด ๋จ์ํด ์ง์ผ๋ก์จ ํ๋์จ์ด ๊ตฌํ์ด ๋จ์ํด ์ง๋ค.
Setup
MNIST dataset์ ๋ก๋ํ๊ณ ํ๊ฒฝ ์ธํ ํ๋ ๊ณผ์
$ pip install snntorch
Import package, ํ๊ฒฝ ์ค์
import snntorch as snn
import torch
# Training Parameters
batch_size=128
data_path='/tmp/data/mnist'
num_classes = 10 # MNIST has 10 output classes
# Torch Variables
dtype = torch.float
MNIST dataset ๋ค์ด๋ก๋
from torchvision import datasets, transforms
# Define a transform
transform = transforms.Compose([
transforms.Resize((28,28)),
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0,), (1,))])
mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
from snntorch import utils
subset = 10
mnist_train = utils.data_subset(mnist_train, subset)
>>> print(f"The size of mnist_train is {len(mnist_train)}")
The size of mnist_train is 6000
Dataloader
Dataloader์ ๋ฐ์ดํฐ๋ฅผ ๋คํธ์ํฌ๋ก ์ ๋ฌํ๊ธฐ ์ํ ์ธํฐํ์ด์ค๋ก batch_size ํฌ๊ธฐ๋ก ๋ถํ ๋ ๋ฐ์ดํฐ๋ฅผ ๋ฐํํ๋ค.
from torch.utils.data import DataLoader
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
Spike encoding
SNN์ ์๊ฐ์ ๋ฐ๋ผ ๋ณํ๋ ๋ฐ์ดํฐ๋ฅผ ํ์ฉํ๋ค. ๊ทธ๋ฐ๋ฐ MNIST๋ ์๊ฐ์ ๋ฐ๋ผ ๋ณํ๋ ๋ฐ์ดํฐ ์ ์ด ์๋๋ค. ๋ฐ๋ผ์ MNIST๋ฅผ SNN์์ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ ๋ ๊ฐ์ง๊ฐ ์กด์ฌํ๋ค.
-
๋์ผ ์ํ ์ ๋ฌ
๋์ผํ ์ด๋ฏธ์ง๋ฅผ SNN์ ์ ๋ฌํ๋ค. ์ ์ง๋ ์ด๋ฏธ์ง์ ๋์์์ ํํ๋ก ์ธ์ฝ๋ฉ ํ๋ค๊ณ ๋ณด๋ฉด ๋๋ค. ๋ค๋ง ์ด๋ฅผ ์ด์ฉํ ๊ฒฝ์ฐ SNN์ ์๊ฐ์ ์์๋ฅผ ํ์ฉํ์ง ๋ชปํ๋ค๋ ๋จ์ ์ด ์กด์ฌํ๋ค.
-
์๊ฐ์ ๋ฐ๋ฅธ spike squence ๋ณํ
์ ๋ ฅ์ spike ์ด๋ก ๋ณํ ํ ์๊ฐ์ ๋ฐ๋ผ ๋ณํํ๋ Spike squence๋ก ๋ณํ๋๋ค.
Spike Encoding function
SNNTorch์์ ์ฌ์ฉ ๊ฐ๋ฅํ Spike encoding์ ์ธ ๊ฐ์ง๊ฐ ์กด์ฌํ๋ค.
1) Rate coding - ย spikingย frequency
2) latency coding - spikeย timing
3) delta modulation - temporalย changeย of input features to generate spikes
Rate coding
ํน์ ์๊ฐ ๋์ spike๊ฐ ๋ฐ์๋ ํ๋ฅ ์ ์ ๋ ฅํ๊ณ ๋ฒ ๋ฅด๋์ด ์ํ
\[P(R_{ij}=1)=X_{ij}=1-P(R_{ij}=0)\]$X_{ij}$์ Spike๊ฐ ์ฃผ์ด์ง time step์์ ์ผ์ด๋ ํ๋ฅ ๋ก ์ฌ์ฉ๋๋ค.
$R_{ij}$์
# Temporal Dynamics
num_steps = 10
# create vector filled with 0.5
raw_vector = torch.ones(num_steps) * 0.5
# pass each sample through a Bernoulli trial
rate_coded_vector = torch.bernoulli(raw_vector)
print(f"Converted vector: {rate_coded_vector}")
ํฐ ์์ ๋ฒ์น์ ์๊ฑฐํด num_steps๊ฐ ์ฆ๊ฐํ ์๋ก ์๋ raw๊ฐ์ ๊ฐ๊น์ ์ง๋ค.