Ich habe BERT-Papier durchlaufen, das GELU (Gaußsche Fehlerlineareinheit) , die die Gleichung als $$ GELU (x) = xP (X ≤ x) = xΦ (x). $$ angibt , was wiederum ungefähr $$ 0,5x (1 + tanh [\ sqrt {2 / π} (x + 0,044715x ^ 3)]) $$
Können Sie die Gleichung vereinfachen und erklären, wie sie angenähert wurde?
Antwort
GELU-Funktion
Wir können die kumulative Verteilung von $ \ mathcal {N} (0, 1) $ , dh $ \ Phi (x) $ , wie folgt: $$ \ text {GELU} (x): = x {\ Bbb P} (X \ le x) = x \ Phi (x) = 0,5x \ left (1+ \ text {erf} \ left (\ frac {x} {\ sqrt {2) }} \ right) \ right) $$
Beachten Sie, dass dies eine Definition ist, keine Gleichung (oder Beziehung). Die Autoren haben einige Begründungen für diesen Vorschlag geliefert, z. Eine stochastische Analogie , jedoch mathematisch gesehen ist dies nur eine Definition.
Hier ist die Darstellung von GELU:
Tanh-Näherung
Für diese Art von numerischen Näherungen besteht die Schlüsselidee darin, eine ähnliche Funktion zu finden (hauptsächlich basierend auf Erfahrung), sie zu parametrisieren und dann anzupassen eine Reihe von Punkten aus der ursprünglichen Funktion.
Zu wissen, dass $ \ text {erf} (x) $ sehr nahe an $ \ text liegt {tanh} (x) $
und erste Ableitung von $ \ text {erf} (\ frac {x} {\ sqrt {2}}) $ stimmt mit dem von $ \ text {tanh} (\ sqrt {überein) \ frac {2} {\ pi}} x) $ bei $ x = 0 $ , was ist $ \ sqrt {\ frac {2} {\ pi}} $ passen wir $$ \ text {tanh} \ left (\ sqrt {\ frac { 2} {\ pi}} (x + ax ^ 2 + bx ^ 3 + cx ^ 4 + dx ^ 5) \ right) $$ (oder mit mehr Begriffen) zu einer Menge von Punkten $ \ left (x_i, \ text {erf} \ left (\ frac {x_i} {\ sqrt {2}} \ right) \ right) $ .
Ich habe diese Funktion an 20 Stichproben zwischen $ (- 1,5, 1,5) $ ( Verwenden dieser Site ) und hier sind die Koeffizienten:
Durch Setzen von $ a = c = d = 0 $ , $ b $ wurde auf $ 0.04495641 geschätzt $ . Bei mehr Proben aus einem größeren Bereich (diese Site erlaubte nur 20) liegt der Koeffizient $ b $ näher am des Papiers $ 0.044715 $ . Schließlich erhalten wir
$ \ text {GELU} (x) = x \ Phi (x) = 0.5x \ left (1 + \ text {erf} \ left (\ frac {x} {\ sqrt {2}} \ right) \ right) \ simeq 0.5x \ left (1+ \ text {tanh} \ left (\ sqrt {\ frac { 2} {\ pi}} (x + 0.044715x ^ 3) \ right) \ right) $
mit mittlerem quadratischen Fehler $ \ sim 10 ^ {- 8} $ für $ x \ in [-10, 10] $ .
Beachten Sie, dass wir dies getan haben Ohne die Beziehung zwischen den ersten Ableitungen zu verwenden, wäre der Begriff $ \ sqrt {\ frac {2} {\ pi}} $ wie folgt in die Parameter aufgenommen worden $$ 0.5x \ left (1+ \ text {tanh} \ left (0.797885x + 0.035677x ^ 3 \ right) \ right) $$ ist weniger schön (weniger analytisch) , numerischer)!
Verwenden der Parität
Wie von @BookYourLuck können wir die Parität von Funktionen verwenden, um den Raum der Polynome, in denen wir suchen, einzuschränken. Das heißt, da $ \ text {erf} $ eine ungerade Funktion ist, dh $ f (-x) = – f (x) $ und $ \ text {tanh} $ ist ebenfalls eine ungerade Funktion, die Polynomfunktion $ \ text {pol} (x) $ in $ \ text {tanh} $ sollte ebenfalls ungerade sein (sollte nur ungerade Potenzen von $ x $ ), um $$ \ text {erf} (- x) \ simeq \ text {tanh} (\ text {pol}) zu haben. (-x)) = \ text {tanh} (- \ text {pol} (x)) = – \ text {tanh} (\ text {pol} (x)) \ simeq- \ text {erf} (x) $$
Bisher hatten wir das Glück, (fast) Null-Koeffizienten für gerade Potenzen zu erhalten. $ x ^ 2 $ und $ x ^ 4 $ Dies kann jedoch im Allgemeinen zu Approximationen geringer Qualität führen, die beispielsweise einen Begriff wie haben $ 0.23x ^ 2 $ , das durch zusätzliche Bedingungen (gerade oder ungerade) aufgehoben wird anstatt sich einfach für $ 0x ^ 2 $ zu entscheiden.
Sigmoid-Approximation
Eine ähnliche Beziehung besteht zwischen $ \ text {erf} (x) $ und $ 2 \ left (\ sigma (x) – \ frac {1} {2} \ right) $ (Sigmoid), das in der Arbeit als weitere Annäherung mit mittlerem quadratischen Fehler vorgeschlagen wird $ \ sim 10 ^ {- 4} $ für $ x \ in [-10, 10] $
Hier ist ein Python-Code zum Generieren von Datenpunkten, Anpassen der Funktionen und Berechnen der mittleren quadratischen Fehler:
import math import numpy as np import scipy.optimize as optimize def tahn(xs, a): return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs] def sigmoid(xs, a): return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs] print_points = 0 np.random.seed(123) # xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0, # .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2] # xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8))) # xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6))) xs = np.arange(-10, 10, 0.001) erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs]) ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs]) # Fit tanh and sigmoid curves to erf points tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs) print("Tanh fit: a=%5.5f" % tuple(tanh_popt)) sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs) print("Sigmoid fit: a=%5.5f" % tuple(sig_popt)) # curves used in https://mycurvefit.com: # 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5)) # 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3)) y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs]) tanh_error_paper = (np.square(ys - y_paper_tanh)).mean() y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs]) tanh_error_alt = (np.square(ys - y_alt_tanh)).mean() # curve used in https://mycurvefit.com: # 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5) y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs]) sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean() y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs]) sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean() print("Paper tanh error:", tanh_error_paper) print("Alternative tanh error:", tanh_error_alt) print("Paper sigmoid error:", sigmoid_error_paper) print("Alternative sigmoid error:", sigmoid_error_alt) if print_points == 1: print(len(xs)) for x, erf in zip(xs, erfs): print(x, erf)
Ausgabe:
Tanh fit: a=0.04485 Sigmoid fit: a=1.70099 Paper tanh error: 2.4329173471294176e-08 Alternative tanh error: 2.698034519269613e-08 Paper sigmoid error: 5.6479106346814546e-05 Alternative sigmoid error: 5.704246564663601e-05
Kommentare
- Warum wird die Annäherung benötigt? Könnte ' nicht nur die erf-Funktion verwendet werden?
Antwort
Beachten Sie zunächst, dass $$ \ Phi (x) = \ frac12 \ mathrm {erfc} \ left (- \ frac {x} {\ sqrt {2}} \ right) = \ frac12 \ left (1 + \ mathrm {erf} \ left (\ frac {x} {\ sqrt2} \ right) \ right) $$ nach Parität von $ \ mathrm {erf} $ . Wir müssen zeigen, dass $$ \ mathrm {erf} \ left (\ frac x {\ sqrt2} \ right) \ approx \ tanh \ left (\ sqrt {\ frac2 \ pi}) \ left (x + ax ^ 3 \ right) \ right) $$ für $ a \ ca. 0.044715 $ .
Bei großen Werten von $ x $ sind beide Funktionen in $ [- 1, 1 begrenzt ] $ . Für kleine $ x $ lautete die jeweilige Taylor-Reihe $$ \ tanh (x) = x – \ frac {x ^ 3} {3} + o (x ^ 3) $$ und $$ \ mathrm {erf} (x) = \ frac {2} {\ sqrt {\ pi}} \ left (x – \ frac {x ^ 3} {3} \ right) + o (x ^ 3). $$ Durch Einsetzen erhalten wir diesen $$ \ tanh \ left (\ sqrt {\ frac2 \ pi} \ left (x + ax ^ 3 \ right) \ right) = \ sqrt \ frac {2} {\ pi} \ left (x + \ left (a – \ frac {2} {3 \ pi} \ rechts) x ^ 3 \ rechts) + o (x ^ 3) $$ und $$ \ mathrm {erf } \ left (\ frac x {\ sqrt2} \ right) = \ sqrt \ frac2 \ pi \ left (x – \ frac {x ^ 3} {6} \ right) + o (x ^ 3). $$ Wenn der Koeffizient für $ x ^ 3 $ gleichgesetzt wird, finden wir $$ a \ ca. 0.04553992412 $$ in der Nähe des Papiers „s $ 0.044715 $ .