๐Ÿ“Œ Early Experiments on Sparse Autoencoding (SAE)

์‹ ๊ฒฝ๋ง์˜ ํ‘œํ˜„์— ์ˆจ์–ด์žˆ๋Š” ์ˆ˜๋งŽ์€ ํŠน์ง•๋“ค์„ ์ฐพ์•„๋‚ด๊ธฐ ์œ„ํ•ด์„œ ์ตœ๊ทผ ์—ฐ๊ตฌ๋˜๋Š” SAE ๊ธฐ๋ฒ•์„ ์ถ”์ข…ํ•˜๊ธฐ ์œ„ํ•ด ์‹คํ—˜์„ ์ง„ํ–‰ํ•˜์˜€๋‹ค.

๋ชจ๋ธ์˜ ํ‘œํ˜„์„ ํ•ด์„ํ•œ๋‹ค๋Š” ๊ฒƒ์€ ํŠน์ • ์•กํ‹ฐ๋ฒ ์ด์…˜ ํŒจํ„ด์ด ์ง€๋‹ˆ๋Š” ์˜๋ฏธ๋ฅผ ๋ฐ์ดํ„ฐ์™€ ์—ฐ๊ฒฐ ์ง“๋Š” ๊ฒƒ์ด๋‹ค. ์ตœ๊ทผ ์•Œ๋ ค์ง„ ๋ฐ”์— ๋”ฐ๋ฅด๋ฉด ๋ชจ๋ธ์ด ํŠน์ง•์„ ์ €์žฅํ•˜๊ธฐ ์œ„ํ•ด์„œ ๋ณต์ˆ˜ ๊ฐœ์˜ ๋‰ด๋Ÿฐ์ด ํŠน์ • ํŒจํ„ด์„ ๋งŒ๋“ค์–ด์„œ ์ €์žฅํ•œ๋‹ค๋Š” ๊ฒƒ์ด ์•Œ๋ ค์กŒ๋‹ค. ์ด ์‚ฌ์‹ค์€ ๊ธฐ์กด์— binary์ ์ธ ๋‰ด๋Ÿฐ์˜ ์ผœ์ง€๊ณ  ๊บผ์ง€๋Š” ํ˜„์ƒ์œผ๋กœ ํŠน์ง•์˜ ์œ ๋ฌด๋ฅผ ํŒ๋‹จํ–ˆ๋˜ ๊ฒƒ์„ ๋„˜์–ด์„œ ํŠน์ • ์•กํ‹ฐ๋ฒ ์ด์…˜ ํŒจํ„ด์ด ์˜๋ฏธ๋ฅผ ์ง€๋‹Œ๋‹ค๋Š” ์‚ฌ์‹ค์„ ๋‚˜ํƒ€๋‚ธ๋‹ค. ์ถ•์•ฝ๋œ ์ฐจ์›์— ์ €์žฅ๋œ ๋ฌด์ˆ˜ํžˆ ๋งŽ์€ ํŠน์ง•์„ ์ฐพ๊ธฐ ์œ„ํ•ด์„œ ์—ฐ๊ตฌ์ž๋“ค์ด ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ์‹์€ dictionary learning์ด๋‹ค. ํŠน์ง•์ธ ํŒจํ„ด์„ ์ง‘์–ด๋„ฃ๊ณ , ๋‹ค์‹œ ๋ณต์›ํ•˜๋Š” ๊ฐ„๋‹จํ•œ ๋ฌธ์ œ์—์„œ ๋‚ด๋ถ€์— ๋ฌด์ˆ˜ํžˆ ๋งŽ์€ feature๋“ค์„ ์ €์žฅํ•˜๊ณ  ์„ ํƒํ•˜๋Š” ๋ฐฉ์‹์€, ๊ทธ๋“ค์˜ linear sum์ด ์‹ ๊ฒฝ๋ง์˜ ํŠน์ง•์„ ๋‚˜ํƒ€๋‚ธ๋‹ค๋Š” ๊ฐ€์ •์„ ๋ณด์ธ๋‹ค.

๋‚˜๋Š” ์ด ์—ฐ๊ตฌ๋ฅผ ์ถ”์ข…ํ•˜๊ธฐ ์œ„ํ•ด์„œ ์ฝ”๋“œ๋ฅผ ๊ตฌํ˜„ํ•˜๊ณ  ์‹คํ—˜ํ•˜์˜€๋‹ค. Wikipedia 10๋งŒ๊ฐœ ๋ฌธ์„œ์— ๋Œ€ํ•ด์„œ Llama2์˜ activation์„ ์ˆ˜์ง‘ํ•˜์—ฌ SAE ํ•™์Šต์„ ์ง„ํ–‰ํ•˜์˜€๋‹ค. ์ด ๋ถ€๋ถ„์€ ๊ทธ ๊ธธ์—์„œ ๊ฐ€์žฅ ๊ฐ„๋‹จํ•œ ๊ตฌ์กฐ๋กœ๋ถ€ํ„ฐ ๋ฌด์—‡์ด ์ž˜๋˜๊ณ  ์•ˆ๋˜๋Š”์ง€ ํŒŒ์•…ํ•˜๊ธฐ ์œ„ํ•œ ์—ฐ๊ตฌ์˜€๋‹ค. ์‹คํ—˜ ๊ฒฐ๊ณผ, ๋ชจ๋ธ์˜ ์ธต์ด ๋†’์•„์งˆ์ˆ˜๋ก Reconstruction์ด ์ œ๋Œ€๋กœ ์ผ์–ด๋‚˜์ง€ ์•Š์•˜๊ณ , GatedSAE๋Š” ๋” ๋‚ฎ์€ ์„ฑ๋Šฅ์„ ๋ณด์˜€๋‹ค. ์ด๋Š” Neuron Resampling๊ธฐ๋ฒ•์„ ์ ์šฉํ•˜์ง€ ์•Š์•˜๊ธฐ ๋•Œ๋ฌธ์— ๊ทธ๋Ÿฐ ๊ฒƒ์ด๋‹ค.

์•„ํ‚คํ…์ฒ˜

SAEs

\(f(x) = ReLU(W_{enc} (x- b_{dec}) + b_{enc}) \\ \hat{x}(f) = W_{dec}f + b_{dec}\)

\[L(x) = || x - \hat{x}(f(x)) ||_2^2 + \lambda ||f(x)||_1\]

Gated SAEs

\[\hat{f}(x) = 1[W_{gate}(x-b_{dec} +b_{gate}) >0] \odot ReLU(W_{mag} (x- b_{dec}) + b_{mag})\]

where $1[\cdot >0]$ is the pointwise Heavyside step function and $\odot$ denotes elementwise multiplication. To reduce the number of weights, the authors set the weight \((W_{mag})_{ij} = (\exp(r_{mag}))_i \cdot (W_{gate})_{ij}\) where $r_{mag} \in \mathbb{R}^M$ is the rescaling parameter.

\[L_{incorrect}(x) = L(x) = || x - \hat{x}(f(x)) ||_2^2 + \lambda ||f_{gate}(x)||_1\]

ํ‰๊ฐ€

1. Reconstruction

ํ‘œํ˜„์„ Dictionary learning์œผ๋กœ ๊ฒฐํ•ฉํ•˜์˜€์„ ๋•Œ ๋ณต์›๋ ฅ

\[\Vert x - \hat{x}(f(x)) \Vert_2^2\]

2. L0

ํ‘œํ˜„์„ Dictionary learning์œผ๋กœ ๊ฒฐํ•ฉํ•˜์˜€์„ ๋•Œ ๋ณต์›๋ ฅ์— ํ•„์š”ํ–ˆ๋˜ ์•„์ดํ…œ ๊ฐœ์ˆ˜

L0 be the number of non zero features

\[L0 = \mathbb{E}_{X\sim D} \Vert f(x)\Vert_0\]

3. loss recovered

ํ‘œํ˜„์„ Dictionary learning์œผ๋กœ ๊ฒฐํ•ฉํ•˜์˜€์„ ๋•Œ ๋ณต์›๋ ฅ (๋ชจ๋ธ ์„ฑ๋Šฅ ๊ธฐ๋ฐ˜)

\[1 - \frac{CE(\hat{x} \cdot \hat{f}) - CE(ID)}{CE(\psi) - CE(ID)}\]

where

๋ฐ์ดํ„ฐ


์ดˆ๊ธฐ ์‹คํ—˜ ๊ฒฐ๊ณผ

  1. lr_scheduler: cosine hard restart๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ, ํ•™์Šต์— ํšจ๊ณผ๊ฐ€ ์ข‹์ง€ ์•Š๋‹ค.
  2. gatedSAE์˜ ๊ฒฝ์šฐ, SAE๋ณด๋‹ค ์„ฑ๋Šฅ์ด ์ข‹์ง€ ์•Š๋‹ค. (๋…ผ๋ฌธ์—์„œ๋Š” Neuron Resampling์„ ํ•˜์˜€๋‹ค๊ณ  ํ•œ๋‹ค. ์ด ๋ถ€๋ถ„์ด ์„ฑ๋Šฅ ์ฐจ์ด๋ฅผ ๋ณด์ธ ๊ฒƒ์œผ๋กœ ๋ณด์ธ๋‹ค. ์™œ๋ƒํ•˜๋ฉด, gate๊ฐ€ ์ผœ์ง€๋Š” ๊ฒฝ์šฐ, gate์— Bias๊ฐ€ ๋˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. )
  3. early, middle, later์— ๋Œ€ํ•ด์„œ ์•ž์˜ ๋‘ ๊ฐœ๋Š” ํ•™์Šต์ด ์ œ๋Œ€๋กœ ๋˜์—ˆ์ง€๋งŒ, ๋งˆ์ง€๋ง‰์€ ํ•™์Šต๋˜์ง€ ์•Š์•˜๋‹ค.
  4. L1 coefficient๋Š” ํ•˜๋Š” ๋งŒํผ Reconstruction์— ์˜ํ–ฅ์„ ๋ฏธ์นœ๋‹ค.

์™œ ๋ ˆ์ด์–ด๊ฐ€ ์˜ฌ๋ผ๊ฐˆ์ˆ˜๋ก Reconstruction์€ ๋˜์ง€ ์•Š๋Š”๊ฐ€?

ํŠน์ง•์ด๋ผ๋Š” ๊ฒƒ์€ ๊ฒฐ์ •์ ์ด๊ณ , ๊ณ ์ •์ ์ธ ์„ฑ์งˆ์„ ์ง€๋…€์•ผ ํ•œ๋‹ค. ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์‚ฌ๊ณ  ์‹คํ—˜์„ ์ƒ๊ฐํ•ด๋ณด์ž.

  1. ๋ ˆ๊ณ ๋ฅผ ์—ฌ๋Ÿฌ ๊ฐœ ๋ณด์—ฌ์ค€๋‹ค. (input batch)
  2. ๋ ˆ๊ณ ๋ฅผ ๊ตฌ์„ฑํ•˜๋Š” ์ปดํฌ๋„ŒํŠธ๋“ค์„ ๊ณ ๋ฅด๋ผ๊ณ  ํ•œ๋‹ค. (items)
  3. ์ปดํฌ๋„ŒํŠธ๋“ค์„ ๊ฒฐํ•จํ•˜์—ฌ ๊ฐ ๋ ˆ๊ณ ๋“ค์„ ๋ชจ๋‘ ๋ณต์›ํ•˜๋ผ๊ณ  ํ•œ๋‹ค (selection)
  4. ์›๋ž˜ ๋ ˆ๊ณ ๋“ค๊ณผ ๊ฐ๊ฐ ๋น„๊ตํ•œ๋‹ค. (reconstruction)

์ „์ฒด ๊ณผ์ •์—์„œ 2, 3๋ฒˆ์€ ๋”ฅ๋Ÿฌ๋‹์˜ End-to-End ํ•™์Šต์„ ํ†ตํ•ด์„œ ์ ์ ˆํ•œ Item ํ‘œํ˜„๊ณผ selection์„ ์ฐพ๋Š” ๊ณผ์ •์ด๋‹ค. ๊ทธ ๊ณผ์ •์—์„œ Item์˜ ํ‘œํ˜„์ด ์ ์ ˆํ•˜์ง€ ์•Š๊ฑฐ๋‚˜ selection๋ถ€๋ถ„์ด ์ ์ ˆํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ๋‹ค.

๋ ˆ์ด์–ด๊ฐ€ ๋†’์•„์งˆ์ˆ˜๋ก ํ•™์Šต๋˜์ง€ ์•Š์•˜๋˜ ๊ฒƒ์€ ๋‘ ๊ฐ€์ง€ ๋ชจ๋‘ ์˜ํ–ฅ์„ ๋ผ์น˜๋Š” ๊ฒƒ์œผ๋กœ ํ™•์ธ๋œ๋‹ค. ์ฆ‰, ๋ ˆ์ด์–ด๊ฐ€ ์˜ฌ๋ผ๊ฐ€๋ฉด์„œ ๋‹ค์Œ ๋ฌธ์ œ๊ฐ€ ์ƒ๊ธด๋‹ค.

๋ณต์žกํ•œ ๋ ˆ์ด์–ด์˜ ํ‘œํ˜„๊ณต๊ฐ„์— ๋Œ€ํ•ด์„œ ์ ์ ˆํ•œ Primitive component๋ฅผ ์ฐพ์„ ์ˆ˜ ์žˆ๋Š”๊ฐ€? (No)