Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 2f28cbc

Browse files
albheimfindmyway
andauthored
Fix bug in multi action ppo (#169)
* Remove dimension in log_pa, fix entropy for multi * Update src/algorithms/policy_gradient/ppo.jl Co-authored-by: Jun Tian <[email protected]> Co-authored-by: Jun Tian <[email protected]>
1 parent bc64e42 commit 2f28cbc

1 file changed

Lines changed: 2 additions & 3 deletions

File tree

  • src/algorithms/policy_gradient

src/algorithms/policy_gradient/ppo.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,11 +267,11 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory)
267267
if AC.actor isa GaussianNetwork
268268
μ, σ = AC.actor(s)
269269
if ndims(a) == 2
270-
log_p′ₐ = sum(normlogpdf(μ, σ, a), dims = 1)
270+
log_p′ₐ = vec(sum(normlogpdf(μ, σ, a), dims = 1))
271271
else
272272
log_p′ₐ = normlogpdf(μ, σ, a)
273273
end
274-
entropy_loss = mean((log(2.0f0π) + 1) / 2 .+ sum(log.(σ), dims = 1))
274+
entropy_loss = mean(size(σ, 1) * (log(2.0f0π) + 1) .+ sum(log, σ; dims = 1)) / 2
275275
else
276276
# actor is assumed to return discrete logits
277277
logit′ = AC.actor(s)
@@ -280,7 +280,6 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory)
280280
log_p′ₐ = log_p′[CartesianIndex.(a, 1:length(a))]
281281
entropy_loss = -sum(p′ .* log_p′) * 1 // size(p′, 2)
282282
end
283-
284283
ratio = exp.(log_p′ₐ .- log_p)
285284
surr1 = ratio .* adv
286285
surr2 = clamp.(ratio, 1.0f0 - clip_range, 1.0f0 + clip_range) .* adv

0 commit comments

Comments
 (0)