๐ŸŽฏ Goal
1) Dataset์„ spiking Dataset์œผ๋กœ ๋ณ€ํ™˜
2) ์‹œ๊ฐํ™” ๋ฐฉ๋ฒ•
3) random spike trains ์ƒ์„ฑ ๋ฐฉ๋ฒ•

Intro

์šฐ๋ฆฌ์˜ ๊ฐ๊ฐ์€ ๊ฐ๊ฐ์˜ ์‹ ํ˜ธ(๋น›, ๋ƒ„์ƒˆ, ์••๋ ฅ ๋“ฑ๋“ฑ)์„ Spike๋กœ ๋ณ€ํ™˜ ๋  ๋•Œ ๋А๋‚€๋‹ค.

SNN์„ ๊ตฌ์ถ•ํ•˜๊ธฐ ์œ„ํ•ด์„  ์ž…๋ ฅ์—์„œ๋„ Spike ๋ฐ์ดํ„ฐ๋ฅผ ์ž…๋ ฅํ•˜๋Š” ๊ฒƒ์ด ๋งž์œผ๋ฉฐ ์ด๋•Œ ๋ฐ์ดํ„ฐ๋ฅผ ์ธ์ฝ”๋”ฉํ•˜๋Š” ๋ฐ์—๋Š” 3๊ฐ€์ง€์˜ ์š”์†Œ์—์„œ ๋น„๋กฏ๋œ๋‹ค.

image.png

  • Spike (a) - (b)

    ์ƒ๋ฌผํ•™์  ๋‰ด๋Ÿฐ๋“ค์€ spike๋ฅผ ํ†ตํ•ด ์ •๋ณด๋ฅผ ์ฒ˜๋ฆฌํ•˜๊ณ  ์†Œํ†ตํ•œ๋‹ค. ๋Œ€๋žต 100mV์˜ ์ „์•• ๋ณ€ํ™”๋ฅผ ํ†ตํ•ด 1, 0์œผ๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ๋‹จ์ˆœํ™” ํ•œ๋‹ค. โ†’ ๋ฐ์ดํ„ฐ์˜ ๋‹จ์ˆœํ™”๋ฅผ ํ†ตํ•ด ํ•˜๋“œ์›จ์–ด์—์„œ ํšจ์œจ์ ์œผ๋กœ ์ฒ˜๋ฆฌ ๊ฐ€๋Šฅํ•ด์ง„๋‹ค.

  • Sparsity (c)

    ์‹ ๊ฒฝํ˜• ํ•˜๋“œ์›จ์–ด ๋‰ด๋Ÿฐ์€ ๋Œ€๋ถ€๋ถ„์ด ๋น„ํ™œ์„ฑ ์ฆ‰ 0์ธ ์ƒํƒœ๋กœ ์œ ์ง€ํ•˜๊ณ  ํŠน์ •ํ•œ ์ƒํ™ฉ์—์„œ๋งŒ 1๋กœ ํ™œ์„ฑํ™” ๋˜๋Š” ํฌ์†Œ์„ฑ(Sparsity)๊ฐ€ ์žˆ๋‹ค.

    space complexity : ๋ชจ๋“  ๋ฒกํ„ฐ, ํ–‰๋ ฌ์ด 0์ธ ํ˜•ํƒœ๋Š” ํ”ํ•˜๊ธฐ์— ํ•ด๋‹น ๋ฐ์ดํ„ฐ๋ฅผ ์žฌ์‚ฌ์šฉ ํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ ๋˜ํ•œ ๋Œ€๋ถ€๋ถ„์˜ ์š”์†Œ๊ฐ€ 0์ด๋ฉด 1์ธ ์œ„์น˜๋งŒ ์ €์žฅํ•ด๋„ ๋จ์œผ๋กœ ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ์ด ํฌ๋‹ค.

    time complexity : 0๊ณผ ๊ณฑํ•ด์ง€๋ฉด ํ•ญ์ƒ 0์ด๋‹ˆ ์ด๋Ÿฌํ•œ ๊ณผ์ •์˜ ๊ณ„์‚ฐ ๊ณผ์ •์€ ์ƒ๋žต์ด ๊ฐ€๋Šฅํ•ด์ง„๋‹ค. โ†’ ์—ฐ์‚ฐ๋Ÿ‰์ด ์ค„์–ด๋“ค์–ด ๊ณ„์‚ฐ ์‹œ๊ฐ„์ด ์ค„๊ณ  ์†Œ๋น„ ์ „๋ ฅ๋„ ์ค„์–ด๋“ ๋‹ค.

  • Static-suppression (d) - (e)

    spike neuron์˜ ์ถœ๋ ฅ์€ 0,1 ์ด์ง„์˜ ๊ฐ’์œผ๋กœ ํ‘œํ˜„๋  ์ˆ˜ ์žˆ๋‹ค. ์ฆ‰ ํ•˜๋“œ์›จ์–ด์˜ ๊ตฌํ˜„์ด ๋‹จ์ˆœํ•ด ์ง์œผ๋กœ์จ ํ•˜๋“œ์›จ์–ด ๊ตฌํ˜„์ด ๋‹จ์ˆœํ•ด ์ง„๋‹ค.


Setup

MNIST dataset์„ ๋กœ๋“œํ•˜๊ณ  ํ™˜๊ฒฝ ์„ธํŒ…ํ•˜๋Š” ๊ณผ์ •

$ pip install snntorch

Import package, ํ™˜๊ฒฝ ์„ค์ •

import snntorch as snn
import torch

# Training Parameters
batch_size=128
data_path='/tmp/data/mnist'
num_classes = 10  # MNIST has 10 output classes

# Torch Variables
dtype = torch.float

MNIST dataset ๋‹ค์šด๋กœ๋“œ

from torchvision import datasets, transforms

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28,28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
from snntorch import utils

subset = 10
mnist_train = utils.data_subset(mnist_train, subset)

>>> print(f"The size of mnist_train is {len(mnist_train)}")
The size of mnist_train is 6000

Dataloader

Dataloader์€ ๋ฐ์ดํ„ฐ๋ฅผ ๋„คํŠธ์›Œํฌ๋กœ ์ „๋‹ฌํ•˜๊ธฐ ์œ„ํ•œ ์ธํ„ฐํŽ˜์ด์Šค๋กœ batch_size ํฌ๊ธฐ๋กœ ๋ถ„ํ•  ๋œ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฐ˜ํ™˜ํ•œ๋‹ค.

from torch.utils.data import DataLoader

train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)

Spike encoding

SNN์€ ์‹œ๊ฐ„์— ๋”ฐ๋ผ ๋ณ€ํ•˜๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ํ™œ์šฉํ•œ๋‹ค. ๊ทธ๋Ÿฐ๋ฐ MNIST๋Š” ์‹œ๊ฐ„์— ๋”ฐ๋ผ ๋ณ€ํ•˜๋Š” ๋ฐ์ดํ„ฐ ์…‹์ด ์•„๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ MNIST๋ฅผ SNN์—์„œ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์€ ๋‘ ๊ฐ€์ง€๊ฐ€ ์กด์žฌํ•œ๋‹ค.

  • ๋™์ผ ์ƒ˜ํ”Œ ์ „๋‹ฌ

    image.png

    ๋™์ผํ•œ ์ด๋ฏธ์ง€๋ฅผ SNN์— ์ „๋‹ฌํ•œ๋‹ค. ์ •์ง€๋œ ์ด๋ฏธ์ง€์˜ ๋™์˜์ƒ์˜ ํ˜•ํƒœ๋กœ ์ธ์ฝ”๋”ฉ ํ•œ๋‹ค๊ณ  ๋ณด๋ฉด ๋œ๋‹ค. ๋‹ค๋งŒ ์ด๋ฅผ ์ด์šฉํ•  ๊ฒฝ์šฐ SNN์˜ ์‹œ๊ฐ„์  ์š”์†Œ๋ฅผ ํ™œ์šฉํ•˜์ง€ ๋ชปํ•œ๋‹ค๋Š” ๋‹จ์ ์ด ์กด์žฌํ•œ๋‹ค.

  • ์‹œ๊ฐ„์— ๋”ฐ๋ฅธ spike squence ๋ณ€ํ™˜

    image.png

    ์ž…๋ ฅ์„ spike ์—ด๋กœ ๋ณ€ํ™˜ ํ›„ ์‹œ๊ฐ„์— ๋”ฐ๋ผ ๋ณ€ํ™”ํ•˜๋Š” Spike squence๋กœ ๋ณ€ํ™˜๋œ๋‹ค.

Spike Encoding function

SNNTorch์—์„œ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ Spike encoding์€ ์„ธ ๊ฐ€์ง€๊ฐ€ ์กด์žฌํ•œ๋‹ค.

1) Rate coding - ย spikingย frequency
2) latency coding - spikeย timing
3) delta modulation - temporalย changeย of input features to generate spikes


Rate coding

ํŠน์ • ์‹œ๊ฐ„ ๋™์•ˆ spike๊ฐ€ ๋ฐœ์ƒ๋  ํ™•๋ฅ ์„ ์ž…๋ ฅํ•˜๊ณ  ๋ฒ ๋ฅด๋ˆ„์ด ์‹œํ–‰

\[P(R_{ij}=1)=X_{ij}=1-P(R_{ij}=0)\]

$X_{ij}$์€ Spike๊ฐ€ ์ฃผ์–ด์ง„ time step์—์„œ ์ผ์–ด๋‚  ํ™•๋ฅ ๋กœ ์‚ฌ์šฉ๋œ๋‹ค.
$R_{ij}$์€

# Temporal Dynamics
num_steps = 10

# create vector filled with 0.5
raw_vector = torch.ones(num_steps) * 0.5

# pass each sample through a Bernoulli trial
rate_coded_vector = torch.bernoulli(raw_vector)
print(f"Converted vector: {rate_coded_vector}")

ํฐ ์ˆ˜์˜ ๋ฒ•์น™์— ์˜๊ฑฐํ•ด num_steps๊ฐ€ ์ฆ๊ฐ€ํ•  ์ˆ˜๋ก ์›๋ž˜ raw๊ฐ’์— ๊ฐ€๊นŒ์›Œ ์ง„๋‹ค.


Latency coding


Delta modulation


Updated: