Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit d78f327

Browse files
authored
Fix #251, ppo multidim action eval (#177)
1 parent 52a9c85 commit d78f327

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

  • src/algorithms/policy_gradient

src/algorithms/policy_gradient/ppo.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,11 @@ RLBase.prob(p::PPOPolicy, env::MultiThreadEnv) = prob(p, state(env))
173173
function RLBase.prob(p::PPOPolicy, env::AbstractEnv)
174174
s = state(env)
175175
s = Flux.unsqueeze(s, ndims(s) + 1)
176-
prob(p, s)[1]
176+
prob(p, s)
177177
end
178178

179179
(p::PPOPolicy)(env::MultiThreadEnv) = rand.(p.rng, prob(p, env))
180-
(p::PPOPolicy)(env::AbstractEnv) = rand(p.rng, prob(p, env))
180+
(p::PPOPolicy)(env::AbstractEnv) = rand.(p.rng, prob(p, env))
181181

182182
function (agent::Agent{<:PPOPolicy})(env::MultiThreadEnv)
183183
dist = prob(agent.policy, env)

0 commit comments

Comments
 (0)