|
| 1 | +export QRDQNLearner, quantile_huber_loss |
| 2 | + |
| 3 | +function quantile_huber_loss(ŷ, y; κ=1.0f0) |
| 4 | + N, B = size(y) |
| 5 | + Δ = reshape(y, N, 1, B) .- reshape(ŷ, 1, N, B) |
| 6 | + abs_error = abs.(Δ) |
| 7 | + quadratic = min.(abs_error, κ) |
| 8 | + linear = abs_error .- quadratic |
| 9 | + huber_loss = 0.5f0 .* quadratic .* quadratic .+ κ .* linear |
| 10 | + |
| 11 | + cum_prob = send_to_device(device(y), range(0.5f0 / N; length=N, step=1.0f0 / N)) |
| 12 | + loss = Zygote.dropgrad(abs.(cum_prob .- (Δ .< 0))) .* huber_loss |
| 13 | + mean(sum(loss;dims=1)) |
| 14 | +end |
| 15 | + |
| 16 | +mutable struct QRDQNLearner{Tq <: AbstractApproximator,Tt <: AbstractApproximator,Tf,R} <: AbstractLearner |
| 17 | + approximator::Tq |
| 18 | + target_approximator::Tt |
| 19 | + min_replay_history::Int |
| 20 | + update_freq::Int |
| 21 | + update_step::Int |
| 22 | + target_update_freq::Int |
| 23 | + sampler::NStepBatchSampler |
| 24 | + n_quantile::Int |
| 25 | + loss_func::Tf |
| 26 | + rng::R |
| 27 | + # for recording |
| 28 | + loss::Float32 |
| 29 | +end |
| 30 | + |
| 31 | +""" |
| 32 | + QRDQNLearner(;kwargs...) |
| 33 | +
|
| 34 | +See paper: [Distributional Reinforcement Learning with Quantile Regression](https://arxiv.org/pdf/1710.10044.pdf) |
| 35 | +
|
| 36 | +# Keywords |
| 37 | +
|
| 38 | +- `approximator`::[`AbstractApproximator`](@ref): used to get quantile-values of a batch of states. The output should be of size `(n_quantile, n_action)`. |
| 39 | +- `target_approximator`::[`AbstractApproximator`](@ref): similar to `approximator`, but used to estimate the quantile values of the next state batch. |
| 40 | +- `γ::Float32=0.99f0`: discount rate. |
| 41 | +- `batch_size::Int=32` |
| 42 | +- `update_horizon::Int=1`: length of update ('n' in n-step update). |
| 43 | +- `min_replay_history::Int=32`: number of transitions that should be experienced before updating the `approximator`. |
| 44 | +- `update_freq::Int=1`: the frequency of updating the `approximator`. |
| 45 | +- `n_quantile::Int=1`: the number of quantiles. |
| 46 | +- `target_update_freq::Int=100`: the frequency of syncing `target_approximator`. |
| 47 | +- `stack_size::Union{Int, Nothing}=4`: use the recent `stack_size` frames to form a stacked state. |
| 48 | +- `traces = SARTS`, set to `SLARTSL` if you are to apply to an environment of `FULL_ACTION_SET`. |
| 49 | +- `loss_func`=[`quantile_huber_loss`](@ref). |
| 50 | +""" |
| 51 | +function QRDQNLearner(; |
| 52 | + approximator, |
| 53 | + target_approximator, |
| 54 | + stack_size::Union{Int,Nothing}=nothing, |
| 55 | + γ::Float32=0.99f0, |
| 56 | + batch_size::Int=32, |
| 57 | + update_horizon::Int=1, |
| 58 | + min_replay_history::Int=32, |
| 59 | + update_freq::Int=1, |
| 60 | + n_quantile::Int=1, |
| 61 | + target_update_freq::Int=100, |
| 62 | + traces=SARTS, |
| 63 | + update_step=0, |
| 64 | + loss_func=quantile_huber_loss, |
| 65 | + rng=Random.GLOBAL_RNG |
| 66 | +) |
| 67 | + copyto!(approximator, target_approximator) |
| 68 | + sampler = NStepBatchSampler{traces}(; |
| 69 | + γ=γ, |
| 70 | + n=update_horizon, |
| 71 | + stack_size=stack_size, |
| 72 | + batch_size=batch_size, |
| 73 | + ) |
| 74 | + |
| 75 | + N = n_quantile |
| 76 | + |
| 77 | + QRDQNLearner( |
| 78 | + approximator, |
| 79 | + target_approximator, |
| 80 | + min_replay_history, |
| 81 | + update_freq, |
| 82 | + update_step, |
| 83 | + target_update_freq, |
| 84 | + sampler, |
| 85 | + N, |
| 86 | + loss_func, |
| 87 | + rng, |
| 88 | + 0.0f0, |
| 89 | + ) |
| 90 | +end |
| 91 | + |
| 92 | +Flux.functor(x::QRDQNLearner) = (Q = x.approximator, Qₜ = x.target_approximator), |
| 93 | +y -> begin |
| 94 | + x = @set x.approximator = y.Q |
| 95 | + x = @set x.target_approximator = y.Qₜ |
| 96 | + x |
| 97 | +end |
| 98 | + |
| 99 | +function (learner::QRDQNLearner)(env) |
| 100 | + s = send_to_device(device(learner.approximator), state(env)) |
| 101 | + s = Flux.unsqueeze(s, ndims(s) + 1) |
| 102 | + q = reshape(learner.approximator(s), learner.n_quantile, :) |
| 103 | + vec(mean(q, dims=1)) |> send_to_host |
| 104 | +end |
| 105 | + |
| 106 | +function RLBase.update!(learner::QRDQNLearner, batch::NamedTuple) |
| 107 | + Q = learner.approximator |
| 108 | + Qₜ = learner.target_approximator |
| 109 | + γ = learner.sampler.γ |
| 110 | + n = learner.sampler.n |
| 111 | + batch_size = learner.sampler.batch_size |
| 112 | + N = learner.n_quantile |
| 113 | + D = device(Q) |
| 114 | + loss_func = learner.loss_func |
| 115 | + |
| 116 | + s, a, r, t, s′ = (send_to_device(D, batch[x]) for x in SARTS) |
| 117 | + a = CartesianIndex.(a, 1:batch_size) |
| 118 | + |
| 119 | + target_quantiles = reshape(Qₜ(s′), N, :, batch_size) |
| 120 | + qₜ = dropdims(mean(target_quantiles; dims=1); dims=1) |
| 121 | + aₜ = dropdims(argmax(qₜ, dims=1); dims=1) |
| 122 | + @views target_quantile_aₜ = target_quantiles[:, aₜ] |
| 123 | + y = reshape(r, 1, batch_size) .+ γ .* reshape(1 .- t, 1, batch_size) .* target_quantile_aₜ |
| 124 | + |
| 125 | + gs = gradient(params(Q)) do |
| 126 | + q = reshape(Q(s), N, :, batch_size) |
| 127 | + @views ŷ = q[:, a] |
| 128 | + |
| 129 | + loss = loss_func(ŷ, y) |
| 130 | + |
| 131 | + ignore() do |
| 132 | + learner.loss = loss |
| 133 | + end |
| 134 | + loss |
| 135 | + end |
| 136 | + |
| 137 | + update!(Q, gs) |
| 138 | +end |
0 commit comments