MuZero 종이 의사 코드 에는 다음과 같은 코드 줄이 있습니다.
hidden_state = tf.scale_gradient(hidden_state, 0.5)
이 기능은 무엇입니까? 왜 거기에 있습니까?
tf.scale_gradient
를 검색했는데 tensorflow에 존재하지 않습니다. 그리고 scalar_loss
와는 달리 자체 코드에서 정의하지 않은 것 같습니다.
컨텍스트를 위해 전체 함수는 다음과 같습니다.
def update_weights(optimizer: tf.train.Optimizer, network: Network, batch, weight_decay: float): loss = 0 for image, actions, targets in batch: # Initial step, from the real observation. value, reward, policy_logits, hidden_state = network.initial_inference( image) predictions = [(1.0, value, reward, policy_logits)] # Recurrent steps, from action and previous hidden state. for action in actions: value, reward, policy_logits, hidden_state = network.recurrent_inference( hidden_state, action) predictions.append((1.0 / len(actions), value, reward, policy_logits)) # THIS LINE HERE hidden_state = tf.scale_gradient(hidden_state, 0.5) for prediction, target in zip(predictions, targets): gradient_scale, value, reward, policy_logits = prediction target_value, target_reward, target_policy = target l = ( scalar_loss(value, target_value) + scalar_loss(reward, target_reward) + tf.nn.softmax_cross_entropy_with_logits( logits=policy_logits, labels=target_policy)) # AND AGAIN HERE loss += tf.scale_gradient(l, gradient_scale) for weights in network.get_weights(): loss += weight_decay * tf.nn.l2_loss(weights) optimizer.minimize(loss)
그라디언트 크기 조정은 어떤 역할을하며, 그 이유는 무엇입니까?
댓글
답변
여기에있는 논문의 저자-이것이 TensorFlow 함수가 아니라 Sonnet의 scale_gradient <와 동일하다는 사실을 놓쳤습니다. / a> 또는 다음 함수 :
def scale_gradient(tensor, scale): """Scales the gradient for the backward pass.""" return tensor * scale + tf.stop_gradient(tensor) * (1 - scale)
댓글
- 답장 해 주셔서 감사합니다! stackoverflow.com/q/60234530 (다른 MuZero 질문)을 살펴 보시면 대단히 감사하겠습니다.
답변
Pseude 코드를 알고 계십니까? (TF 2.0이 아니기 때문에) 그라디언트 클리핑 또는 일괄 정규화 를 사용합니다 ( “활성화 기능의 크기 조정”)
설명
- 제공 한 링크를 보면 그래디언트 표준 크기 조정이 될 가능성이 높습니다. 최적화 프로그램에서
clipnorm
매개 변수를 설정합니다. 그러나 코드에서는 매번 다른 값으로 코드에서 두 번 그래디언트 스케일링을 사용합니다.clipnorm
매개 변수로는이 작업을 수행 할 수 없습니다. 내가 어떻게 할 수 있는지 아시나요? - 또한 모델의 숨겨진 상태는 ' 잘 려야하는 것처럼 보이지 않습니다. (나는 ' 클리핑이 왜 도움이되는지 이해하지 못합니다.) 그래디언트 클리핑이 어떤 작업을 수행하는지 설명하면 귀하의 답변이 정확하다는 것을 확신하는 데 매우 도움이 될 것입니다.
tf.scale_gradient()
를 검색했습니다. 결과에서 알 수 있듯이 아무것도 나오지 않습니다. 지금은 폐기 된 이전 TF 버전의 함수 여야합니다. 확실히, ' TF 2.0에서는 더 이상 사용할 수 없습니다.