MuZero紙の擬似コードには、次のコード行があります。
hidden_state = tf.scale_gradient(hidden_state, 0.5)
これは何をしますか?なぜそこにあるのですか?
tf.scale_gradient
を検索しましたが、テンソルフローに存在しません。また、scalar_loss
とは異なり、独自のコードで定義されていないようです。
コンテキストとして、関数全体を次に示します。
def update_weights(optimizer: tf.train.Optimizer, network: Network, batch, weight_decay: float): loss = 0 for image, actions, targets in batch: # Initial step, from the real observation. value, reward, policy_logits, hidden_state = network.initial_inference( image) predictions = [(1.0, value, reward, policy_logits)] # Recurrent steps, from action and previous hidden state. for action in actions: value, reward, policy_logits, hidden_state = network.recurrent_inference( hidden_state, action) predictions.append((1.0 / len(actions), value, reward, policy_logits)) # THIS LINE HERE hidden_state = tf.scale_gradient(hidden_state, 0.5) for prediction, target in zip(predictions, targets): gradient_scale, value, reward, policy_logits = prediction target_value, target_reward, target_policy = target l = ( scalar_loss(value, target_value) + scalar_loss(reward, target_reward) + tf.nn.softmax_cross_entropy_with_logits( logits=policy_logits, labels=target_policy)) # AND AGAIN HERE loss += tf.scale_gradient(l, gradient_scale) for weights in network.get_weights(): loss += weight_decay * tf.nn.l2_loss(weights) optimizer.minimize(loss)
グラデーションのスケーリングは何をし、なぜそこで行うのですか?
コメント
回答
ここでの論文の著者-これは明らかにTensorFlow関数ではなく、Sonnetの scale_gradient <と同等であることがわかりませんでした。 / a>、または次の関数:
def scale_gradient(tensor, scale): """Scales the gradient for the backward pass.""" return tensor * scale + tf.stop_gradient(tensor) * (1 - scale)
コメント
- 返信ありがとうございます! stackoverflow.com/q/60234530 (別のMuZeroの質問)をご覧になりたい場合は、よろしくお願いします。
回答
その疑似コードを考えますか? (TF 2.0にはないため)グラデーションクリッピングまたはバッチ正規化を使用します( 「活性化関数のスケーリング」)
コメント
- 指定したリンクから、これは勾配ノルムスケーリングである可能性が高いようです。オプティマイザーで
clipnorm
パラメーターを設定します。ただし、コードでは、毎回異なる値を持つ勾配スケーリングをコードで2回使用します。clipnorm
パラメーターではこれを実行できません。どうすればよいかわかりますか? - また、モデルの非表示状態は'クリップする必要があるようには見えません。 ('クリッピングが役立つ理由がわかりません。)グラデーションクリッピングがどのように機能するかを説明することは、あなたの答えが正しいことを確認するのに非常に役立ちます。
tf.scale_gradient()
を検索しました。 結果が示すように、何も表示されません。これは、現在は廃止されている古いTFバージョンの関数である必要があります。確かに、'はTF2.0では利用できなくなりました。