Ik was bezig met BERT-papier dat GELU (Gaussian Error Linear Unit) die vergelijking aangeeft als $$ GELU (x) = xP (X ≤ x) = xΦ (x). $$ die op zijn beurt wordt geschat op $$ 0,5x (1 + tanh [\ sqrt {2 / π} (x + 0,044715x ^ 3)]) $$
Kunt u de vergelijking vereenvoudigen en uitleggen hoe deze is benaderd.
Antwoord
GELU-functie
We kunnen de cumulatieve distributie van $ \ mathcal {N} (0, 1) $ , dwz $ \ Phi (x) $ , als volgt: $$ \ text {GELU} (x): = x {\ Bbb P} (X \ le x) = x \ Phi (x) = 0,5x \ left (1+ \ text {erf} \ left (\ frac {x} {\ sqrt {2 }} \ right) \ right) $$
Merk op dat dit een definitie is, geen vergelijking (of een relatie). Auteurs hebben enkele rechtvaardigingen gegeven voor dit voorstel, bijv. een stochastische analogie , hoe wiskundig ook, dit is slechts een definitie.
Hier is de plot van GELU:
Tanh-benadering
Voor dit soort numerieke benaderingen is het belangrijkste idee om een vergelijkbare functie te vinden (voornamelijk gebaseerd op ervaring), deze te parametriseren en deze vervolgens aan te passen aan een reeks punten van de oorspronkelijke functie.
Wetende dat $ \ text {erf} (x) $ zeer dicht bij $ \ text ligt {tanh} (x) $
en eerste afgeleide van $ \ text {erf} (\ frac {x} {\ sqrt {2}}) $ valt samen met die van $ \ text {tanh} (\ sqrt { \ frac {2} {\ pi}} x) $ om $ x = 0 $ , dat is $ \ sqrt {\ frac {2} {\ pi}} $ , we gaan verder met het plaatsen van $$ \ text {tanh} \ left (\ sqrt {\ frac { 2} {\ pi}} (x + ax ^ 2 + bx ^ 3 + cx ^ 4 + dx ^ 5) \ right) $$ (of met meer termen) naar een reeks punten $ \ left (x_i, \ text {erf} \ left (\ frac {x_i} {\ sqrt {2}} \ right) \ right) $ .
Ik heb deze functie aangepast aan 20 voorbeelden tussen $ (- 1.5, 1.5) $ ( met behulp van deze site ), en hier zijn de coëfficiënten:
Door $ a = c = d = 0 $ , $ b $ werd geschat op $ 0,04495641 $ . Met meer voorbeelden uit een groter bereik (die site stond er slechts 20 toe), zal de coëfficiënt $ b $ dichter bij de $ 0,044715 $ . Eindelijk krijgen we
$ \ text {GELU} (x) = x \ Phi (x) = 0,5x \ left (1 + \ text {erf} \ left (\ frac {x} {\ sqrt {2}} \ right) \ right) \ simeq 0,5x \ left (1+ \ text {tanh} \ left (\ sqrt {\ frac { 2} {\ pi}} (x + 0,044715x ^ 3) \ right) \ right) $
met gemiddelde kwadratische fout $ \ sim 10 ^ {- 8} $ voor $ x \ in [-10, 10] $ .
Merk op dat als we dat deden geen gebruik maken van de relatie tussen de eerste afgeleiden, zou de term $ \ sqrt {\ frac {2} {\ pi}} $ als volgt in de parameters zijn opgenomen $$ 0,5x \ left (1+ \ text {tanh} \ left (0,797885x + 0,035677x ^ 3 \ right) \ right) $$ wat minder mooi is (minder analytisch , numerieker)!
De pariteit gebruiken
Zoals voorgesteld door @BookYourLuck , kunnen we de pariteit van functies gebruiken om de ruimte van polynomen waarin we zoeken te beperken. Dat wil zeggen, aangezien $ \ text {erf} $ een vreemde functie is, dwz $ f (-x) = – f (x) $ , en $ \ text {tanh} $ is ook een vreemde functie, polynoomfunctie $ \ text {pol} (x) $ binnen $ \ text {tanh} $ moet ook oneven zijn (mag alleen oneven machten hebben van $ x $ ) om $$ \ text {erf} (- x) \ simeq \ text {tanh} (\ text {pol} (-x)) = \ text {tanh} (- \ text {pol} (x)) = – \ text {tanh} (\ text {pol} (x)) \ simeq- \ text {erf} (x) $$
Eerder hadden we het geluk dat we eindigden met (bijna) nulcoëfficiënten voor even machten $ x ^ 2 $ en $ x ^ 4 $ , maar in het algemeen kan dit leiden tot benaderingen van lage kwaliteit die bijvoorbeeld een term hebben als $ 0,23x ^ 2 $ dat wordt opgeheven door extra voorwaarden (even of oneven) in plaats van simpelweg te kiezen voor $ 0x ^ 2 $ .
Sigmoidbenadering
Een vergelijkbare relatie geldt tussen $ \ text {erf} (x) $ en $ 2 \ left (\ sigma (x) – \ frac {1} {2} \ right) $ (sigmoid), die in de paper wordt voorgesteld als een andere benadering, met een gemiddelde kwadratische fout $ \ sim 10 ^ {- 4} $ voor $ x \ in [-10, 10] $ .
Hier is een Python-code voor het genereren van datapunten, het aanpassen van de functies en het berekenen van de gemiddelde kwadratische fouten:
import math import numpy as np import scipy.optimize as optimize def tahn(xs, a): return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs] def sigmoid(xs, a): return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs] print_points = 0 np.random.seed(123) # xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0, # .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2] # xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8))) # xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6))) xs = np.arange(-10, 10, 0.001) erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs]) ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs]) # Fit tanh and sigmoid curves to erf points tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs) print("Tanh fit: a=%5.5f" % tuple(tanh_popt)) sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs) print("Sigmoid fit: a=%5.5f" % tuple(sig_popt)) # curves used in https://mycurvefit.com: # 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5)) # 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3)) y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs]) tanh_error_paper = (np.square(ys - y_paper_tanh)).mean() y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs]) tanh_error_alt = (np.square(ys - y_alt_tanh)).mean() # curve used in https://mycurvefit.com: # 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5) y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs]) sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean() y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs]) sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean() print("Paper tanh error:", tanh_error_paper) print("Alternative tanh error:", tanh_error_alt) print("Paper sigmoid error:", sigmoid_error_paper) print("Alternative sigmoid error:", sigmoid_error_alt) if print_points == 1: print(len(xs)) for x, erf in zip(xs, erfs): print(x, erf)
Uitvoer:
Tanh fit: a=0.04485 Sigmoid fit: a=1.70099 Paper tanh error: 2.4329173471294176e-08 Alternative tanh error: 2.698034519269613e-08 Paper sigmoid error: 5.6479106346814546e-05 Alternative sigmoid error: 5.704246564663601e-05
Reacties
- Waarom is de benadering nodig? Kunnen ' t ze gewoon de erf-functie gebruiken?
Answer
Merk allereerst op dat $$ \ Phi (x) = \ frac12 \ mathrm {erfc} \ left (- \ frac {x} {\ sqrt {2}} \ right) = \ frac12 \ left (1 + \ mathrm {erf} \ left (\ frac {x} {\ sqrt2} \ right) \ right) $$ door pariteit van $ \ mathrm {erf} $ . We moeten laten zien dat $$ \ mathrm {erf} \ left (\ frac x {\ sqrt2} \ right) \ approx \ tanh \ left (\ sqrt {\ frac2 \ pi} \ left (x + ax ^ 3 \ right) \ right) $$ voor $ a \ ca. 0,044715 $ .
Voor grote waarden van $ x $ zijn beide functies begrensd in $ [- 1, 1 ] $ . Voor kleine $ x $ , leest de respectieve Taylor-reeks $$ \ tanh (x) = x – \ frac {x ^ 3} {3} + o (x ^ 3) $$ en $$ \ mathrm {erf} (x) = \ frac {2} {\ sqrt {\ pi}} \ left (x – \ frac {x ^ 3} {3} \ right) + o (x ^ 3). $$ Vervanging, we krijgen dat $$ \ tanh \ left (\ sqrt {\ frac2 \ pi} \ left (x + ax ^ 3 \ right) \ right) = \ sqrt \ frac {2} {\ pi} \ left (x + \ left (a – \ frac {2} {3 \ pi} \ right) x ^ 3 \ right) + o (x ^ 3) $$ en $$ \ mathrm {erf } \ left (\ frac x {\ sqrt2} \ right) = \ sqrt \ frac2 \ pi \ left (x – \ frac {x ^ 3} {6} \ right) + o (x ^ 3). $$ Vergelijkende coëfficiënt voor $ x ^ 3 $ , we vinden $$ a \ circa 0,04553992412 $$ dicht bij de krant “s $ 0,044715 $ .