Jag gick igenom BERT-papper som använder GELU (Gaussian Error Linear Unit) som anger ekvation som $$ GELU (x) = xP (X ≤ x) = xΦ (x). $$ vilket i sin tur är ungefär $$ 0.5x (1 + tanh [\ sqrt {2 / π} (x + 0,044715x ^ 3)]) $$
Kan du förenkla ekvationen och förklara hur den har approximerats.
Svar
GELU-funktion
Vi kan utöka kumulativ fördelning av $ \ mathcal {N} (0, 1) $ , dvs. $ \ Phi (x) $ , enligt följande: $$ \ text {GELU} (x): = x {\ Bbb P} (X \ le x) = x \ Phi (x) = 0.5x \ left (1+ \ text {erf} \ left (\ frac {x} {\ sqrt {2 }} \ right) \ right) $$
Observera att detta är en definition , inte en ekvation (eller en relation). Författare har tillhandahållit vissa motiveringar för detta förslag, t.ex. en stokastisk analogi , hur matematiskt det än är, det här är bara en definition.
Här är plot av GELU:
Tanh-approximation
För denna typ av numeriska approximationer är nyckelidén att hitta en liknande funktion (främst baserat på erfarenhet), parametrisera den och sedan passa den till en uppsättning punkter från den ursprungliga funktionen.
Att veta att $ \ text {erf} (x) $ ligger mycket nära $ \ text {tanh} (x) $
och första derivatet av $ \ text {erf} (\ frac {x} {\ sqrt {2}}) $ sammanfaller med den för $ \ text {tanh} (\ sqrt { \ frac {2} {\ pi}} x) $ vid $ x = 0 $ , vilket är $ \ sqrt {\ frac {2} {\ pi}} $ , vi fortsätter för att passa $$ \ text {tanh} \ left (\ sqrt {\ frac { 2} {\ pi}} (x + ax ^ 2 + bx ^ 3 + cx ^ 4 + dx ^ 5) \ höger) $$ (eller med fler termer) till en uppsättning punkter $ \ left (x_i, \ text {erf} \ left (\ frac {x_i} {\ sqrt {2}} \ right) \ right) $ .
Jag har anpassat den här funktionen till 20 prover mellan $ (- 1.5, 1.5) $ ( använder den här webbplatsen ), och här är koefficienterna:
Genom att ställa in $ a = c = d = 0 $ , $ b $ beräknades vara $ 0,04495641 $ . Med fler prover från ett större intervall (den webbplatsen tillåts endast 20) kommer koefficienten $ b $ närmare papper ”s $ 0,044715 $ . Slutligen får vi
$ \ text {GELU} (x) = x \ Phi (x) = 0.5x \ left (1 + \ text {erf} \ left (\ frac {x} {\ sqrt {2}} \ right) \ right) \ simeq 0.5x \ left (1+ \ text {tanh} \ left (\ sqrt {\ frac {) 2} {\ pi}} (x + 0,044715x ^ 3) \ höger) \ höger) $
med medelkvadratfel $ \ sim 10 ^ {- 8} $ för $ x \ i [-10, 10] $ .
Observera att om vi gjorde det inte använda förhållandet mellan de första derivaten, termen $ \ sqrt {\ frac {2} {\ pi}} $ skulle ha inkluderats i parametrarna enligt följande $$ 0.5x \ left (1+ \ text {tanh} \ left (0.797885x + 0.035677x ^ 3 \ right) \ right) $$ vilket är mindre vackert (mindre analytiskt , mer numeriskt)!
Använda pariteten
Som föreslagits av @BookYourLuck , vi kan använda paritetsfunktionerna för att begränsa utrymmet för polynom som vi söker i. Det vill säga, eftersom $ \ text {erf} $ är en udda funktion, dvs. $ f (-x) = – f (x) $ och $ \ text {tanh} $ är också en udda funktion, polynomfunktion $ \ text {pol} (x) $ inuti $ \ text {tanh} $ bör också vara udda (bör bara ha udda befogenheter $ x $ ) för att ha $$ \ text {erf} (- x) \ simeq \ text {tanh} (\ text {pol} (-x)) = \ text {tanh} (- \ text {pol} (x)) = – \ text {tanh} (\ text {pol} (x)) \ simeq- \ text {erf} (x) $$
Tidigare hade vi turen att sluta med (nästan) nollkoefficienter för jämna krafter $ x ^ 2 $ och $ x ^ 4 $ , men i allmänhet kan detta leda till approximationer av låg kvalitet som till exempel har en term som $ 0,23x ^ 2 $ som avbryts med extra villkor (jämnt eller udda) istället för att helt enkelt välja $ 0x ^ 2 $ .
Sigmoid approximation
Ett liknande förhållande gäller mellan $ \ text {erf} (x) $ och $ 2 \ left (\ sigma (x) – \ frac {1} {2} \ right) $ (sigmoid), vilket föreslås i papperet som en annan approximation, med medelkvadratfel $ \ sim 10 ^ {- 4} $ för $ x \ i [-10, 10] $ .
Här är en Python-kod för att generera datapunkter, anpassa funktionerna och beräkna medelkvadratfel:
import math import numpy as np import scipy.optimize as optimize def tahn(xs, a): return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs] def sigmoid(xs, a): return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs] print_points = 0 np.random.seed(123) # xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0, # .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2] # xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8))) # xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6))) xs = np.arange(-10, 10, 0.001) erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs]) ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs]) # Fit tanh and sigmoid curves to erf points tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs) print("Tanh fit: a=%5.5f" % tuple(tanh_popt)) sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs) print("Sigmoid fit: a=%5.5f" % tuple(sig_popt)) # curves used in https://mycurvefit.com: # 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5)) # 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3)) y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs]) tanh_error_paper = (np.square(ys - y_paper_tanh)).mean() y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs]) tanh_error_alt = (np.square(ys - y_alt_tanh)).mean() # curve used in https://mycurvefit.com: # 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5) y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs]) sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean() y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs]) sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean() print("Paper tanh error:", tanh_error_paper) print("Alternative tanh error:", tanh_error_alt) print("Paper sigmoid error:", sigmoid_error_paper) print("Alternative sigmoid error:", sigmoid_error_alt) if print_points == 1: print(len(xs)) for x, erf in zip(xs, erfs): print(x, erf)
Output:
Tanh fit: a=0.04485 Sigmoid fit: a=1.70099 Paper tanh error: 2.4329173471294176e-08 Alternative tanh error: 2.698034519269613e-08 Paper sigmoid error: 5.6479106346814546e-05 Alternative sigmoid error: 5.704246564663601e-05
Kommentarer
- Varför behövs en approximation? Kunde inte ' t bara använda erf-funktionen?
Svar
Första notera att $$ \ Phi (x) = \ frac12 \ mathrm {erfc} \ left (- \ frac {x} {\ sqrt {2}} \ right) = \ frac12 \ left (1 + \ mathrm {erf} \ left (\ frac {x} {\ sqrt2} \ right) \ right) $$ efter paritet $ \ mathrm {erf} $ . Vi måste visa att $$ \ mathrm {erf} \ left (\ frac x {\ sqrt2} \ right) \ approx \ tanh \ left (\ sqrt {\ frac2 \ pi} \ left (x + ax ^ 3 \ höger) \ höger) $$ för $ a \ cirka 0,044715 $ .
För stora värden $ x $ begränsas båda funktionerna i $ [- 1, 1 ] $ . För små $ x $ läser respektive Taylor-serie $$ \ tanh (x) = x – \ frac {x ^ 3} {3} + o (x ^ 3) $$ och $$ \ mathrm {erf} (x) = \ frac {2} {\ sqrt {\ pi}} \ left (x – \ frac {x ^ 3} {3} \ right) + o (x ^ 3). $$ Om vi byter ut får vi den $$ \ tanh \ left (\ sqrt {\ frac2 \ pi} \ left (x + ax ^ 3 \ right) \ right) = \ sqrt \ frac {2} {\ pi} \ left (x + \ left (a – \ frac {2} {3 \ pi} \ höger) x ^ 3 \ höger) + o (x ^ 3) $$ och $$ \ mathrm {erf } \ left (\ frac x {\ sqrt2} \ right) = \ sqrt \ frac2 \ pi \ left (x – \ frac {x ^ 3} {6} \ right) + o (x ^ 3). $$ Likvärdighetskoefficient för $ x ^ 3 $ , vi hittar $$ a \ approx 0.04553992412 $$ nära papperet ”s $ 0,044715 $ .