diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/multi_thread_env.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/multi_thread_env.jl index 11d175833..649242f44 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/multi_thread_env.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/multi_thread_env.jl @@ -115,7 +115,12 @@ const MULTI_THREAD_ENV_CACHE = IdDict{AbstractEnv,Dict{Symbol,Array}}() function RLBase.state(env::MultiThreadEnv) N = ndims(env.states) @sync for i in 1:length(env) - @spawn selectdim(env.states, N, i) .= state(env[i]) + @spawn begin + if N == 1 + env.states[i] .= state(env[i]) + else + selectdim(env.states, N, i) .= state(env[i]) + end end env.states end @@ -167,7 +172,7 @@ function RLBase.plan!(π::QBasedPolicy, env::MultiThreadEnv, ::FullActionSet, A) ] end -function RLBase.plan!(π::QBasedPolicy, +function RLBase.plan!(π::QBasedPolicy, env::MultiThreadEnv, ::MinimalActionSet, ::Space{<:Vector{<:Base.OneTo{<:Integer}}},