Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 022c1fd

Browse files
albheimfindmyway
andauthored
SAC multidimensional actions (#173)
* Switch sigma to log_sigma * Replace SAC network with gaussian network * Missed a logsigma spot... * Remove unwanted prints * Remove na from example * Seems to be running multidim actions * Remove StructArray * Cleanup * Cleanup * Update src/algorithms/policy_gradient/sac.jl Co-authored-by: Jun Tian <[email protected]> * Add more missed logsigma spots * Undo erronous change * Track and log reward and entropy terms * Add link to paper Co-authored-by: Jun Tian <[email protected]>
1 parent 4a2417b commit 022c1fd

2 files changed

Lines changed: 33 additions & 16 deletions

File tree

src/algorithms/policy_gradient/sac.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ mutable struct SACPolicy{
2323
update_every::Int
2424
step::Int
2525
rng::R
26+
# Logging
27+
reward_term::Float32
28+
entropy_term::Float32
2629
end
2730

2831
"""
@@ -49,6 +52,8 @@ end
4952
`policy` is expected to output a tuple `(μ, logσ)` of mean and
5053
log standard deviations for the desired action distributions, this
5154
can be implemented using a `GaussianNetwork` in a `NeuralNetworkApproximator`.
55+
56+
Implemented based on http://arxiv.org/abs/1812.05905
5257
"""
5358
function SACPolicy(;
5459
policy,
@@ -85,6 +90,8 @@ function SACPolicy(;
8590
update_every,
8691
step,
8792
rng,
93+
0f0,
94+
0f0,
8895
)
8996
end
9097

@@ -99,12 +106,12 @@ function (p::SACPolicy)(env)
99106
s = state(env)
100107
s = Flux.unsqueeze(s, ndims(s) + 1)
101108
# trainmode:
102-
action = evaluate(p, s)[1][] # returns action as scalar
109+
action = dropdims(evaluate(p, s)[1], dims=2) # Single action vec, drop second dim
103110

104111
# testmode:
105112
# if testing dont sample an action, but act deterministically by
106113
# taking the "mean" action
107-
# action = p.policy(s)[1][] # returns action as scalar
114+
# action = dropdims(p.policy(s)[1], dims=2)
108115
end
109116
end
110117

@@ -137,17 +144,13 @@ function RLBase.update!(p::SACPolicy, batch::NamedTuple{SARTS})
137144

138145
γ, ρ, α = p.γ, p.ρ, p.α
139146

140-
# !!! we have several assumptions here, need revisit when we have more complex environments
141-
# state is vector
142-
# action is scalar
143147
a′, log_π = evaluate(p, s′)
144148
q′_input = vcat(s′, a′)
145149
q′ = min.(p.target_qnetwork1(q′_input), p.target_qnetwork2(q′_input))
146150

147-
y = r .+ γ .* (1 .- t) .* vec((q′ .- α .* log_π))
151+
y = r .+ γ .* (1 .- t) .* vec(q′ .- α .* log_π)
148152

149153
# Train Q Networks
150-
a = Flux.unsqueeze(a, 1)
151154
q_input = vcat(s, a)
152155

153156
q_grad_1 = gradient(Flux.params(p.qnetwork1)) do
@@ -166,7 +169,13 @@ function RLBase.update!(p::SACPolicy, batch::NamedTuple{SARTS})
166169
a, log_π = evaluate(p, s)
167170
q_input = vcat(s, a)
168171
q = min.(p.qnetwork1(q_input), p.qnetwork2(q_input))
169-
mean.* log_π .- q)
172+
reward = mean(q)
173+
entropy = mean(log_π)
174+
ignore() do
175+
p.reward_term = reward
176+
p.entropy_term = entropy
177+
end
178+
α * entropy - reward
170179
end
171180
update!(p.policy, p_grad)
172181

src/experiments/rl_envs/JuliaRL_SAC_Pendulum.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@ function RLCore.Experiment(
1818
low = A.left
1919
high = A.right
2020
ns = length(state(inner_env))
21+
na = 1
2122

2223
env = ActionTransformedEnv(
2324
inner_env;
24-
action_mapping = x -> low + (x + 1) * 0.5 * (high - low),
25+
action_mapping = x -> low + (x[1] + 1) * 0.5 * (high - low),
2526
)
2627
init = glorot_uniform(rng)
2728

@@ -31,15 +32,15 @@ function RLCore.Experiment(
3132
Dense(ns, 30, relu),
3233
Dense(30, 30, relu),
3334
),
34-
μ = Chain(Dense(30, 1, initW = init)),
35-
logσ = Chain(Dense(30, 1, x -> clamp.(x, typeof(x)(-10), typeof(x)(2)), initW = init)),
35+
μ = Chain(Dense(30, na, initW = init)),
36+
logσ = Chain(Dense(30, na, x -> clamp.(x, typeof(x)(-10), typeof(x)(2)), initW = init)),
3637
),
3738
optimizer = ADAM(0.003),
3839
)
3940

4041
create_q_net() = NeuralNetworkApproximator(
4142
model = Chain(
42-
Dense(ns + 1, 30, relu; initW = init),
43+
Dense(ns + na, 30, relu; initW = init),
4344
Dense(30, 30, relu; initW = init),
4445
Dense(30, 1; initW = init),
4546
),
@@ -58,15 +59,15 @@ function RLCore.Experiment(
5859
α = 0.2f0,
5960
batch_size = 64,
6061
start_steps = 1000,
61-
start_policy = RandomPolicy(-1.0..1.0; rng = rng),
62+
start_policy = RandomPolicy(Space([-1.0..1.0 for _ in 1:na]); rng = rng),
6263
update_after = 1000,
6364
update_every = 1,
6465
rng = rng,
6566
),
6667
trajectory = CircularArraySARTTrajectory(
6768
capacity = 10000,
6869
state = Vector{Float32} => (ns,),
69-
action = Float32 => (),
70+
action = Vector{Float32} => (na,),
7071
),
7172
)
7273

@@ -76,9 +77,16 @@ function RLCore.Experiment(
7677
hook = ComposedHook(
7778
total_reward_per_episode,
7879
time_per_step,
79-
DoEveryNEpisode() do t, agent, env
80+
DoEveryNStep() do t, agent, env
8081
with_logger(lg) do
81-
@info "training" reward = total_reward_per_episode.rewards[end]
82+
@info(
83+
"training",
84+
reward_term = agent.policy.reward_term,
85+
entropy_term = agent.policy.entropy_term,
86+
)
87+
if is_terminated(env)
88+
@info "training" reward = total_reward_per_episode.reward log_step_increment = 0
89+
end
8290
end
8391
end,
8492
)

0 commit comments

Comments
 (0)