BERT papíron mentem keresztül, amely a GELU (Gaussian Lineáris Egység) , amely az egyenletet $$ GELU (x) = xP (X ≤ x) = xΦ (x). $$ ami viszont megközelítőleg $$ 0,5x (1 + tanh [\ sqrt {2 / π} (x + 0,044715x ^ 3)]) $$

Egyszerűsítené az egyenletet, és elmagyarázná, hogyan közelítették meg.

Válasz

GELU függvény

Bővíthetjük a $ \ mathcal {N} (0, 1) $ kumulatív eloszlást span> , azaz $ \ Phi (x) $ , az alábbiak szerint: $$ \ text {GELU} (x): = x {\ Bbb P} (X \ le x) = x \ Phi (x) = 0,5x \ bal (1+ \ text {erf} \ left (\ frac {x} {\ sqrt {2 }} \ right) \ right) $$

Ne feledje, hogy ez egy definíció , és nem egyenlet (vagy reláció). A szerzők megalapozták ezt a javaslatot, pl. sztochasztikus analógia , matematikailag azonban ez csak egy meghatározás.

Itt van a GELU cselekménye:

Tanh közelítés

Az ilyen típusú numerikus közelítéseknél a legfontosabb ötlet egy hasonló (elsősorban tapasztalatokon alapuló) függvény megtalálása, paraméterezése, majd illesztése a következőhöz: az eredeti függvény pontjainak halmaza.

Annak tudatában, hogy a $ \ text {erf} (x) $ nagyon közel van a $ \ text-hez {tanh} (x) $

és a $ \ text {erf} (\ frac {x} {\ sqrt {2}}) $ egybeesik a $ \ text {tanh} (\ sqrt { \ frac {2} {\ pi}} x) $ itt: $ x = 0 $ , ami $ \ sqrt {\ frac {2} {\ pi}} $ , folytatjuk a $$ \ text {tanh} \ left (\ sqrt {\ frac { 2} {\ pi}} (x + ax ^ 2 + bx ^ 3 + cx ^ 4 + dx ^ 5) \ jobbra) $$ (vagy több kifejezéssel) egy ponthalmazra $ \ left (x_i, \ text {erf} \ left (\ frac {x_i} {\ sqrt {2}} \ right) \ right) $ .

Ezt a függvényt 20 mintához illesztettem a $ (- 1.5, 1.5) $ ( ennek a webhelynek a használatával ), és itt vannak az együtthatók:

A $ a = c = d = 0 $ , $ b $ a becslések szerint 0,04495641 USD $ . Több mintával szélesebb tartományból (az a webhely csak 20 engedélyezett) az $ b $ együttható közelebb lesz a “s papírhoz 0,044715 USD $ . Végül kapunk

$ \ text {GELU} (x) = x \ Phi (x) = 0,5x \ left (1 + \ text {erf} \ left (\ frac {x} {\ sqrt {2}} \ right) \ right) \ simeq 0,5x \ left (1+ \ text {tanh} \ left (\ sqrt {\ frac { 2} {\ pi}} (x + 0,044715x ^ 3) \ jobbra \ jobbra) $

átlagos négyzetes hibával $ \ sim 10 ^ {- 8} $ a $ x \ esetén [-10, 10] $ .

Ne feledje, hogy ha mégis nem használja az első származtatott ügyletek kapcsolatát, a $ \ sqrt {\ frac {2} {\ pi}} $ kifejezést a következőkben szerepeltették volna a paraméterekben: $$ 0.5x \ left (1+ \ text {tanh} \ left (0.797885x + 0.035677x ^ 3 \ right) \ right) $$ ami kevésbé szép (kevésbé elemző , még numerikusabb)!

A paritás kihasználása

Ahogy azt @BookYourLuck , felhasználhatjuk a funkciók paritását a keresendő polinomok terének korlátozására. Vagyis mivel a $ \ text {erf} $ egy páratlan függvény, azaz $ f (-x) = – f (x) $ , és a $ \ text {tanh} $ szintén páratlan függvény, polinom függvény $ A $ \ text {tanh} $ belsejében lévő \ text {pol} (x) $ értéknek szintén páratlannak kell lennie (csak $ x $ ) $$ \ text {erf} (- x) \ simeq \ text {tanh} (\ text {pol} (-x)) = \ text {tanh} (- \ text {pol} (x)) = – \ text {tanh} (\ text {pol} (x)) \ simeq- \ text {erf} (x) $$

Korábban szerencsénk volt, hogy (majdnem) nulla együtthatókat kaptunk a páros hatványok $ x ^ 2 $ és $ x ^ 4 $ , azonban általában ez gyenge minőségű közelítésekhez vezethet, amelyeknek például van egy olyan kifejezésük, mint 0,23x ^ 2 $ dollár, amelyet külön feltételek (páros vagy páratlan) törölnek ahelyett, hogy egyszerűen a $ 0x ^ 2 $ lehetőséget választaná.

Szigmoid közelítés

Hasonló kapcsolat áll fenn $ \ text {erf} (x) $ és $ 2 \ left (\ sigma (x) – \ frac {1} {2} \ right) $ (sigmoid), amelyet a cikk másik közelítésként javasol, átlagos négyzethibával $ \ sim 10 ^ {- 4} $ a $ x \ esetén [-10, 10] $ .

Itt van egy Python-kód adatpontok előállításához, a függvények illesztéséhez és az átlagos négyzetes hibák kiszámításához:

import math import numpy as np import scipy.optimize as optimize def tahn(xs, a): return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs] def sigmoid(xs, a): return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs] print_points = 0 np.random.seed(123) # xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0, # .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2] # xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8))) # xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6))) xs = np.arange(-10, 10, 0.001) erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs]) ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs]) # Fit tanh and sigmoid curves to erf points tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs) print("Tanh fit: a=%5.5f" % tuple(tanh_popt)) sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs) print("Sigmoid fit: a=%5.5f" % tuple(sig_popt)) # curves used in https://mycurvefit.com: # 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5)) # 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3)) y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs]) tanh_error_paper = (np.square(ys - y_paper_tanh)).mean() y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs]) tanh_error_alt = (np.square(ys - y_alt_tanh)).mean() # curve used in https://mycurvefit.com: # 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5) y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs]) sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean() y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs]) sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean() print("Paper tanh error:", tanh_error_paper) print("Alternative tanh error:", tanh_error_alt) print("Paper sigmoid error:", sigmoid_error_paper) print("Alternative sigmoid error:", sigmoid_error_alt) if print_points == 1: print(len(xs)) for x, erf in zip(xs, erfs): print(x, erf) 

Kimenet:

Tanh fit: a=0.04485 Sigmoid fit: a=1.70099 Paper tanh error: 2.4329173471294176e-08 Alternative tanh error: 2.698034519269613e-08 Paper sigmoid error: 5.6479106346814546e-05 Alternative sigmoid error: 5.704246564663601e-05 

megjegyzések

  • Miért van szükség a közelítésre? Nem lehet, hogy ' nem csak erf függvényt használnak?

Válasz

Először vegye figyelembe, hogy $$ \ Phi (x) = \ frac12 \ mathrm {erfc} \ left (- \ frac {x} {\ sqrt {2}} \ right) = \ frac12 \ left (1 + \ mathrm {erf} \ left (\ frac {x} {\ sqrt2} \ right) \ right) $$ a $ \ paritás szerint mathrm {erf} $ . Meg kell mutatnunk, hogy $$ \ mathrm {erf} \ left (\ frac x {\ sqrt2} \ right) \ kb \ tanh \ left (\ sqrt {\ frac2 \ pi} \ left (x + ax ^ 3 \ right) \ right) $$ a $ a \ kb 0,044715 $ esetében.

A $ x $ nagy értéke esetén mindkét függvény a $ [- 1, 1 mezőben van korlátozva ] $ . A kis $ x $ esetében a megfelelő Taylor-sorozat a következőt tartalmazza: $$ \ tanh (x) = x – \ frac {x ^ 3} {3} + o (x ^ 3) $$ és $$ \ mathrm {erf} (x) = \ frac {2} {\ sqrt {\ pi}} \ left (x – \ frac {x ^ 3} {3} \ right) + o (x ^ 3). $$ Cserélve azt kapjuk, hogy $$ \ tanh \ left (\ sqrt {\ frac2 \ pi} \ left (x + ax ^ 3 \ right) \ right) = \ sqrt \ frac {2} {\ pi} \ left (x + \ left (a – \ frac {2} {3 \ pi} \ right) x ^ 3 \ right) + o (x ^ 3) $$ és $$ \ mathrm {erf } \ left (\ frac x {\ sqrt2} \ right) = \ sqrt \ frac2 \ pi \ left (x – \ frac {x ^ 3} {6} \ right) + o (x ^ 3). $$ A $ x ^ 3 $ együtthatóval egyenlővé téve $$ a \ kb 0,04553992412 $$ közel a papírhoz “s 0,044715 USD $ .

Vélemény, hozzászólás?

Az email címet nem tesszük közzé. A kötelező mezőket * karakterrel jelöltük