Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit ca8f347

Browse files
albheimfindmyway
andauthored
Fix GaussianNetwork stddev and replace SACPolicyNetwork (#172)
* Switch sigma to log_sigma * Replace SAC network with gaussian network * Missed a logsigma spot... * Remove unwanted prints * Remove na from example * Remove StructArray * Update src/algorithms/policy_gradient/sac.jl Co-authored-by: Jun Tian <[email protected]> * Add more missed logsigma spots Co-authored-by: Jun Tian <[email protected]>
1 parent ec06a82 commit ca8f347

5 files changed

Lines changed: 26 additions & 29 deletions

File tree

src/algorithms/policy_gradient/ppo.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,8 @@ function RLBase.prob(
148148
if p.update_step < p.n_random_start
149149
@error "todo"
150150
else
151-
p.approximator.actor(send_to_device(device(p.approximator), state)) |>
152-
send_to_host |>
153-
StructArray{Normal}
151+
μ, logσ = p.approximator.actor(send_to_device(device(p.approximator), state)) |> send_to_host
152+
StructArray{Normal}((μ, exp.(logσ)))
154153
end
155154
end
156155

@@ -265,13 +264,13 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory)
265264
gs = gradient(ps) do
266265
v′ = AC.critic(s) |> vec
267266
if AC.actor isa GaussianNetwork
268-
μ, σ = AC.actor(s)
267+
μ, logσ = AC.actor(s)
269268
if ndims(a) == 2
270-
log_p′ₐ = vec(sum(normlogpdf(μ, σ, a), dims = 1))
269+
log_p′ₐ = vec(sum(normlogpdf(μ, exp.(logσ), a), dims = 1))
271270
else
272-
log_p′ₐ = normlogpdf(μ, σ, a)
271+
log_p′ₐ = normlogpdf(μ, exp.(logσ), a)
273272
end
274-
entropy_loss = mean(size(σ, 1) * (log(2.0f0π) + 1) .+ sum(log, σ; dims = 1)) / 2
273+
entropy_loss = mean(size(logσ, 1) * (log(2.0f0π) + 1) .+ sum(logσ; dims = 1)) / 2
275274
else
276275
# actor is assumed to return discrete logits
277276
logit′ = AC.actor(s)

src/algorithms/policy_gradient/sac.jl

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,4 @@
1-
export SACPolicy, SACPolicyNetwork
2-
3-
# Define SAC Actor
4-
struct SACPolicyNetwork
5-
pre::Chain
6-
mean::Chain
7-
log_std::Chain
8-
end
9-
Flux.@functor SACPolicyNetwork
10-
(m::SACPolicyNetwork)(state) = (x = m.pre(state); (m.mean(x), m.log_std(x)))
1+
export SACPolicy
112

123
mutable struct SACPolicy{
134
BA<:NeuralNetworkApproximator,
@@ -54,6 +45,10 @@ end
5445
- `update_every = 50`,
5546
- `step = 0`,
5647
- `rng = Random.GLOBAL_RNG`,
48+
49+
`policy` is expected to output a tuple `(μ, logσ)` of mean and
50+
log standard deviations for the desired action distributions, this
51+
can be implemented using a `GaussianNetwork` in a `NeuralNetworkApproximator`.
5752
"""
5853
function SACPolicy(;
5954
policy,
@@ -117,8 +112,8 @@ end
117112
This function is compatible with a multidimensional action space.
118113
"""
119114
function evaluate(p::SACPolicy, state)
120-
μ, log_σ = p.policy(state)
121-
π_dist = Normal.(μ, exp.(log_σ))
115+
μ, logσ = p.policy(state)
116+
π_dist = Normal.(μ, exp.(logσ))
122117
z = rand.(p.rng, π_dist)
123118
logp_π = sum(logpdf.(π_dist, z), dims = 1)
124119
logp_π -= sum((2.0f0 .* (log(2.0f0) .- z - softplus.(-2.0f0 * z))), dims = 1)

src/algorithms/policy_gradient/vpg.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
export VPGPolicy, GaussianNetwork
22

33
"""
4-
GaussianNetwork(;pre=identity, μ, σ)
4+
GaussianNetwork(;pre=identity, μ, logσ)
55
6-
`σ` should return the log of std, `exp` will be applied to it automatically.
6+
Returns `μ` and `logσ` when called.
7+
Create a distribution to sample from
8+
using `Normal.(μ, exp.(logσ))`.
79
"""
810
Base.@kwdef struct GaussianNetwork{P,U,S}
911
pre::P = identity
1012
μ::U
11-
σ::S
13+
logσ::S
1214
end
1315

1416
Flux.@functor GaussianNetwork
1517

1618
function (m::GaussianNetwork)(S)
1719
x = m.pre(S)
18-
m.μ(x), m.σ(x) .|> exp
20+
m.μ(x), m.logσ(x)
1921
end
2022

2123
"""

src/experiments/rl_envs/JuliaRL_PPO_Pendulum.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function RLCore.Experiment(
3737
Dense(64, 64, relu; initW = glorot_uniform(rng)),
3838
),
3939
μ = Chain(Dense(64, 1, tanh; initW = glorot_uniform(rng)), vec),
40-
σ = Chain(Dense(64, 1; initW = glorot_uniform(rng)), vec),
40+
logσ = Chain(Dense(64, 1; initW = glorot_uniform(rng)), vec),
4141
),
4242
critic = Chain(
4343
Dense(ns, 64, relu; initW = glorot_uniform(rng)),

src/experiments/rl_envs/JuliaRL_SAC_Pendulum.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@ function RLCore.Experiment(
2626
init = glorot_uniform(rng)
2727

2828
create_policy_net() = NeuralNetworkApproximator(
29-
model = SACPolicyNetwork(
30-
Chain(Dense(ns, 30, relu), Dense(30, 30, relu)),
31-
Chain(Dense(30, 1, initW = init)),
32-
Chain(
33-
Dense(30, 1, x -> clamp(x, typeof(x)(-2), typeof(x)(2)), initW = init),
29+
model = GaussianNetwork(
30+
pre = Chain(
31+
Dense(ns, 30, relu),
32+
Dense(30, 30, relu),
3433
),
34+
μ = Chain(Dense(30, 1, initW = init)),
35+
logσ = Chain(Dense(30, 1, x -> clamp.(x, typeof(x)(-10), typeof(x)(2)), initW = init)),
3536
),
3637
optimizer = ADAM(0.003),
3738
)

0 commit comments

Comments
 (0)