Je passais par BERT article qui utilise GELU (unité linéaire derreur gaussienne) qui énonce léquation comme $$ GELU (x) = xP (X ≤ x) = xΦ (x). $$ qui à son tour est approximé à $$ 0,5x (1 + tanh [\ sqrt {2 / π} (x + 0,044715x ^ 3)]) $$

Pourriez-vous simplifier léquation et expliquer comment elle a été approximée.

Réponse

Fonction GELU

Nous pouvons étendre la distribution cumulative de $ \ mathcal {N} (0, 1) $ , cest-à-dire $ \ Phi (x) $ , comme suit: $$ \ text {GELU} (x): = x {\ Bbb P} (X \ le x) = x \ Phi (x) = 0.5x \ left (1+ \ text {erf} \ left (\ frac {x} {\ sqrt {2 }} \ right) \ right) $$

Notez quil sagit dune définition , pas dune équation (ou dune relation). Les auteurs ont fourni quelques justifications pour cette proposition, par exemple une analogie stochastique, mais mathématiquement, ce n’est qu’une définition.

Voici le tracé de GELU:

Approximation de Tanh

Pour ce type dapproximations numériques, lidée clé est de trouver une fonction similaire (principalement basée sur lexpérience), de la paramétrer, puis de ladapter à un ensemble de points de la fonction dorigine.

Sachant que $ \ text {erf} (x) $ est très proche de $ \ text {tanh} (x) $

et le premier dérivé de $ \ text {erf} (\ frac {x} {\ sqrt {2}}) $ coïncide avec celui de $ \ text {tanh} (\ sqrt { \ frac {2} {\ pi}} x) $ à $ x = 0 $ , qui est $ \ sqrt {\ frac {2} {\ pi}} $ , nous procédons à lajustement de $$ \ text {tanh} \ left (\ sqrt {\ frac { 2} {\ pi}} (x + ax ^ 2 + bx ^ 3 + cx ^ 4 + dx ^ 5) \ right) $$ (ou avec plus de termes) à un ensemble de points $ \ left (x_i, \ text {erf} \ left (\ frac {x_i} {\ sqrt {2}} \ right) \ right) $ .

Jai installé cette fonction sur 20 échantillons entre $ (- 1,5, 1,5) $ ( en utilisant ce site ), et voici les coefficients:

En définissant $ a = c = d = 0 $ , $ b $ a été estimé à 0,04495641 $ $ . Avec plus déchantillons provenant dune gamme plus large (ce site nen autorisait que 20), le coefficient $ b $ sera plus proche de celui du papier « s 0,044715 $ . Enfin, nous obtenons

$ \ text {GELU} (x) = x \ Phi (x) = 0.5x \ left (1 + \ text {erf} \ left (\ frac {x} {\ sqrt {2}} \ right) \ right) \ simeq 0.5x \ left (1+ \ text {tanh} \ left (\ sqrt {\ frac { 2} {\ pi}} (x + 0,044715x ^ 3) \ right) \ right) $

avec erreur quadratique moyenne $ \ sim 10 ^ {- 8} $ pour $ x \ in [-10, 10] $ .

Notez que si nous lavons fait ne pas utiliser la relation entre les premières dérivées, le terme $ \ sqrt {\ frac {2} {\ pi}} $ aurait été inclus dans les paramètres comme suit $$ 0.5x \ left (1+ \ text {tanh} \ left (0.797885x + 0.035677x ^ 3 \ right) \ right) $$ qui est moins beau (moins analytique , plus numérique)!

Utilisation de la parité

Comme suggéré par @BookYourLuck , nous pouvons utiliser la parité des fonctions pour restreindre lespace des polynômes dans lesquels nous recherchons. Autrement dit, puisque $ \ text {erf} $ est une fonction impaire, cest-à-dire $ f (-x) = – f (x) $ , et $ \ text {tanh} $ est aussi une fonction impaire, fonction polynomiale $ \ text {pol} (x) $ à lintérieur de $ \ text {tanh} $ devrait également être impair (ne devrait avoir que des puissances impaires de $ x $ ) pour avoir $$ \ text {erf} (- x) \ simeq \ text {tanh} (\ text {pol} (-x)) = \ text {tanh} (- \ text {pol} (x)) = – \ text {tanh} (\ text {pol} (x)) \ simeq- \ text {erf} (x) $$

Auparavant, nous avions la chance de nous retrouver avec (presque) zéro coefficients pour des puissances paires $ x ^ 2 $ et $ x ^ 4 $ , cependant, en général, cela peut conduire à des approximations de faible qualité qui, par exemple, ont un terme comme 0,23 $ x ^ 2 $ qui est annulé par des conditions supplémentaires (paires ou impaires) au lieu dopter simplement pour $ 0x ^ 2 $ .

Approximation sigmoïde

Une relation similaire existe entre $ \ text {erf} (x) $ et $ 2 \ left (\ sigma (x) – \ frac {1} {2} \ right) $ (sigmoïde), qui est proposé dans larticle comme une autre approximation, avec une erreur quadratique moyenne $ \ sim 10 ^ {- 4} $ pour $ x \ in [-10, 10] $ .

Voici un code Python pour générer des points de données, ajuster les fonctions et calculer les erreurs quadratiques moyennes:

import math import numpy as np import scipy.optimize as optimize def tahn(xs, a): return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs] def sigmoid(xs, a): return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs] print_points = 0 np.random.seed(123) # xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0, # .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2] # xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8))) # xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6))) xs = np.arange(-10, 10, 0.001) erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs]) ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs]) # Fit tanh and sigmoid curves to erf points tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs) print("Tanh fit: a=%5.5f" % tuple(tanh_popt)) sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs) print("Sigmoid fit: a=%5.5f" % tuple(sig_popt)) # curves used in https://mycurvefit.com: # 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5)) # 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3)) y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs]) tanh_error_paper = (np.square(ys - y_paper_tanh)).mean() y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs]) tanh_error_alt = (np.square(ys - y_alt_tanh)).mean() # curve used in https://mycurvefit.com: # 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5) y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs]) sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean() y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs]) sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean() print("Paper tanh error:", tanh_error_paper) print("Alternative tanh error:", tanh_error_alt) print("Paper sigmoid error:", sigmoid_error_paper) print("Alternative sigmoid error:", sigmoid_error_alt) if print_points == 1: print(len(xs)) for x, erf in zip(xs, erfs): print(x, erf) 

Sortie:

Tanh fit: a=0.04485 Sigmoid fit: a=1.70099 Paper tanh error: 2.4329173471294176e-08 Alternative tanh error: 2.698034519269613e-08 Paper sigmoid error: 5.6479106346814546e-05 Alternative sigmoid error: 5.704246564663601e-05 

Commentaires

  • Pourquoi lapproximation est-elle nécessaire? Ne pourraient pas ' utiliser simplement la fonction erf?

Répondre

Notez dabord que $$ \ Phi (x) = \ frac12 \ mathrm {erfc} \ left (- \ frac {x} {\ sqrt {2}} \ right) = \ frac12 \ left (1 + \ mathrm {erf} \ left (\ frac {x} {\ sqrt2} \ right) \ right) $$ par parité de $ \ mathrm {erf} $ . Nous devons montrer que $$ \ mathrm {erf} \ left (\ frac x {\ sqrt2} \ right) \ approx \ tanh \ left (\ sqrt {\ frac2 \ pi} \ left (x + ax ^ 3 \ right) \ right) $$ pour $ a \ approx 0.044715 $ .

Pour les grandes valeurs de $ x $ , les deux fonctions sont délimitées dans $ [- 1, 1 ] $ . Pour les petits $ x $ , la série Taylor respective se lit $$ \ tanh (x) = x – \ frac {x ^ 3} {3} + o (x ^ 3) $$ et $$ \ mathrm {erf} (x) = \ frac {2} {\ sqrt {\ pi}} \ left (x – \ frac {x ^ 3} {3} \ right) + o (x ^ 3). $$ En remplaçant, nous obtenons ce $$ \ tanh \ left (\ sqrt {\ frac2 \ pi} \ left (x + ax ^ 3 \ right) \ right) = \ sqrt \ frac {2} {\ pi} \ left (x + \ left (a – \ frac {2} {3 \ pi} \ right) x ^ 3 \ right) + o (x ^ 3) $$ et $$ \ mathrm {erf } \ left (\ frac x {\ sqrt2} \ right) = \ sqrt \ frac2 \ pi \ left (x – \ frac {x ^ 3} {6} \ right) + o (x ^ 3). $$ Coefficient équivalent pour $ x ^ 3 $ , nous trouvons $$ a \ approx 0.04553992412 $$ près du papier « s $ 0.044715 $ .

Laisser un commentaire

Votre adresse e-mail ne sera pas publiée. Les champs obligatoires sont indiqués avec *