Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit b2a27f3

Browse files
Mobius1DPrasidh Srikumarfindmyway
authored
QRDQN implementation (#176)
* QRDQN implementation Initial implementation with a CartPole experiment with a few bugs. * Fix TD errors fix the mistake in TD_error * Working Non optimal QRDQN Has bugs but runs the experiment. * Fixed a few errors Corrected few errors in huber loss and raw loss * Fixed tau * Fix notations and typos * Fix state used in calculation of quantiles * Fixed a few bugs made quantile_huber_loss into a separate functon, changed reshaping, changed ensemble_num to quantile_num * Fixed a few issues * Fixed issues * fix qrdqn Co-authored-by: Prasidh Srikumar <prsdhsk@gmail,com> Co-authored-by: Jun Tian <[email protected]>
1 parent 022c1fd commit b2a27f3

5 files changed

Lines changed: 206 additions & 2 deletions

File tree

src/algorithms/dqns/common.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
const PERLearners = Union{PrioritizedDQNLearner,RainbowLearner,IQNLearner}
66

7-
function RLBase.update!(learner::Union{DQNLearner,REMDQNLearner,PERLearners}, t::AbstractTrajectory)
7+
function RLBase.update!(learner::Union{DQNLearner, QRDQNLearner,REMDQNLearner,PERLearners}, t::AbstractTrajectory)
88
length(t[:terminal]) - learner.sampler.n <= learner.min_replay_history && return
99

1010
learner.update_step += 1

src/algorithms/dqns/dqns.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
include("basic_dqn.jl")
22
include("dqn.jl")
33
include("prioritized_dqn.jl")
4+
include("qr_dqn.jl")
45
include("rem_dqn.jl")
56
include("rainbow.jl")
67
include("iqn.jl")

src/algorithms/dqns/qr_dqn.jl

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
export QRDQNLearner, quantile_huber_loss
2+
3+
function quantile_huber_loss(ŷ, y; κ=1.0f0)
4+
N, B = size(y)
5+
Δ = reshape(y, N, 1, B) .- reshape(ŷ, 1, N, B)
6+
abs_error = abs.(Δ)
7+
quadratic = min.(abs_error, κ)
8+
linear = abs_error .- quadratic
9+
huber_loss = 0.5f0 .* quadratic .* quadratic .+ κ .* linear
10+
11+
cum_prob = send_to_device(device(y), range(0.5f0 / N; length=N, step=1.0f0 / N))
12+
loss = Zygote.dropgrad(abs.(cum_prob .-.< 0))) .* huber_loss
13+
mean(sum(loss;dims=1))
14+
end
15+
16+
mutable struct QRDQNLearner{Tq <: AbstractApproximator,Tt <: AbstractApproximator,Tf,R} <: AbstractLearner
17+
approximator::Tq
18+
target_approximator::Tt
19+
min_replay_history::Int
20+
update_freq::Int
21+
update_step::Int
22+
target_update_freq::Int
23+
sampler::NStepBatchSampler
24+
n_quantile::Int
25+
loss_func::Tf
26+
rng::R
27+
# for recording
28+
loss::Float32
29+
end
30+
31+
"""
32+
QRDQNLearner(;kwargs...)
33+
34+
See paper: [Distributional Reinforcement Learning with Quantile Regression](https://arxiv.org/pdf/1710.10044.pdf)
35+
36+
# Keywords
37+
38+
- `approximator`::[`AbstractApproximator`](@ref): used to get quantile-values of a batch of states. The output should be of size `(n_quantile, n_action)`.
39+
- `target_approximator`::[`AbstractApproximator`](@ref): similar to `approximator`, but used to estimate the quantile values of the next state batch.
40+
- `γ::Float32=0.99f0`: discount rate.
41+
- `batch_size::Int=32`
42+
- `update_horizon::Int=1`: length of update ('n' in n-step update).
43+
- `min_replay_history::Int=32`: number of transitions that should be experienced before updating the `approximator`.
44+
- `update_freq::Int=1`: the frequency of updating the `approximator`.
45+
- `n_quantile::Int=1`: the number of quantiles.
46+
- `target_update_freq::Int=100`: the frequency of syncing `target_approximator`.
47+
- `stack_size::Union{Int, Nothing}=4`: use the recent `stack_size` frames to form a stacked state.
48+
- `traces = SARTS`, set to `SLARTSL` if you are to apply to an environment of `FULL_ACTION_SET`.
49+
- `loss_func`=[`quantile_huber_loss`](@ref).
50+
"""
51+
function QRDQNLearner(;
52+
approximator,
53+
target_approximator,
54+
stack_size::Union{Int,Nothing}=nothing,
55+
γ::Float32=0.99f0,
56+
batch_size::Int=32,
57+
update_horizon::Int=1,
58+
min_replay_history::Int=32,
59+
update_freq::Int=1,
60+
n_quantile::Int=1,
61+
target_update_freq::Int=100,
62+
traces=SARTS,
63+
update_step=0,
64+
loss_func=quantile_huber_loss,
65+
rng=Random.GLOBAL_RNG
66+
)
67+
copyto!(approximator, target_approximator)
68+
sampler = NStepBatchSampler{traces}(;
69+
γ=γ,
70+
n=update_horizon,
71+
stack_size=stack_size,
72+
batch_size=batch_size,
73+
)
74+
75+
N = n_quantile
76+
77+
QRDQNLearner(
78+
approximator,
79+
target_approximator,
80+
min_replay_history,
81+
update_freq,
82+
update_step,
83+
target_update_freq,
84+
sampler,
85+
N,
86+
loss_func,
87+
rng,
88+
0.0f0,
89+
)
90+
end
91+
92+
Flux.functor(x::QRDQNLearner) = (Q = x.approximator, Qₜ = x.target_approximator),
93+
y -> begin
94+
x = @set x.approximator = y.Q
95+
x = @set x.target_approximator = y.Qₜ
96+
x
97+
end
98+
99+
function (learner::QRDQNLearner)(env)
100+
s = send_to_device(device(learner.approximator), state(env))
101+
s = Flux.unsqueeze(s, ndims(s) + 1)
102+
q = reshape(learner.approximator(s), learner.n_quantile, :)
103+
vec(mean(q, dims=1)) |> send_to_host
104+
end
105+
106+
function RLBase.update!(learner::QRDQNLearner, batch::NamedTuple)
107+
Q = learner.approximator
108+
Qₜ = learner.target_approximator
109+
γ = learner.sampler.γ
110+
n = learner.sampler.n
111+
batch_size = learner.sampler.batch_size
112+
N = learner.n_quantile
113+
D = device(Q)
114+
loss_func = learner.loss_func
115+
116+
s, a, r, t, s′ = (send_to_device(D, batch[x]) for x in SARTS)
117+
a = CartesianIndex.(a, 1:batch_size)
118+
119+
target_quantiles = reshape(Qₜ(s′), N, :, batch_size)
120+
qₜ = dropdims(mean(target_quantiles; dims=1); dims=1)
121+
aₜ = dropdims(argmax(qₜ, dims=1); dims=1)
122+
@views target_quantile_aₜ = target_quantiles[:, aₜ]
123+
y = reshape(r, 1, batch_size) .+ γ .* reshape(1 .- t, 1, batch_size) .* target_quantile_aₜ
124+
125+
gs = gradient(params(Q)) do
126+
q = reshape(Q(s), N, :, batch_size)
127+
@views= q[:, a]
128+
129+
loss = loss_func(ŷ, y)
130+
131+
ignore() do
132+
learner.loss = loss
133+
end
134+
loss
135+
end
136+
137+
update!(Q, gs)
138+
end
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
function RLCore.Experiment(
2+
::Val{:JuliaRL},
3+
::Val{:QRDQN},
4+
::Val{:CartPole},
5+
::Nothing;
6+
save_dir=nothing,
7+
seed=123,
8+
)
9+
10+
N = 10
11+
12+
rng = StableRNG(seed)
13+
env = CartPoleEnv(; T=Float32, rng=rng)
14+
ns, na = length(state(env)), length(action_space(env))
15+
16+
agent = Agent(
17+
policy=QBasedPolicy(
18+
learner=QRDQNLearner(
19+
approximator=NeuralNetworkApproximator(
20+
model=Chain(
21+
Dense(ns, 128, relu; initW=glorot_uniform(rng)),
22+
Dense(128, 128, relu; initW=glorot_uniform(rng)),
23+
Dense(128, N * na; initW=glorot_uniform(rng)),
24+
) |> cpu,
25+
optimizer=ADAM(),
26+
),
27+
target_approximator=NeuralNetworkApproximator(
28+
model=Chain(
29+
Dense(ns, 128, relu; initW=glorot_uniform(rng)),
30+
Dense(128, 128, relu; initW=glorot_uniform(rng)),
31+
Dense(128, N * na; initW=glorot_uniform(rng)),
32+
) |> cpu,
33+
),
34+
stack_size=nothing,
35+
batch_size=32,
36+
update_horizon=1,
37+
min_replay_history=100,
38+
update_freq=1,
39+
target_update_freq=100,
40+
n_quantile=N,
41+
),
42+
explorer=EpsilonGreedyExplorer(
43+
kind=:exp,
44+
ϵ_stable=0.01,
45+
decay_steps=500,
46+
rng=rng,
47+
),
48+
),
49+
trajectory=CircularArraySARTTrajectory(
50+
capacity=1000,
51+
state=Vector{Float32} => (ns,),
52+
),
53+
)
54+
55+
stop_condition = StopAfterStep(10_000)
56+
57+
hook = ComposedHook(TotalRewardPerEpisode())
58+
59+
description = """
60+
This experiment uses the `QRDQNLearner` method with three dense layers to approximate the quantile values.
61+
The testing environment is CartPoleEnv.
62+
"""
63+
64+
Experiment(agent, env, stop_condition, hook, description)
65+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ end
3333

3434
@testset "training" begin
3535
mktempdir() do dir
36-
for method in (:BasicDQN, :BC, :DQN, :PrioritizedDQN, :Rainbow, :REMDQN, :IQN, :VPG)
36+
for method in (:BasicDQN, :BC, :DQN, :PrioritizedDQN, :Rainbow, :QRDQN, :REMDQN, :IQN, :VPG)
3737
res = run(
3838
Experiment(
3939
Val(:JuliaRL),

0 commit comments

Comments
 (0)