Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 4a2417b

Browse files
authored
Add dueling network (#171)
* Add dueling network * Add docs * Some adjustment
1 parent ca8f347 commit 4a2417b

4 files changed

Lines changed: 40 additions & 11 deletions

File tree

src/algorithms/dqns/common.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,22 @@ function RLBase.update!(
3535
push!(trajectory[:terminal], is_terminated(env))
3636
push!(trajectory[:priority], p.learner.default_priority)
3737
end
38+
39+
"""
40+
DuelingNetwork(;base, val, adv)
41+
42+
Dueling network automatically produces separate estimates of the state value function network and advantage function network. The expected output size of val is 1, and adv is the size of the action space.
43+
"""
44+
struct DuelingNetwork{B,V,A}
45+
base::B
46+
val::V
47+
adv::A
48+
end
49+
50+
Flux.@functor DuelingNetwork
51+
52+
function (m::DuelingNetwork)(state)
53+
x = m.base(state)
54+
val = m.val(x)
55+
return val .+ m.adv(x) .- mean(m.adv(x), dims=1)
56+
end

src/experiments/rl_envs/JuliaRL_DQN_CartPole.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,21 @@ function RLCore.Experiment(
1616

1717
env = CartPoleEnv(; T = Float32, rng = rng)
1818
ns, na = length(state(env)), length(action_space(env))
19+
base_model = Chain(
20+
Dense(ns, 128, relu; initW = glorot_uniform(rng)),
21+
Dense(128, 128, relu; initW = glorot_uniform(rng)),
22+
Dense(128, na; initW = glorot_uniform(rng))
23+
)
1924

2025
agent = Agent(
2126
policy = QBasedPolicy(
2227
learner = DQNLearner(
2328
approximator = NeuralNetworkApproximator(
24-
model = Chain(
25-
Dense(ns, 128, relu; initW = glorot_uniform(rng)),
26-
Dense(128, 128, relu; initW = glorot_uniform(rng)),
27-
Dense(128, na; initW = glorot_uniform(rng)),
28-
) |> cpu,
29+
model = build_dueling_network(base_model) |> cpu,
2930
optimizer = ADAM(),
3031
),
3132
target_approximator = NeuralNetworkApproximator(
32-
model = Chain(
33-
Dense(ns, 128, relu; initW = glorot_uniform(rng)),
34-
Dense(128, 128, relu; initW = glorot_uniform(rng)),
35-
Dense(128, na; initW = glorot_uniform(rng)),
36-
) |> cpu,
33+
model = build_dueling_network(base_model) |> cpu,
3734
),
3835
loss_func = huber_loss,
3936
stack_size = nothing,

src/experiments/rl_envs/JuliaRL_REMDQN_CartPole.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function RLCore.Experiment(
1616

1717
env = CartPoleEnv(; T = Float32, rng = rng)
1818
ns, na = length(state(env)), length(action_space(env))
19-
ensemble_num = 6
19+
ensemble_num = 16
2020

2121
agent = Agent(
2222
policy = QBasedPolicy(

src/experiments/rl_envs/rl_envs.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,16 @@ for f in readdir(@__DIR__)
55
include(f)
66
end
77
end
8+
9+
# Build Dueling Network
10+
function build_dueling_network(network::Chain)
11+
lm = length(network)
12+
if !(network[lm] isa Dense) || !(network[lm-1] isa Dense)
13+
error("The Qnetwork provided is incompatible with dueling.")
14+
end
15+
base = Chain([deepcopy(network[i]) for i=1:lm-2]...)
16+
last_layer_dims = size(network[lm].W, 2)
17+
val = Chain(deepcopy(network[lm-1]), Dense(last_layer_dims, 1))
18+
adv = Chain([deepcopy(network[i]) for i=lm-1:lm]...)
19+
return DuelingNetwork(base, val, adv)
20+
end

0 commit comments

Comments
 (0)