Estaba revisando papel BERT que usa GELU (Unidad lineal de error gaussiano) que establece la ecuación como $$ GELU (x) = xP (X ≤ x) = xΦ (x). $$ que a su vez se aproxima a $$ 0.5x (1 + tanh [\ sqrt {2 / π} (x + 0.044715x ^ 3)]) $$
¿Podría simplificar la ecuación y explicar cómo se ha aproximado?
Respuesta
Función GELU
Podemos expandir la distribución acumulativa de $ \ mathcal {N} (0, 1) $ , es decir, $ \ Phi (x) $ , de la siguiente manera: $$ \ text {GELU} (x): = x {\ Bbb P} (X \ le x) = x \ Phi (x) = 0.5x \ left (1+ \ text {erf} \ left (\ frac {x} {\ sqrt {2 }} \ right) \ right) $$
Tenga en cuenta que esta es una definición , no una ecuación (o una relación). Los autores han proporcionado algunas justificaciones para esta propuesta, p. Ej. una analogía estocástica, sin embargo matemáticamente, esto es solo una definición.
Aquí está la trama de GELU:
Aproximación de Tanh
Para este tipo de aproximaciones numéricas, la idea clave es encontrar una función similar (basada principalmente en la experiencia), parametrizarla y luego ajustarla a un conjunto de puntos de la función original.
Sabiendo que $ \ text {erf} (x) $ está muy cerca de $ \ text {tanh} (x) $
y primera derivada de $ \ text {erf} (\ frac {x} {\ sqrt {2}}) $ coincide con el de $ \ text {tanh} (\ sqrt { \ frac {2} {\ pi}} x) $ en $ x = 0 $ , que es $ \ sqrt {\ frac {2} {\ pi}} $ , procedemos a ajustar $$ \ text {tanh} \ left (\ sqrt {\ frac { 2} {\ pi}} (x + ax ^ 2 + bx ^ 3 + cx ^ 4 + dx ^ 5) \ right) $$ (o con más términos) a un conjunto de puntos $ \ left (x_i, \ text {erf} \ left (\ frac {x_i} {\ sqrt {2}} \ right) \ right) $ .
He ajustado esta función a 20 muestras entre $ (- 1,5, 1,5) $ ( usando este sitio ), y estos son los coeficientes:
Configurando $ a = c = d = 0 $ , $ b $ se estimó en $ 0.04495641 $ . Con más muestras de un rango más amplio (ese sitio solo permitió 20), el coeficiente $ b $ estará más cerca del papel «s $ 0.044715 $ . Finalmente obtenemos
$ \ text {GELU} (x) = x \ Phi (x) = 0.5x \ left (1 + \ text {erf} \ left (\ frac {x} {\ sqrt {2}} \ right) \ right) \ simeq 0.5x \ left (1+ \ text {tanh} \ left (\ sqrt {\ frac { 2} {\ pi}} (x + 0.044715x ^ 3) \ right) \ right) $
con error cuadrático medio $ \ sim 10 ^ {- 8} $ para $ x \ in [-10, 10] $ .
Tenga en cuenta que si lo hicimos no utilizar la relación entre las primeras derivadas, el término $ \ sqrt {\ frac {2} {\ pi}} $ se habría incluido en los parámetros de la siguiente manera $$ 0.5x \ left (1+ \ text {tanh} \ left (0.797885x + 0.035677x ^ 3 \ right) \ right) $$ que es menos hermoso (menos analítico , más numérico)!
Utilizando la paridad
Como sugiere @BookYourLuck , podemos utilizar la paridad de funciones para restringir el espacio de polinomios en el que buscamos. Es decir, dado que $ \ text {erf} $ es una función extraña, es decir, $ f (-x) = – f (x) $ y $ \ text {tanh} $ también es una función impar, función polinomial $ \ text {pol} (x) $ dentro de $ \ text {tanh} $ también debe ser impar (solo debe tener poderes impares de $ x $ ) para tener $$ \ text {erf} (- x) \ simeq \ text {tanh} (\ text {pol} (-x)) = \ text {tanh} (- \ text {pol} (x)) = – \ text {tanh} (\ text {pol} (x)) \ simeq- \ text {erf} (x) $$
Anteriormente, tuvimos la suerte de terminar con coeficientes (casi) cero para potencias pares $ x ^ 2 $ y $ x ^ 4 $ , sin embargo, en general, esto puede generar aproximaciones de baja calidad que, por ejemplo, tienen un término como $ 0.23x ^ 2 $ que se cancela con términos adicionales (pares o impares) en lugar de simplemente optar por $ 0x ^ 2 $ .
Aproximación sigmoide
Existe una relación similar entre $ \ text {erf} (x) $ y $ 2 \ left (\ sigma (x) – \ frac {1} {2} \ right) $ (sigmoide), que se propone en el documento como otra aproximación, con error cuadrático medio $ \ sim 10 ^ {- 4} $ para $ x \ in [-10, 10] $ .
Aquí hay un código Python para generar puntos de datos, ajustar las funciones y calcular los errores cuadrados medios:
import math import numpy as np import scipy.optimize as optimize def tahn(xs, a): return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs] def sigmoid(xs, a): return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs] print_points = 0 np.random.seed(123) # xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0, # .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2] # xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8))) # xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6))) xs = np.arange(-10, 10, 0.001) erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs]) ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs]) # Fit tanh and sigmoid curves to erf points tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs) print("Tanh fit: a=%5.5f" % tuple(tanh_popt)) sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs) print("Sigmoid fit: a=%5.5f" % tuple(sig_popt)) # curves used in https://mycurvefit.com: # 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5)) # 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3)) y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs]) tanh_error_paper = (np.square(ys - y_paper_tanh)).mean() y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs]) tanh_error_alt = (np.square(ys - y_alt_tanh)).mean() # curve used in https://mycurvefit.com: # 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5) y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs]) sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean() y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs]) sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean() print("Paper tanh error:", tanh_error_paper) print("Alternative tanh error:", tanh_error_alt) print("Paper sigmoid error:", sigmoid_error_paper) print("Alternative sigmoid error:", sigmoid_error_alt) if print_points == 1: print(len(xs)) for x, erf in zip(xs, erfs): print(x, erf)
Resultado:
Tanh fit: a=0.04485 Sigmoid fit: a=1.70099 Paper tanh error: 2.4329173471294176e-08 Alternative tanh error: 2.698034519269613e-08 Paper sigmoid error: 5.6479106346814546e-05 Alternative sigmoid error: 5.704246564663601e-05
Comentarios
- ¿Por qué se necesita la aproximación? ¿No pueden ' t solo usan la función erf?
Responder
Primero tenga en cuenta que $$ \ Phi (x) = \ frac12 \ mathrm {erfc} \ left (- \ frac {x} {\ sqrt {2}} \ right) = \ frac12 \ left (1 + \ mathrm {erf} \ left (\ frac {x} {\ sqrt2} \ right) \ right) $$ por paridad de $ \ mathrm {erf} $ . Necesitamos mostrar que $$ \ mathrm {erf} \ left (\ frac x {\ sqrt2} \ right) \ approx \ tanh \ left (\ sqrt {\ frac2 \ pi} \ left (x + ax ^ 3 \ right) \ right) $$ para $ a \ approx 0.044715 $ .
Para valores grandes de $ x $ , ambas funciones están delimitadas en $ [- 1, 1 ] $ . Para $ x $ pequeños, la serie de Taylor respectiva dice $$ \ tanh (x) = x – \ frac {x ^ 3} {3} + o (x ^ 3) $$ y $$ \ mathrm {erf} (x) = \ frac {2} {\ sqrt {\ pi}} \ left (x – \ frac {x ^ 3} {3} \ right) + o (x ^ 3). $$ Sustituyendo, obtenemos que $$ \ tanh \ left (\ sqrt {\ frac2 \ pi} \ left (x + ax ^ 3 \ right) \ right) = \ sqrt \ frac {2} {\ pi} \ left (x + \ left (a – \ frac {2} {3 \ pi} \ right) x ^ 3 \ right) + o (x ^ 3) $$ y $$ \ mathrm {erf } \ left (\ frac x {\ sqrt2} \ right) = \ sqrt \ frac2 \ pi \ left (x – \ frac {x ^ 3} {6} \ right) + o (x ^ 3). $$ Al igualar el coeficiente para $ x ^ 3 $ , encontramos $$ a \ approx 0.04553992412 $$ cerca del papel «s $ 0.044715 $ .