Jeg gennemgik BERT-papir , der bruger GELU (Gaussian Error Linear Unit) , der angiver ligning som $$ GELU (x) = xP (X ≤ x) = xΦ (x). $$ som igen er tilnærmet $$ 0.5x (1 + tanh [\ sqrt {2 / π} (x + 0.044715x ^ 3)]) $$
Kunne du forenkle ligningen og forklare, hvordan den er tilnærmet.
Svar
GELU-funktion
Vi kan udvide kumulativ fordeling af $ \ mathcal {N} (0, 1) $ , dvs. $ \ Phi (x) $ , som følger: $$ \ text {GELU} (x): = x {\ Bbb P} (X \ le x) = x \ Phi (x) = 0.5x \ left (1+ \ text {erf} \ left (\ frac {x} {\ sqrt {2 }} \ right) \ right) $$
Bemærk, at dette er en definition , ikke en ligning (eller en relation). Forfattere har givet nogle begrundelser for dette forslag, f.eks. en stokastisk analogi men matematisk, dette er bare en definition.
Her er plottet for GELU:
Tanh-tilnærmelse
For denne type numeriske tilnærmelser er nøgleidéen at finde en lignende funktion (primært baseret på erfaring), parametrere den og derefter tilpasse den til et sæt punkter fra den oprindelige funktion.
At vide, at $ \ text {erf} (x) $ er meget tæt på $ \ text {tanh} (x) $
og første derivat af $ \ text {erf} (\ frac {x} {\ sqrt {2}}) $ falder sammen med $ \ text {tanh} (\ sqrt { \ frac {2} {\ pi}} x) $ ved $ x = 0 $ , hvilket er $ \ sqrt {\ frac {2} {\ pi}} $ , vi fortsætter med at passe til $$ \ text {tanh} \ left (\ sqrt {\ frac { 2} {\ pi}} (x + ax ^ 2 + bx ^ 3 + cx ^ 4 + dx ^ 5) \ right) $$ (eller med flere udtryk) til et sæt punkter $ \ left (x_i, \ text {erf} \ left (\ frac {x_i} {\ sqrt {2}} \ right) \ right) $ .
Jeg har tilpasset denne funktion til 20 prøver mellem $ (- 1.5, 1.5) $ ( ved hjælp af dette websted ), og her er koefficienterne:
Ved at indstille $ a = c = d = 0 $ , $ b $ blev anslået til at være $ 0,04495641 $ . Med flere prøver fra et bredere interval (dette websted tillader kun 20), vil koefficienten $ b $ være tættere på papiret “s $ 0,044715 $ . Endelig får vi
$ \ text {GELU} (x) = x \ Phi (x) = 0.5x \ left (1 + \ text {erf} \ left (\ frac {x} {\ sqrt {2}} \ right) \ right) \ simeq 0.5x \ left (1+ \ text {tanh} \ left (\ sqrt {\ frac { 2} {\ pi}} (x + 0,044715x ^ 3) \ højre) \ højre) $
med gennemsnitlig kvadratfejl $ \ sim 10 ^ {- 8} $ til $ x \ i [-10, 10] $ .
Bemærk at hvis vi gjorde ikke bruge forholdet mellem de første derivater, ville termen $ \ sqrt {\ frac {2} {\ pi}} $ være inkluderet i parametrene som følger $$ 0.5x \ left (1+ \ text {tanh} \ left (0.797885x + 0.035677x ^ 3 \ right) \ right) $$ hvilket er mindre smukt (mindre analytisk , mere numerisk)!
Brug af pariteten
Som foreslået af @BookYourLuck , vi kan bruge pariteten af funktioner til at begrænse rummet for polynomer, som vi søger efter. Det vil sige, da $ \ text {erf} $ er en ulige funktion, dvs. $ f (-x) = – f (x) $ , og $ \ text {tanh} $ er også en ulige funktion, polynomial funktion $ \ text {pol} (x) $ inde i $ \ text {tanh} $ skal også være ulige (skal kun have ulige kræfter på $ x $ ) skal have $$ \ text {erf} (- x) \ simeq \ text {tanh} (\ text {pol} (-x)) = \ text {tanh} (- \ text {pol} (x)) = – \ text {tanh} (\ text {pol} (x)) \ simeq- \ text {erf} (x) $$
Tidligere var vi heldige at ende med (næsten) nul koefficienter for lige kræfter $ x ^ 2 $ og $ x ^ 4 $ , men generelt kan dette føre til tilnærmelser af lav kvalitet, som f.eks. har et udtryk som $ 0,23x ^ 2 $ , der annulleres med ekstra vilkår (lige eller ulige) i stedet for blot at vælge $ 0x ^ 2 $ .
Sigmoid-tilnærmelse
Et lignende forhold gælder mellem $ \ text {erf} (x) $ og $ 2 \ left (\ sigma (x) – \ frac {1} {2} \ right) $ (sigmoid), som foreslås i papiret som en anden tilnærmelse med gennemsnitlig kvadratisk fejl $ \ sim 10 ^ {- 4} $ til $ x \ i [-10, 10] $ .
Her er en Python-kode til generering af datapunkter, tilpasning af funktionerne og beregning af de gennemsnitlige kvadratiske fejl:
import math import numpy as np import scipy.optimize as optimize def tahn(xs, a): return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs] def sigmoid(xs, a): return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs] print_points = 0 np.random.seed(123) # xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0, # .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2] # xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8))) # xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6))) xs = np.arange(-10, 10, 0.001) erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs]) ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs]) # Fit tanh and sigmoid curves to erf points tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs) print("Tanh fit: a=%5.5f" % tuple(tanh_popt)) sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs) print("Sigmoid fit: a=%5.5f" % tuple(sig_popt)) # curves used in https://mycurvefit.com: # 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5)) # 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3)) y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs]) tanh_error_paper = (np.square(ys - y_paper_tanh)).mean() y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs]) tanh_error_alt = (np.square(ys - y_alt_tanh)).mean() # curve used in https://mycurvefit.com: # 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5) y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs]) sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean() y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs]) sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean() print("Paper tanh error:", tanh_error_paper) print("Alternative tanh error:", tanh_error_alt) print("Paper sigmoid error:", sigmoid_error_paper) print("Alternative sigmoid error:", sigmoid_error_alt) if print_points == 1: print(len(xs)) for x, erf in zip(xs, erfs): print(x, erf)
Output:
Tanh fit: a=0.04485 Sigmoid fit: a=1.70099 Paper tanh error: 2.4329173471294176e-08 Alternative tanh error: 2.698034519269613e-08 Paper sigmoid error: 5.6479106346814546e-05 Alternative sigmoid error: 5.704246564663601e-05
Kommentarer
- Hvorfor er det nødvendigt med en tilnærmelse? Kunne de ikke ' t bare bruge erffunktion?
Svar
Første bemærkning om at $$ \ Phi (x) = \ frac12 \ mathrm {erfc} \ left (- \ frac {x} {\ sqrt {2}} \ right) = \ frac12 \ left (1 + \ mathrm {erf} \ left (\ frac {x} {\ sqrt2} \ right) \ right) $$ efter paritet $ \ mathrm {erf} $ . Vi skal vise, at $$ \ mathrm {erf} \ left (\ frac x {\ sqrt2} \ right) \ approx \ tanh \ left (\ sqrt {\ frac2 \ pi} \ left (x + ax ^ 3 \ right) \ right) $$ til $ a \ ca. 0,044715 $ .
For store værdier på $ x $ er begge funktioner afgrænset i $ [- 1, 1 ] $ . For lille $ x $ læses den respektive Taylor-serie $$ \ tanh (x) = x – \ frac {x ^ 3} {3} + o (x ^ 3) $$ og $$ \ mathrm {erf} (x) = \ frac {2} {\ sqrt {\ pi}} \ left (x – \ frac {x ^ 3} {3} \ right) + o (x ^ 3). $$ Når vi erstatter, får vi den $$ \ tanh \ left (\ sqrt {\ frac2 \ pi} \ left (x + ax ^ 3 \ right) \ right) = \ sqrt \ frac {2} {\ pi} \ left (x + \ left (a – \ frac {2} {3 \ pi} \ right) x ^ 3 \ right) + o (x ^ 3) $$ og $$ \ mathrm {erf } \ left (\ frac x {\ sqrt2} \ right) = \ sqrt \ frac2 \ pi \ left (x – \ frac {x ^ 3} {6} \ right) + o (x ^ 3). $$ Ligningskoefficient for $ x ^ 3 $ , vi finder $$ a \ ca. 0,04553992412 $$ tæt på papiret “s $ 0,044715 $ .