[SNNTorch] Tutorial 3 - The Leaky integrate and fire neuron
์ด์ Tutorial 2์์ ์ค๊ณ๋ LIF neuron model์ ๋ณต์กํ๊ณ hyper parameter์ ์กฐ์ ์ด ํ์ํ๋ค. ์ด๋ ํ๋ผ๋ฏธํฐ์ ์ถ์ ์ ์ด๋ ต๊ณ SNN์ผ๋ก์ ํ์ฅ์ ํ ์ ๋ ๋ณต์กํด ์ง์ผ๋ก ๋จ์ํ๋ฅผ ์งํํ๋ค.
LIF neuron model ๋จ์ํ
๊ฐ์ ์จ : $\beta$
Euler ๋ฐฉ๋ฒ์ ์ด์ฉํด passive membrane ๋ชจ๋ธ์ ํด๋ ๋ค์๊ณผ ๊ฐ๋ค.
\[V_m(t+ฮt)=V_m(t)+ฯฮt(โV_m(t)+I(t)R)\]์ด๋ ์ ๋ ฅ ์ ๋ฅ $I(t)$๊ฐ ์๋ ๊ฒฝ์ฐ๋ฅผ ๊ฐ์ ํ๋ค.
\[V_m(t+ฮt)=V_m(t)โ\frac{ฮt}{ฯ}V_m(t)\]์ด๋ ๋ง์ ์์ ๊ฐ์ ์จ์ $\beta$๋ผ๊ณ ํ๋ฉฐ ์๋์ ๊ฐ๋ค.
\[\beta = \frac{V_m(t+\triangle t)}{V_m(t)}=1-\frac{\triangle t}{\tau}\]๊ฐ์ค ์ ๋ ฅ ์ ๋ฅ
t๊ฐ ์ํ์ค์ ์ ์๊ฐ ๋จ๊ณ๋ผ๊ณ ๊ฐ์ ํ๋ฉด $\triangle t = 1$๋ก ๋ณผ ์ ์๋ค. ๋ํ hyper parameter์๋ฅผ ์ค์ด๊ธฐ ์ํด $R = 1$๋ก ๊ฐ์ ํ๋ฉด ์๋์ ๊ฐ์ ์์ด ์ถ๋ ฅ๋๋ค.
\[\beta = 1 - \frac{1}{C} \to (1-\beta)I_{in}=\frac{1}{\tau}I_{in}\]์ด๋ $1-\beta$๋ฅผ ์ ๋ ฅ ์ ๋ฅ์ ๊ฐ์ค์น๋ผ๊ณ ๋ณด๋ฉฐ membrane ์ ์์ ์๊ฐ์ ์ผ๋ก ๊ธฐ์ฌํ๋ค๊ณ ๊ฐ์ ํ๋ค. ๋ํ ์๊ฐ ๊ตฌ๊ฐ์ด ์งง์์ neuron์ ํ๋์ Spike๋ง ๋ฐ์ํ ์ ์๋ค๊ณ ๊ฐ์ ํ๋ค.
\[V(t+1) = \beta V(t) + (1-\beta)I_{in}(t+1)\]deeplearning์์ ์ ๋ ฅ์ ๊ฐ์ค์น ๊ณ์๊ฐ ํ์ต ๊ฐ๋ฅํ parameter๋ก ์ฌ์ฉ๋๋ค. ์ด๋ ์ ํธ $V(t)$์ ๊ฐ์ค์น W์ ์ํธ์์ฉ์ ๋จ์ํ ํ๊ธฐ ์ํด ๋์ ๊ณฑํ ๊ฒฐ๊ณผ๋ก ํํํ๋ค.
\[V(t+1) = \beta V(t) + WX(t+1)\]Spikint & Reset
๋ง ์ ์๊ฐ ์๊ณ๊ฐ์ ์ด๊ณผํ๋ฉด ๋ด๋ฐ์ด ์ถ๋ ฅ ์คํ์ดํฌ๋ฅผ ๋ฐ์์ํจ๋ค.
\[S[t] = 1, if \;\;V(t)>V_{thr} \\ \;\;\;0, otherwise\]Spike๊ฐ ๋ฐ์ํ๋ฉด membrane ์ ์๋ ์ด๊ธฐํ๊ฐ ๋์ด์ผ ํ๋ค. ์ด๋ ๊ฐ์์ ์ํ ๋ฆฌ์ (reset by substraction) ๋ชจ๋ธ์ ๋ค์๊ณผ ๊ฐ๋ค.
\[V(t+1) = \beta V(t) + WX(t+1)-S(t)V_{thr}\]์ด๋ W๋ ํ์ต ๊ฐ๋ฅํ ํ๋ผ๋ฏธํฐ ์ด๋ฉฐ $V_{thr}$์ ์ข ์ข 1๋ก ์ค์ ๋๋ค.
def leaky_integrate_and_fire(mem, x, w, beta, threshold=1):
spk = (mem > threshold) # ๋ง ์ ์๊ฐ ์๊ณ๊ฐ์ ์ด๊ณผํ๋ฉด spk=1, ๊ทธ๋ ์ง ์์ผ๋ฉด 0
mem = beta * mem + w * x - spk * threshold
return spk, mem
delta_t = torch.tensor(1e-3)
tau = torch.tensor(5e-3)
beta = torch.exp(-delta_t / tau)
print(f"The decay rate is: {beta:.3f}")
num_steps = 200
# ์
๋ ฅ/์ถ๋ ฅ ์ด๊ธฐํ ๋ฐ ์์ ์คํ
์ ๋ฅ ์
๋ ฅ
x = torch.cat((torch.zeros(10), torch.ones(190) * 0.5), 0)
mem = torch.zeros(1)
spk_out = torch.zeros(1)
mem_rec = []
spk_rec = []
# ๋ด๋ฐ ํ๋ผ๋ฏธํฐ
w = 0.4
beta = 0.819
# ๋ด๋ฐ ์๋ฎฌ๋ ์ด์
for step in range(num_steps):
spk, mem = leaky_integrate_and_fire(mem, x[step], w=w, beta=beta)
mem_rec.append(mem)
spk_rec.append(spk)